source: trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaBaseTraining2.java @ 9

Last change on this file since 9 was 7, checked in by sherbold, 10 years ago
  • changed WekaBaseTraining2 to allow multiple CV parameters
  • Property svn:mime-type set to text/plain
File size: 2.9 KB
Line 
1package de.ugoe.cs.cpdp.training;
2
3import java.util.Arrays;
4import java.util.logging.Level;
5
6import de.ugoe.cs.util.console.Console;
7import weka.core.OptionHandler;
8import weka.classifiers.Classifier;
9import weka.classifiers.meta.CVParameterSelection;
10
11public abstract class WekaBaseTraining2 implements WekaCompatibleTrainer {
12       
13        protected Classifier classifier = null;
14        protected String classifierClassName;
15        protected String classifierName;
16        protected String[] classifierParams;
17       
18        @Override
19        public void setParameter(String parameters) {
20                String[] params = parameters.split(" ");
21
22                // first is classifierName
23                classifierName = params[0];
24               
25                // all following parameters can be copied from weka!
26               
27                // second param is classifierClassName
28                classifierClassName = params[1];
29       
30                // rest are params to the specified classifier
31                classifierParams = Arrays.copyOfRange(params, 2, params.length);
32               
33                classifier = setupClassifier();
34        }
35
36        @Override
37        public Classifier getClassifier() {
38                return classifier;
39        }
40
41        public Classifier setupClassifier() {
42                Classifier cl = null;
43                try{
44                        @SuppressWarnings("rawtypes")
45                        Class c = Class.forName(classifierClassName);
46                        Classifier obj = (Classifier) c.newInstance();
47                       
48                        // Filter -CVPARAM
49                        String[] param = Arrays.copyOf(classifierParams, classifierParams.length);
50                        String[] cvparam = {};
51                        boolean cv = false;
52                        for ( int i=0; i < classifierParams.length; i++ ) {
53                                if(classifierParams[i].equals("-CVPARAM")) {
54                                        // rest of array are cvparam
55                                        cvparam = Arrays.copyOfRange(classifierParams, i+1, classifierParams.length);
56                                       
57                                        // before this we have normal params
58                                        param = Arrays.copyOfRange(classifierParams, 0, i);
59                                       
60                                        cv = true;
61                                        break;
62                                }
63                        }
64                       
65                        // set classifier params
66                        ((OptionHandler)obj).setOptions(param);
67                        cl = obj;
68                       
69                        // we have cross val params
70                        // cant check on cvparam.length may not be initialized                 
71                        if(cv) {
72                                final CVParameterSelection ps = new CVParameterSelection();
73                                ps.setClassifier(obj);
74                                ps.setNumFolds(5);
75                                //ps.addCVParameter("I 5 25 5");
76                                for( int i=1 ; i<cvparam.length/4 ; i++ ) {
77                                        ps.addCVParameter(Arrays.asList(Arrays.copyOfRange(cvparam, 0, 4*i)).toString().replaceAll(", ", " ").replaceAll("^\\[|\\]$", ""));
78                                }
79                               
80                                cl = ps;
81                        }
82
83                }catch(ClassNotFoundException e) {
84                        Console.traceln(Level.WARNING, String.format("class not found: %s", e.toString()));
85                        e.printStackTrace();
86                } catch (InstantiationException e) {
87                        Console.traceln(Level.WARNING, String.format("Instantiation Exception: %s", e.toString()));
88                        e.printStackTrace();
89                } catch (IllegalAccessException e) {
90                        Console.traceln(Level.WARNING, String.format("Illegal Access Exception: %s", e.toString()));
91                        e.printStackTrace();
92                } catch (Exception e) {
93                        Console.traceln(Level.WARNING, String.format("Exception: %s", e.toString()));
94                        e.printStackTrace();
95                }
96               
97                return cl;
98        }
99
100        @Override
101        public String getName() {
102                return classifierName;
103        }
104       
105}
Note: See TracBrowser for help on using the repository browser.