1 | package de.ugoe.cs.cpdp.training;
|
---|
2 |
|
---|
3 | import java.io.PrintStream;
|
---|
4 | import java.util.HashSet;
|
---|
5 | import java.util.LinkedList;
|
---|
6 | import java.util.List;
|
---|
7 | import java.util.Set;
|
---|
8 |
|
---|
9 | import org.apache.commons.collections4.list.SetUniqueList;
|
---|
10 | import org.apache.commons.io.output.NullOutputStream;
|
---|
11 |
|
---|
12 | import weka.classifiers.AbstractClassifier;
|
---|
13 | import weka.classifiers.Classifier;
|
---|
14 | import weka.core.DenseInstance;
|
---|
15 | import weka.core.Instance;
|
---|
16 | import weka.core.Instances;
|
---|
17 |
|
---|
18 | public abstract class BaggingTraining implements ISetWiseTrainingStrategy, WekaCompatibleTrainer {
|
---|
19 |
|
---|
20 | protected abstract Classifier setupClassifier();
|
---|
21 |
|
---|
22 | private final TraindatasetBagging classifier = new TraindatasetBagging();
|
---|
23 |
|
---|
24 | public void apply(SetUniqueList<Instances> traindataSet) {
|
---|
25 | PrintStream errStr = System.err;
|
---|
26 | System.setErr(new PrintStream(new NullOutputStream()));
|
---|
27 | try {
|
---|
28 | classifier.buildClassifier(traindataSet);
|
---|
29 | } catch (Exception e) {
|
---|
30 | throw new RuntimeException(e);
|
---|
31 | } finally {
|
---|
32 | System.setErr(errStr);
|
---|
33 | }
|
---|
34 | }
|
---|
35 |
|
---|
36 | @Override
|
---|
37 | public Classifier getClassifier() {
|
---|
38 | return classifier;
|
---|
39 | }
|
---|
40 |
|
---|
41 | @Override
|
---|
42 | public void setParameter(String parameters) {
|
---|
43 | // TODO should allow passing of weka parameters to the classifier
|
---|
44 | }
|
---|
45 |
|
---|
46 | public class TraindatasetBagging extends AbstractClassifier {
|
---|
47 |
|
---|
48 | /**
|
---|
49 | *
|
---|
50 | */
|
---|
51 | private static final long serialVersionUID = 1L;
|
---|
52 |
|
---|
53 | private List<Instances> trainingData = null;
|
---|
54 |
|
---|
55 | private List<Classifier> classifiers = null;
|
---|
56 |
|
---|
57 | @Override
|
---|
58 | public double classifyInstance(Instance instance) {
|
---|
59 | if( classifiers==null ) {
|
---|
60 | return 0.0;
|
---|
61 | }
|
---|
62 |
|
---|
63 | double classification = 0.0;
|
---|
64 | for( int i=0 ; i<classifiers.size(); i++ ) {
|
---|
65 | Classifier classifier = classifiers.get(i);
|
---|
66 | Instances traindata = trainingData.get(i);
|
---|
67 |
|
---|
68 | Set<String> attributeNames = new HashSet<>();
|
---|
69 | for( int j=0; j<traindata.numAttributes(); j++ ) {
|
---|
70 | attributeNames.add(traindata.attribute(j).name());
|
---|
71 | }
|
---|
72 |
|
---|
73 | double[] values = new double[traindata.numAttributes()];
|
---|
74 | int index = 0;
|
---|
75 | for( int j=0; j<instance.numAttributes(); j++ ) {
|
---|
76 | if( attributeNames.contains(instance.attribute(j).name())) {
|
---|
77 | values[index] = instance.value(j);
|
---|
78 | index++;
|
---|
79 | }
|
---|
80 | }
|
---|
81 |
|
---|
82 | Instances tmp = new Instances(traindata);
|
---|
83 | tmp.clear();
|
---|
84 | Instance instCopy = new DenseInstance(instance.weight(), values);
|
---|
85 | instCopy.setDataset(tmp);
|
---|
86 | try {
|
---|
87 | classification += classifier.classifyInstance(instCopy);
|
---|
88 | } catch (Exception e) {
|
---|
89 | throw new RuntimeException("bagging classifier could not classify an instance", e);
|
---|
90 | }
|
---|
91 | }
|
---|
92 | classification /= classifiers.size();
|
---|
93 | return (classification>=0.5) ? 1.0 : 0.0;
|
---|
94 | }
|
---|
95 |
|
---|
96 | public void buildClassifier(SetUniqueList<Instances> traindataSet) throws Exception {
|
---|
97 | classifiers = new LinkedList<>();
|
---|
98 | trainingData = new LinkedList<>();
|
---|
99 | for( Instances traindata : traindataSet ) {
|
---|
100 | Classifier classifier = setupClassifier();
|
---|
101 | classifier.buildClassifier(traindata);
|
---|
102 | classifiers.add(classifier);
|
---|
103 | trainingData.add(new Instances(traindata));
|
---|
104 | }
|
---|
105 | }
|
---|
106 |
|
---|
107 | @Override
|
---|
108 | public void buildClassifier(Instances traindata) throws Exception {
|
---|
109 | classifiers = new LinkedList<>();
|
---|
110 | trainingData = new LinkedList<>();
|
---|
111 | final Classifier classifier = setupClassifier();
|
---|
112 | classifier.buildClassifier(traindata);
|
---|
113 | classifiers.add(classifier);
|
---|
114 | trainingData.add(new Instances(traindata));
|
---|
115 | }
|
---|
116 | }
|
---|
117 |
|
---|
118 | }
|
---|