source: trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaBaseTraining2.java @ 2

Last change on this file since 2 was 2, checked in by sherbold, 10 years ago
  • initial commit
  • Property svn:mime-type set to text/plain
File size: 2.8 KB
Line 
1package de.ugoe.cs.cpdp.training;
2
3import java.util.Arrays;
4import java.util.logging.Level;
5
6import de.ugoe.cs.util.console.Console;
7import weka.core.OptionHandler;
8import weka.classifiers.Classifier;
9import weka.classifiers.meta.CVParameterSelection;
10
11public abstract class WekaBaseTraining2 implements WekaCompatibleTrainer {
12       
13        protected Classifier classifier = null;
14        protected String classifierClassName;
15        protected String classifierName;
16        protected String[] classifierParams;
17       
18        @Override
19        public void setParameter(String parameters) {
20                String[] params = parameters.split(" ");
21
22                // first is classifierName
23                classifierName = params[0];
24               
25                // all following parameters can be copied from weka!
26               
27                // second param is classifierClassName
28                classifierClassName = params[1];
29       
30                // rest are params to the specified classifier
31                classifierParams = Arrays.copyOfRange(params, 2, params.length);
32               
33                classifier = setupClassifier();
34        }
35
36        @Override
37        public Classifier getClassifier() {
38                return classifier;
39        }
40
41        public Classifier setupClassifier() {
42                Classifier cl = null;
43                try{
44                        @SuppressWarnings("rawtypes")
45                        Class c = Class.forName(classifierClassName);
46                        Classifier obj = (Classifier) c.newInstance();
47                       
48                        // Filter -CVPARAM
49                        String[] param = Arrays.copyOf(classifierParams, classifierParams.length);
50                        String[] cvparam = {};
51                        boolean cv = false;
52                        for ( int i=0; i < classifierParams.length; i++ ) {
53                                if(classifierParams[i].equals("-CVPARAM")) {
54                                        // rest of array are cvparam
55                                        cvparam = Arrays.copyOfRange(classifierParams, i+1, classifierParams.length);
56                                       
57                                        // before this we have normal params
58                                        param = Arrays.copyOfRange(classifierParams, 0, i);
59                                       
60                                        cv = true;
61                                        break;
62                                }
63                        }
64                       
65                        // set classifier params
66                        ((OptionHandler)obj).setOptions(param);
67                        cl = obj;
68                       
69                        // we have cross val params
70                        // cant check on cvparam.length may not be initialized                 
71                        if(cv) {
72                                final CVParameterSelection ps = new CVParameterSelection();
73                                ps.setClassifier(obj);
74                                ps.setNumFolds(5);
75                                //ps.addCVParameter("I 5 25 5");
76                                ps.addCVParameter(Arrays.asList(cvparam).toString().replaceAll(", ", " ").replaceAll("^\\[|\\]$", ""));
77                               
78                                cl = ps;
79                        }
80
81                }catch(ClassNotFoundException e) {
82                        Console.traceln(Level.WARNING, String.format("class not found: %s", e.toString()));
83                        e.printStackTrace();
84                } catch (InstantiationException e) {
85                        Console.traceln(Level.WARNING, String.format("Instantiation Exception: %s", e.toString()));
86                        e.printStackTrace();
87                } catch (IllegalAccessException e) {
88                        Console.traceln(Level.WARNING, String.format("Illegal Access Exception: %s", e.toString()));
89                        e.printStackTrace();
90                } catch (Exception e) {
91                        Console.traceln(Level.WARNING, String.format("Exception: %s", e.toString()));
92                        e.printStackTrace();
93                }
94               
95                return cl;
96        }
97
98        @Override
99        public String getName() {
100                return classifierName;
101        }
102       
103}
Note: See TracBrowser for help on using the repository browser.