source: trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaBaggingTraining.java @ 37

Last change on this file since 37 was 25, checked in by atrautsch, 10 years ago

comment fixes

  • Property svn:mime-type set to text/plain
File size: 3.7 KB
Line 
1package de.ugoe.cs.cpdp.training;
2
3import java.io.PrintStream;
4import java.util.HashSet;
5import java.util.LinkedList;
6import java.util.List;
7import java.util.Set;
8
9import org.apache.commons.collections4.list.SetUniqueList;
10import org.apache.commons.io.output.NullOutputStream;
11
12import weka.classifiers.AbstractClassifier;
13import weka.classifiers.Classifier;
14import weka.core.DenseInstance;
15import weka.core.Instance;
16import weka.core.Instances;
17
18/**
19 * Programmatic WekaBaggingTraining
20 *
21 * first parameter is Trainer Name.
22 * second parameter is class name
23 *
24 * all subsequent parameters are configuration params (for example for trees)
25 * Cross Validation params always come last and are prepended with -CVPARAM
26 *
27 * XML Configurations for Weka Classifiers:
28 * <pre>
29 * {@code
30 * <!-- examples -->
31 * <setwisetrainer name="WekaBaggingTraining" param="NaiveBayesBagging weka.classifiers.bayes.NaiveBayes" />
32 * <setwisetrainer name="WekaBaggingTraining" param="LogisticBagging weka.classifiers.functions.Logistic -R 1.0E-8 -M -1" />
33 * }
34 * </pre>
35 *
36 */
37public class WekaBaggingTraining extends WekaBaseTraining implements ISetWiseTrainingStrategy {
38
39        private final TraindatasetBagging classifier = new TraindatasetBagging();
40       
41        @Override
42        public Classifier getClassifier() {
43                return classifier;
44        }
45       
46        @Override
47        public void apply(SetUniqueList<Instances> traindataSet) {
48                PrintStream errStr      = System.err;
49                System.setErr(new PrintStream(new NullOutputStream()));
50                try {
51                        classifier.buildClassifier(traindataSet);
52                } catch (Exception e) {
53                        throw new RuntimeException(e);
54                } finally {
55                        System.setErr(errStr);
56                }
57        }
58       
59        public class TraindatasetBagging extends AbstractClassifier {
60               
61                private static final long serialVersionUID = 1L;
62
63                private List<Instances> trainingData = null;
64               
65                private List<Classifier> classifiers = null;
66       
67                @Override
68                public double classifyInstance(Instance instance) {
69                        if( classifiers==null ) {
70                                return 0.0;
71                        }
72                       
73                        double classification = 0.0;
74                        for( int i=0 ; i<classifiers.size(); i++ ) {
75                                Classifier classifier = classifiers.get(i);
76                                Instances traindata = trainingData.get(i);
77                               
78                                Set<String> attributeNames = new HashSet<>();
79                                for( int j=0; j<traindata.numAttributes(); j++ ) {
80                                        attributeNames.add(traindata.attribute(j).name());
81                                }
82                               
83                                double[] values = new double[traindata.numAttributes()];
84                                int index = 0;
85                                for( int j=0; j<instance.numAttributes(); j++ ) {
86                                        if( attributeNames.contains(instance.attribute(j).name())) {
87                                                values[index] = instance.value(j);
88                                                index++;
89                                        }
90                                }
91                               
92                                Instances tmp = new Instances(traindata);
93                                tmp.clear();
94                                Instance instCopy = new DenseInstance(instance.weight(), values);
95                                instCopy.setDataset(tmp);
96                                try {
97                                        classification += classifier.classifyInstance(instCopy);
98                                } catch (Exception e) {
99                                        throw new RuntimeException("bagging classifier could not classify an instance", e);
100                                }
101                        }
102                        classification /= classifiers.size();
103                        return (classification>=0.5) ? 1.0 : 0.0;
104                }
105               
106                public void buildClassifier(SetUniqueList<Instances> traindataSet) throws Exception {
107                        classifiers = new LinkedList<>();
108                        trainingData = new LinkedList<>();
109                        for( Instances traindata : traindataSet ) {
110                                Classifier classifier = setupClassifier();
111                                classifier.buildClassifier(traindata);
112                                classifiers.add(classifier);
113                                trainingData.add(new Instances(traindata));
114                        }
115                }
116       
117                @Override
118                public void buildClassifier(Instances traindata) throws Exception {
119                        classifiers = new LinkedList<>();
120                        trainingData = new LinkedList<>();
121                        final Classifier classifier = setupClassifier();
122                        classifier.buildClassifier(traindata);
123                        classifiers.add(classifier);
124                        trainingData.add(new Instances(traindata));
125                }
126        }
127}
Note: See TracBrowser for help on using the repository browser.