source: trunk/CrossPare/src/de/ugoe/cs/cpdp/training/BaggingTraining.java @ 9

Last change on this file since 9 was 2, checked in by sherbold, 10 years ago
  • initial commit
  • Property svn:mime-type set to text/plain
File size: 3.5 KB
Line 
1package de.ugoe.cs.cpdp.training;
2
3import java.io.PrintStream;
4import java.util.HashSet;
5import java.util.LinkedList;
6import java.util.List;
7import java.util.Set;
8
9import org.apache.commons.collections4.list.SetUniqueList;
10import org.apache.commons.io.output.NullOutputStream;
11
12import weka.classifiers.AbstractClassifier;
13import weka.classifiers.Classifier;
14import weka.core.DenseInstance;
15import weka.core.Instance;
16import weka.core.Instances;
17
18public abstract class BaggingTraining implements ISetWiseTrainingStrategy, WekaCompatibleTrainer {
19
20        protected abstract Classifier setupClassifier();
21       
22        private final TraindatasetBagging classifier = new TraindatasetBagging();
23       
24        public void apply(SetUniqueList<Instances> traindataSet) {
25                PrintStream errStr      = System.err;
26                System.setErr(new PrintStream(new NullOutputStream()));
27                try {
28                        classifier.buildClassifier(traindataSet);
29                } catch (Exception e) {
30                        throw new RuntimeException(e);
31                } finally {
32                        System.setErr(errStr);
33                }
34        }
35       
36        @Override
37        public Classifier getClassifier() {
38                return classifier;
39        }
40       
41        @Override
42        public void setParameter(String parameters) {
43                // TODO should allow passing of weka parameters to the classifier
44        }
45       
46        public class TraindatasetBagging extends AbstractClassifier {
47               
48                /**
49                 *
50                 */
51                private static final long serialVersionUID = 1L;
52
53                private List<Instances> trainingData = null;
54               
55                private List<Classifier> classifiers = null;
56       
57                @Override
58                public double classifyInstance(Instance instance) {
59                        if( classifiers==null ) {
60                                return 0.0; // TODO check how WEKA expects classifyInstance to behave if no classifier exists yet
61                        }
62                       
63                        double classification = 0.0;
64                        for( int i=0 ; i<classifiers.size(); i++ ) {
65                                Classifier classifier = classifiers.get(i);
66                                Instances traindata = trainingData.get(i);
67                               
68                                Set<String> attributeNames = new HashSet<>();
69                                for( int j=0; j<traindata.numAttributes(); j++ ) {
70                                        attributeNames.add(traindata.attribute(j).name());
71                                }
72                               
73                                double[] values = new double[traindata.numAttributes()];
74                                int index = 0;
75                                for( int j=0; j<instance.numAttributes(); j++ ) {
76                                        if( attributeNames.contains(instance.attribute(j).name())) {
77                                                values[index] = instance.value(j);
78                                                index++;
79                                        }
80                                }
81                               
82                                Instances tmp = new Instances(traindata);
83                                tmp.clear();
84                                Instance instCopy = new DenseInstance(instance.weight(), values);
85                                instCopy.setDataset(tmp);
86                                try {
87                                        classification += classifier.classifyInstance(instCopy);
88                                } catch (Exception e) {
89                                        throw new RuntimeException("bagging classifier could not classify an instance", e);
90                                }
91                        }
92                        classification /= classifiers.size();
93                        return (classification>=0.5) ? 1.0 : 0.0;
94                }
95               
96                public void buildClassifier(SetUniqueList<Instances> traindataSet) throws Exception {
97                        classifiers = new LinkedList<>();
98                        trainingData = new LinkedList<>();
99                        for( Instances traindata : traindataSet ) {
100                                Classifier classifier = setupClassifier();
101                                classifier.buildClassifier(traindata);
102                                classifiers.add(classifier);
103                                trainingData.add(new Instances(traindata));
104                        }
105                }
106       
107                @Override
108                public void buildClassifier(Instances traindata) throws Exception {
109                        classifiers = new LinkedList<>();
110                        trainingData = new LinkedList<>();
111                        final Classifier classifier = setupClassifier();
112                        classifier.buildClassifier(traindata);
113                        classifiers.add(classifier);
114                        trainingData.add(new Instances(traindata));
115                }
116        }
117
118}
Note: See TracBrowser for help on using the repository browser.