source: trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaBaseTraining.java @ 132

Last change on this file since 132 was 131, checked in by sherbold, 9 years ago
  • added workaround to WekaBaseTraining? to allow setting that no ADTree is used for BayesNet? when using a Vote classifier (required due to bug in WEKAs option parser
  • Property svn:mime-type set to text/plain
File size: 5.3 KB
Line 
1// Copyright 2015 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.cpdp.training;
16
17import java.util.Arrays;
18import java.util.logging.Level;
19
20import de.ugoe.cs.util.console.Console;
21
22import weka.core.OptionHandler;
23import weka.classifiers.Classifier;
24import weka.classifiers.bayes.BayesNet;
25import weka.classifiers.meta.CVParameterSelection;
26import weka.classifiers.meta.Vote;
27
28/**
29 * WekaBaseTraining2
30 *
31 * Allows specification of the Weka classifier and its params in the XML experiment configuration.
32 *
33 * Important conventions of the XML format: Cross Validation params always come last and are
34 * prepended with -CVPARAM Example: <trainer name="WekaTraining"
35 * param="RandomForestLocal weka.classifiers.trees.RandomForest -CVPARAM I 5 25 5"/>
36 */
37public abstract class WekaBaseTraining implements IWekaCompatibleTrainer {
38
39    protected Classifier classifier = null;
40    protected String classifierClassName;
41    protected String classifierName;
42    protected String[] classifierParams;
43
44    @Override
45    public void setParameter(String parameters) {
46        String[] params = parameters.split(" ");
47
48        // first part of the params is the classifierName (e.g. SMORBF)
49        classifierName = params[0];
50
51        // the following parameters can be copied from weka!
52
53        // second param is classifierClassName (e.g. weka.classifiers.functions.SMO)
54        classifierClassName = params[1];
55
56        // rest are params to the specified classifier (e.g. -K
57        // weka.classifiers.functions.supportVector.RBFKernel)
58        classifierParams = Arrays.copyOfRange(params, 2, params.length);
59
60        //classifier = setupClassifier();
61    }
62
63    @Override
64    public Classifier getClassifier() {
65        return classifier;
66    }
67
68    protected Classifier setupClassifier() {
69        Classifier cl = null;
70        try {
71            @SuppressWarnings("rawtypes")
72            Class c = Class.forName(classifierClassName);
73            Classifier obj = (Classifier) c.newInstance();
74
75            // Filter out -CVPARAM, these are special because they do not belong to the Weka
76            // classifier class as parameters
77            String[] param = Arrays.copyOf(classifierParams, classifierParams.length);
78            String[] cvparam = { };
79            boolean cv = false;
80            for (int i = 0; i < classifierParams.length; i++) {
81                if (classifierParams[i].equals("-CVPARAM")) {
82                    // rest of array are cvparam
83                    cvparam = Arrays.copyOfRange(classifierParams, i + 1, classifierParams.length);
84
85                    // before this we have normal params
86                    param = Arrays.copyOfRange(classifierParams, 0, i);
87
88                    cv = true;
89                    break;
90                }
91            }
92
93            // set classifier params
94            ((OptionHandler) obj).setOptions(param);
95            cl = obj;
96
97            if( cl instanceof Vote ) {
98                Vote votingClassifier = (Vote) cl;
99                for( Classifier classifier : votingClassifier.getClassifiers() ) {
100                    if( classifier instanceof BayesNet ) {
101                        ((BayesNet) classifier).setUseADTree(false);
102                    }
103                }
104            }
105            // we have cross val params
106            // cant check on cvparam.length here, it may not be initialized
107            if (cv) {
108                final CVParameterSelection ps = new CVParameterSelection();
109                ps.setClassifier(obj);
110                ps.setNumFolds(5);
111                // ps.addCVParameter("I 5 25 5");
112                for (int i = 1; i < cvparam.length / 4; i++) {
113                    ps.addCVParameter(Arrays.asList(Arrays.copyOfRange(cvparam, 0, 4 * i))
114                        .toString().replaceAll(", ", " ").replaceAll("^\\[|\\]$", ""));
115                }
116
117                cl = ps;
118            }
119
120        }
121        catch (ClassNotFoundException e) {
122            Console.traceln(Level.WARNING, String.format("class not found: %s", e.toString()));
123            e.printStackTrace();
124        }
125        catch (InstantiationException e) {
126            Console.traceln(Level.WARNING,
127                            String.format("Instantiation Exception: %s", e.toString()));
128            e.printStackTrace();
129        }
130        catch (IllegalAccessException e) {
131            Console.traceln(Level.WARNING,
132                            String.format("Illegal Access Exception: %s", e.toString()));
133            e.printStackTrace();
134        }
135        catch (Exception e) {
136            Console.traceln(Level.WARNING, String.format("Exception: %s", e.toString()));
137            e.printStackTrace();
138        }
139
140        return cl;
141    }
142
143    @Override
144    public String getName() {
145        return classifierName;
146    }
147
148}
Note: See TracBrowser for help on using the repository browser.