source: trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaBaseTraining.java @ 88

Last change on this file since 88 was 86, checked in by sherbold, 9 years ago
  • switched workspace encoding to UTF-8 and fixed broken characters
  • Property svn:mime-type set to text/plain
File size: 4.8 KB
Line 
1// Copyright 2015 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.cpdp.training;
16
17import java.util.Arrays;
18import java.util.logging.Level;
19
20import de.ugoe.cs.util.console.Console;
21
22import weka.core.OptionHandler;
23import weka.classifiers.Classifier;
24import weka.classifiers.meta.CVParameterSelection;
25
26/**
27 * WekaBaseTraining2
28 *
29 * Allows specification of the Weka classifier and its params in the XML experiment configuration.
30 *
31 * Important conventions of the XML format: Cross Validation params always come last and are
32 * prepended with -CVPARAM Example: <trainer name="WekaTraining"
33 * param="RandomForestLocal weka.classifiers.trees.RandomForest -CVPARAM I 5 25 5"/>
34 */
35public abstract class WekaBaseTraining implements IWekaCompatibleTrainer {
36
37    protected Classifier classifier = null;
38    protected String classifierClassName;
39    protected String classifierName;
40    protected String[] classifierParams;
41
42    @Override
43    public void setParameter(String parameters) {
44        String[] params = parameters.split(" ");
45
46        // first part of the params is the classifierName (e.g. SMORBF)
47        classifierName = params[0];
48
49        // the following parameters can be copied from weka!
50
51        // second param is classifierClassName (e.g. weka.classifiers.functions.SMO)
52        classifierClassName = params[1];
53
54        // rest are params to the specified classifier (e.g. -K
55        // weka.classifiers.functions.supportVector.RBFKernel)
56        classifierParams = Arrays.copyOfRange(params, 2, params.length);
57
58        //classifier = setupClassifier();
59    }
60
61    @Override
62    public Classifier getClassifier() {
63        return classifier;
64    }
65
66    protected Classifier setupClassifier() {
67        Classifier cl = null;
68        try {
69            @SuppressWarnings("rawtypes")
70            Class c = Class.forName(classifierClassName);
71            Classifier obj = (Classifier) c.newInstance();
72
73            // Filter out -CVPARAM, these are special because they do not belong to the Weka
74            // classifier class as parameters
75            String[] param = Arrays.copyOf(classifierParams, classifierParams.length);
76            String[] cvparam = { };
77            boolean cv = false;
78            for (int i = 0; i < classifierParams.length; i++) {
79                if (classifierParams[i].equals("-CVPARAM")) {
80                    // rest of array are cvparam
81                    cvparam = Arrays.copyOfRange(classifierParams, i + 1, classifierParams.length);
82
83                    // before this we have normal params
84                    param = Arrays.copyOfRange(classifierParams, 0, i);
85
86                    cv = true;
87                    break;
88                }
89            }
90
91            // set classifier params
92            ((OptionHandler) obj).setOptions(param);
93            cl = obj;
94
95            // we have cross val params
96            // cant check on cvparam.length here, it may not be initialized
97            if (cv) {
98                final CVParameterSelection ps = new CVParameterSelection();
99                ps.setClassifier(obj);
100                ps.setNumFolds(5);
101                // ps.addCVParameter("I 5 25 5");
102                for (int i = 1; i < cvparam.length / 4; i++) {
103                    ps.addCVParameter(Arrays.asList(Arrays.copyOfRange(cvparam, 0, 4 * i))
104                        .toString().replaceAll(", ", " ").replaceAll("^\\[|\\]$", ""));
105                }
106
107                cl = ps;
108            }
109
110        }
111        catch (ClassNotFoundException e) {
112            Console.traceln(Level.WARNING, String.format("class not found: %s", e.toString()));
113            e.printStackTrace();
114        }
115        catch (InstantiationException e) {
116            Console.traceln(Level.WARNING,
117                            String.format("Instantiation Exception: %s", e.toString()));
118            e.printStackTrace();
119        }
120        catch (IllegalAccessException e) {
121            Console.traceln(Level.WARNING,
122                            String.format("Illegal Access Exception: %s", e.toString()));
123            e.printStackTrace();
124        }
125        catch (Exception e) {
126            Console.traceln(Level.WARNING, String.format("Exception: %s", e.toString()));
127            e.printStackTrace();
128        }
129
130        return cl;
131    }
132
133    @Override
134    public String getName() {
135        return classifierName;
136    }
137
138}
Note: See TracBrowser for help on using the repository browser.