source: trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaBaggingTraining.java @ 41

Last change on this file since 41 was 41, checked in by sherbold, 9 years ago
  • formatted code and added copyrights
  • Property svn:mime-type set to text/plain
File size: 5.0 KB
Line 
1// Copyright 2015 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.cpdp.training;
16
17import java.io.PrintStream;
18import java.util.HashSet;
19import java.util.LinkedList;
20import java.util.List;
21import java.util.Set;
22
23import org.apache.commons.collections4.list.SetUniqueList;
24import org.apache.commons.io.output.NullOutputStream;
25
26import weka.classifiers.AbstractClassifier;
27import weka.classifiers.Classifier;
28import weka.core.DenseInstance;
29import weka.core.Instance;
30import weka.core.Instances;
31
32/**
33 * Programmatic WekaBaggingTraining
34 *
35 * first parameter is Trainer Name. second parameter is class name
36 *
37 * all subsequent parameters are configuration params (for example for trees) Cross Validation
38 * params always come last and are prepended with -CVPARAM
39 *
40 * XML Configurations for Weka Classifiers:
41 *
42 * <pre>
43 * {@code
44 * <!-- examples -->
45 * <setwisetrainer name="WekaBaggingTraining" param="NaiveBayesBagging weka.classifiers.bayes.NaiveBayes" />
46 * <setwisetrainer name="WekaBaggingTraining" param="LogisticBagging weka.classifiers.functions.Logistic -R 1.0E-8 -M -1" />
47 * }
48 * </pre>
49 *
50 */
51public class WekaBaggingTraining extends WekaBaseTraining implements ISetWiseTrainingStrategy {
52
53    private final TraindatasetBagging classifier = new TraindatasetBagging();
54
55    @Override
56    public Classifier getClassifier() {
57        return classifier;
58    }
59
60    @Override
61    public void apply(SetUniqueList<Instances> traindataSet) {
62        PrintStream errStr = System.err;
63        System.setErr(new PrintStream(new NullOutputStream()));
64        try {
65            classifier.buildClassifier(traindataSet);
66        }
67        catch (Exception e) {
68            throw new RuntimeException(e);
69        }
70        finally {
71            System.setErr(errStr);
72        }
73    }
74
75    public class TraindatasetBagging extends AbstractClassifier {
76
77        private static final long serialVersionUID = 1L;
78
79        private List<Instances> trainingData = null;
80
81        private List<Classifier> classifiers = null;
82
83        @Override
84        public double classifyInstance(Instance instance) {
85            if (classifiers == null) {
86                return 0.0;
87            }
88
89            double classification = 0.0;
90            for (int i = 0; i < classifiers.size(); i++) {
91                Classifier classifier = classifiers.get(i);
92                Instances traindata = trainingData.get(i);
93
94                Set<String> attributeNames = new HashSet<>();
95                for (int j = 0; j < traindata.numAttributes(); j++) {
96                    attributeNames.add(traindata.attribute(j).name());
97                }
98
99                double[] values = new double[traindata.numAttributes()];
100                int index = 0;
101                for (int j = 0; j < instance.numAttributes(); j++) {
102                    if (attributeNames.contains(instance.attribute(j).name())) {
103                        values[index] = instance.value(j);
104                        index++;
105                    }
106                }
107
108                Instances tmp = new Instances(traindata);
109                tmp.clear();
110                Instance instCopy = new DenseInstance(instance.weight(), values);
111                instCopy.setDataset(tmp);
112                try {
113                    classification += classifier.classifyInstance(instCopy);
114                }
115                catch (Exception e) {
116                    throw new RuntimeException("bagging classifier could not classify an instance",
117                                               e);
118                }
119            }
120            classification /= classifiers.size();
121            return (classification >= 0.5) ? 1.0 : 0.0;
122        }
123
124        public void buildClassifier(SetUniqueList<Instances> traindataSet) throws Exception {
125            classifiers = new LinkedList<>();
126            trainingData = new LinkedList<>();
127            for (Instances traindata : traindataSet) {
128                Classifier classifier = setupClassifier();
129                classifier.buildClassifier(traindata);
130                classifiers.add(classifier);
131                trainingData.add(new Instances(traindata));
132            }
133        }
134
135        @Override
136        public void buildClassifier(Instances traindata) throws Exception {
137            classifiers = new LinkedList<>();
138            trainingData = new LinkedList<>();
139            final Classifier classifier = setupClassifier();
140            classifier.buildClassifier(traindata);
141            classifiers.add(classifier);
142            trainingData.add(new Instances(traindata));
143        }
144    }
145}
Note: See TracBrowser for help on using the repository browser.