source: trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaBaseTraining.java @ 25

Last change on this file since 25 was 25, checked in by atrautsch, 10 years ago

comment fixes

  • Property svn:mime-type set to text/plain
File size: 3.5 KB
Line 
1package de.ugoe.cs.cpdp.training;
2
3import java.util.Arrays;
4import java.util.logging.Level;
5
6import de.ugoe.cs.util.console.Console;
7
8import weka.core.OptionHandler;
9import weka.classifiers.Classifier;
10import weka.classifiers.meta.CVParameterSelection;
11
12/**
13 * WekaBaseTraining2
14 *
15 * Allows specification of the Weka classifier and its params in the XML experiment configuration.
16 *
17 * Important conventions of the XML format:
18 * Cross Validation params always come last and are prepended with -CVPARAM
19 * Example: <trainer name="WekaTraining" param="RandomForestLocal weka.classifiers.trees.RandomForest -CVPARAM I 5 25 5"/>
20 */
21public abstract class WekaBaseTraining implements IWekaCompatibleTrainer {
22       
23        protected Classifier classifier = null;
24        protected String classifierClassName;
25        protected String classifierName;
26        protected String[] classifierParams;
27       
28        @Override
29        public void setParameter(String parameters) {
30                String[] params = parameters.split(" ");
31
32                // first part of the params is the classifierName (e.g. SMORBF)
33                classifierName = params[0];
34               
35                // the following parameters can be copied from weka!
36               
37                // second param is classifierClassName (e.g. weka.classifiers.functions.SMO)
38                classifierClassName = params[1];
39       
40                // rest are params to the specified classifier (e.g. -K weka.classifiers.functions.supportVector.RBFKernel)
41                classifierParams = Arrays.copyOfRange(params, 2, params.length);
42               
43                classifier = setupClassifier();
44        }
45
46        @Override
47        public Classifier getClassifier() {
48                return classifier;
49        }
50
51        public Classifier setupClassifier() {
52                Classifier cl = null;
53                try{
54                        @SuppressWarnings("rawtypes")
55                        Class c = Class.forName(classifierClassName);
56                        Classifier obj = (Classifier) c.newInstance();
57                       
58                        // Filter out -CVPARAM, these are special because they do not belong to the Weka classifier class as parameters
59                        String[] param = Arrays.copyOf(classifierParams, classifierParams.length);
60                        String[] cvparam = {};
61                        boolean cv = false;
62                        for ( int i=0; i < classifierParams.length; i++ ) {
63                                if(classifierParams[i].equals("-CVPARAM")) {
64                                        // rest of array are cvparam
65                                        cvparam = Arrays.copyOfRange(classifierParams, i+1, classifierParams.length);
66                                       
67                                        // before this we have normal params
68                                        param = Arrays.copyOfRange(classifierParams, 0, i);
69                                       
70                                        cv = true;
71                                        break;
72                                }
73                        }
74                       
75                        // set classifier params
76                        ((OptionHandler)obj).setOptions(param);
77                        cl = obj;
78                       
79                        // we have cross val params
80                        // cant check on cvparam.length here, it may not be initialized                 
81                        if(cv) {
82                                final CVParameterSelection ps = new CVParameterSelection();
83                                ps.setClassifier(obj);
84                                ps.setNumFolds(5);
85                                //ps.addCVParameter("I 5 25 5");
86                                for( int i=1 ; i<cvparam.length/4 ; i++ ) {
87                                        ps.addCVParameter(Arrays.asList(Arrays.copyOfRange(cvparam, 0, 4*i)).toString().replaceAll(", ", " ").replaceAll("^\\[|\\]$", ""));
88                                }
89                               
90                                cl = ps;
91                        }
92
93                }catch(ClassNotFoundException e) {
94                        Console.traceln(Level.WARNING, String.format("class not found: %s", e.toString()));
95                        e.printStackTrace();
96                } catch (InstantiationException e) {
97                        Console.traceln(Level.WARNING, String.format("Instantiation Exception: %s", e.toString()));
98                        e.printStackTrace();
99                } catch (IllegalAccessException e) {
100                        Console.traceln(Level.WARNING, String.format("Illegal Access Exception: %s", e.toString()));
101                        e.printStackTrace();
102                } catch (Exception e) {
103                        Console.traceln(Level.WARNING, String.format("Exception: %s", e.toString()));
104                        e.printStackTrace();
105                }
106               
107                return cl;
108        }
109
110        @Override
111        public String getName() {
112                return classifierName;
113        }
114       
115}
Note: See TracBrowser for help on using the repository browser.