source: trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaBaseTraining.java @ 140

Last change on this file since 140 was 135, checked in by sherbold, 8 years ago
  • code documentation and formatting
  • Property svn:mime-type set to text/plain
File size: 6.0 KB
Line 
1// Copyright 2015 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.cpdp.training;
16
17import java.util.Arrays;
18import java.util.logging.Level;
19
20import de.ugoe.cs.util.console.Console;
21
22import weka.core.OptionHandler;
23import weka.classifiers.Classifier;
24import weka.classifiers.bayes.BayesNet;
25import weka.classifiers.meta.CVParameterSelection;
26import weka.classifiers.meta.Vote;
27
28/**
29 * <p>
30 * Allows specification of the Weka classifier and its params in the XML experiment configuration.
31 * </p>
32 * <p>
33 * Important conventions of the XML format: Cross Validation params always come last and are
34 * prepended with -CVPARAM.<br>
35 * Example:
36 *
37 * <pre>
38 * {@code
39 * <trainer name="WekaTraining" param="RandomForestLocal weka.classifiers.trees.RandomForest -CVPARAM I 5 25 5"/>
40 * }
41 * </pre>
42 *
43 * @author Alexander Trautsch
44 */
45public abstract class WekaBaseTraining implements IWekaCompatibleTrainer {
46
47    /**
48     * reference to the Weka classifier
49     */
50    protected Classifier classifier = null;
51
52    /**
53     * qualified class name of the weka classifier
54     */
55    protected String classifierClassName;
56
57    /**
58     * name of the classifier
59     */
60    protected String classifierName;
61
62    /**
63     * parameters of the training
64     */
65    protected String[] classifierParams;
66
67    /*
68     * (non-Javadoc)
69     *
70     * @see de.ugoe.cs.cpdp.IParameterizable#setParameter(java.lang.String)
71     */
72    @Override
73    public void setParameter(String parameters) {
74        String[] params = parameters.split(" ");
75
76        // first part of the params is the classifierName (e.g. SMORBF)
77        classifierName = params[0];
78
79        // the following parameters can be copied from weka!
80
81        // second param is classifierClassName (e.g. weka.classifiers.functions.SMO)
82        classifierClassName = params[1];
83
84        // rest are params to the specified classifier (e.g. -K
85        // weka.classifiers.functions.supportVector.RBFKernel)
86        classifierParams = Arrays.copyOfRange(params, 2, params.length);
87
88        // classifier = setupClassifier();
89    }
90
91    /*
92     * (non-Javadoc)
93     *
94     * @see de.ugoe.cs.cpdp.training.IWekaCompatibleTrainer#getClassifier()
95     */
96    @Override
97    public Classifier getClassifier() {
98        return classifier;
99    }
100
101    /**
102     * <p>
103     * helper function that sets up the Weka classifier including its parameters
104     * </p>
105     *
106     * @return
107     */
108    protected Classifier setupClassifier() {
109        Classifier cl = null;
110        try {
111            @SuppressWarnings("rawtypes")
112            Class c = Class.forName(classifierClassName);
113            Classifier obj = (Classifier) c.newInstance();
114
115            // Filter out -CVPARAM, these are special because they do not belong to the Weka
116            // classifier class as parameters
117            String[] param = Arrays.copyOf(classifierParams, classifierParams.length);
118            String[] cvparam = { };
119            boolean cv = false;
120            for (int i = 0; i < classifierParams.length; i++) {
121                if (classifierParams[i].equals("-CVPARAM")) {
122                    // rest of array are cvparam
123                    cvparam = Arrays.copyOfRange(classifierParams, i + 1, classifierParams.length);
124
125                    // before this we have normal params
126                    param = Arrays.copyOfRange(classifierParams, 0, i);
127
128                    cv = true;
129                    break;
130                }
131            }
132
133            // set classifier params
134            ((OptionHandler) obj).setOptions(param);
135            cl = obj;
136
137            if (cl instanceof Vote) {
138                Vote votingClassifier = (Vote) cl;
139                for (Classifier classifier : votingClassifier.getClassifiers()) {
140                    if (classifier instanceof BayesNet) {
141                        ((BayesNet) classifier).setUseADTree(false);
142                    }
143                }
144            }
145            // we have cross val params
146            // cant check on cvparam.length here, it may not be initialized
147            if (cv) {
148                final CVParameterSelection ps = new CVParameterSelection();
149                ps.setClassifier(obj);
150                ps.setNumFolds(5);
151                // ps.addCVParameter("I 5 25 5");
152                for (int i = 1; i < cvparam.length / 4; i++) {
153                    ps.addCVParameter(Arrays.asList(Arrays.copyOfRange(cvparam, 0, 4 * i))
154                        .toString().replaceAll(", ", " ").replaceAll("^\\[|\\]$", ""));
155                }
156
157                cl = ps;
158            }
159
160        }
161        catch (ClassNotFoundException e) {
162            Console.traceln(Level.WARNING, String.format("class not found: %s", e.toString()));
163            e.printStackTrace();
164        }
165        catch (InstantiationException e) {
166            Console.traceln(Level.WARNING,
167                            String.format("Instantiation Exception: %s", e.toString()));
168            e.printStackTrace();
169        }
170        catch (IllegalAccessException e) {
171            Console.traceln(Level.WARNING,
172                            String.format("Illegal Access Exception: %s", e.toString()));
173            e.printStackTrace();
174        }
175        catch (Exception e) {
176            Console.traceln(Level.WARNING, String.format("Exception: %s", e.toString()));
177            e.printStackTrace();
178        }
179
180        return cl;
181    }
182
183    /*
184     * (non-Javadoc)
185     *
186     * @see de.ugoe.cs.cpdp.training.IWekaCompatibleTrainer#getName()
187     */
188    @Override
189    public String getName() {
190        return classifierName;
191    }
192
193}
Note: See TracBrowser for help on using the repository browser.