source: trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaBaggingTraining2.java @ 10

Last change on this file since 10 was 10, checked in by sherbold, 10 years ago
  • added some source code comments
  • Property svn:mime-type set to text/plain
File size: 3.6 KB
Line 
1package de.ugoe.cs.cpdp.training;
2
3import java.io.PrintStream;
4import java.util.HashSet;
5import java.util.LinkedList;
6import java.util.List;
7import java.util.Set;
8
9import org.apache.commons.collections4.list.SetUniqueList;
10import org.apache.commons.io.output.NullOutputStream;
11
12import weka.classifiers.AbstractClassifier;
13import weka.classifiers.Classifier;
14import weka.core.DenseInstance;
15import weka.core.Instance;
16import weka.core.Instances;
17
18/**
19 * Programmatic WekaBaggingTraining
20 *
21 * first parameter is Trainer Name.
22 * second parameter is class name
23 *
24 * all subsequent parameters are configuration params (for example for trees)
25 *
26 * XML Configurations for Weka Classifiers:
27 * <pre>
28 * {@code
29 * <!-- examples -->
30 * <setwisetrainer name="WekaBaggingTraining2" param="NaiveBayesBagging weka.classifiers.bayes.NaiveBayes" />
31 * <setwisetrainer name="WekaBaggingTraining2" param="LogisticBagging weka.classifiers.functions.Logistic -R 1.0E-8 -M -1" />
32 * }
33 * </pre>
34 *
35 */
36public class WekaBaggingTraining2 extends WekaBaseTraining2 implements ISetWiseTrainingStrategy {
37
38        private final TraindatasetBagging classifier = new TraindatasetBagging();
39       
40        @Override
41        public Classifier getClassifier() {
42                return classifier;
43        }
44       
45        @Override
46        public void apply(SetUniqueList<Instances> traindataSet) {
47                PrintStream errStr      = System.err;
48                System.setErr(new PrintStream(new NullOutputStream()));
49                try {
50                        classifier.buildClassifier(traindataSet);
51                } catch (Exception e) {
52                        throw new RuntimeException(e);
53                } finally {
54                        System.setErr(errStr);
55                }
56        }
57       
58        public class TraindatasetBagging extends AbstractClassifier {
59               
60                private static final long serialVersionUID = 1L;
61
62                private List<Instances> trainingData = null;
63               
64                private List<Classifier> classifiers = null;
65       
66                @Override
67                public double classifyInstance(Instance instance) {
68                        if( classifiers==null ) {
69                                return 0.0;
70                        }
71                       
72                        double classification = 0.0;
73                        for( int i=0 ; i<classifiers.size(); i++ ) {
74                                Classifier classifier = classifiers.get(i);
75                                Instances traindata = trainingData.get(i);
76                               
77                                Set<String> attributeNames = new HashSet<>();
78                                for( int j=0; j<traindata.numAttributes(); j++ ) {
79                                        attributeNames.add(traindata.attribute(j).name());
80                                }
81                               
82                                double[] values = new double[traindata.numAttributes()];
83                                int index = 0;
84                                for( int j=0; j<instance.numAttributes(); j++ ) {
85                                        if( attributeNames.contains(instance.attribute(j).name())) {
86                                                values[index] = instance.value(j);
87                                                index++;
88                                        }
89                                }
90                               
91                                Instances tmp = new Instances(traindata);
92                                tmp.clear();
93                                Instance instCopy = new DenseInstance(instance.weight(), values);
94                                instCopy.setDataset(tmp);
95                                try {
96                                        classification += classifier.classifyInstance(instCopy);
97                                } catch (Exception e) {
98                                        throw new RuntimeException("bagging classifier could not classify an instance", e);
99                                }
100                        }
101                        classification /= classifiers.size();
102                        return (classification>=0.5) ? 1.0 : 0.0;
103                }
104               
105                public void buildClassifier(SetUniqueList<Instances> traindataSet) throws Exception {
106                        classifiers = new LinkedList<>();
107                        trainingData = new LinkedList<>();
108                        for( Instances traindata : traindataSet ) {
109                                Classifier classifier = setupClassifier();
110                                classifier.buildClassifier(traindata);
111                                classifiers.add(classifier);
112                                trainingData.add(new Instances(traindata));
113                        }
114                }
115       
116                @Override
117                public void buildClassifier(Instances traindata) throws Exception {
118                        classifiers = new LinkedList<>();
119                        trainingData = new LinkedList<>();
120                        final Classifier classifier = setupClassifier();
121                        classifier.buildClassifier(traindata);
122                        classifiers.add(classifier);
123                        trainingData.add(new Instances(traindata));
124                }
125        }
126}
Note: See TracBrowser for help on using the repository browser.