source: trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaBaggingTraining.java @ 121

Last change on this file since 121 was 99, checked in by sherbold, 9 years ago
  • improved error reporting
  • Property svn:mime-type set to text/plain
File size: 4.8 KB
Line 
1// Copyright 2015 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.cpdp.training;
16
17import java.util.HashSet;
18import java.util.LinkedList;
19import java.util.List;
20import java.util.Set;
21
22import org.apache.commons.collections4.list.SetUniqueList;
23
24import weka.classifiers.AbstractClassifier;
25import weka.classifiers.Classifier;
26import weka.core.DenseInstance;
27import weka.core.Instance;
28import weka.core.Instances;
29
30/**
31 * Programmatic WekaBaggingTraining
32 *
33 * first parameter is Trainer Name. second parameter is class name
34 *
35 * all subsequent parameters are configuration params (for example for trees) Cross Validation
36 * params always come last and are prepended with -CVPARAM
37 *
38 * XML Configurations for Weka Classifiers:
39 *
40 * <pre>
41 * {@code
42 * <!-- examples -->
43 * <setwisetrainer name="WekaBaggingTraining" param="NaiveBayesBagging weka.classifiers.bayes.NaiveBayes" />
44 * <setwisetrainer name="WekaBaggingTraining" param="LogisticBagging weka.classifiers.functions.Logistic -R 1.0E-8 -M -1" />
45 * }
46 * </pre>
47 *
48 */
49public class WekaBaggingTraining extends WekaBaseTraining implements ISetWiseTrainingStrategy {
50
51    private final TraindatasetBagging classifier = new TraindatasetBagging();
52
53    @Override
54    public Classifier getClassifier() {
55        return classifier;
56    }
57
58    @Override
59    public void apply(SetUniqueList<Instances> traindataSet) {
60        try {
61            classifier.buildClassifier(traindataSet);
62        }
63        catch (Exception e) {
64            throw new RuntimeException(e);
65        }
66    }
67
68    public class TraindatasetBagging extends AbstractClassifier {
69
70        private static final long serialVersionUID = 1L;
71
72        private List<Instances> trainingData = null;
73
74        private List<Classifier> classifiers = null;
75
76        @Override
77        public double classifyInstance(Instance instance) {
78            if (classifiers == null) {
79                return 0.0;
80            }
81
82            double classification = 0.0;
83            for (int i = 0; i < classifiers.size(); i++) {
84                Classifier classifier = classifiers.get(i);
85                Instances traindata = trainingData.get(i);
86
87                Set<String> attributeNames = new HashSet<>();
88                for (int j = 0; j < traindata.numAttributes(); j++) {
89                    attributeNames.add(traindata.attribute(j).name());
90                }
91
92                double[] values = new double[traindata.numAttributes()];
93                int index = 0;
94                for (int j = 0; j < instance.numAttributes(); j++) {
95                    if (attributeNames.contains(instance.attribute(j).name())) {
96                        values[index] = instance.value(j);
97                        index++;
98                    }
99                }
100
101                Instances tmp = new Instances(traindata);
102                tmp.clear();
103                Instance instCopy = new DenseInstance(instance.weight(), values);
104                instCopy.setDataset(tmp);
105                try {
106                    classification += classifier.classifyInstance(instCopy);
107                }
108                catch (Exception e) {
109                    throw new RuntimeException("bagging classifier could not classify an instance",
110                                               e);
111                }
112            }
113            classification /= classifiers.size();
114            return (classification >= 0.5) ? 1.0 : 0.0;
115        }
116
117        public void buildClassifier(SetUniqueList<Instances> traindataSet) throws Exception {
118            classifiers = new LinkedList<>();
119            trainingData = new LinkedList<>();
120            for (Instances traindata : traindataSet) {
121                Classifier classifier = setupClassifier();
122                classifier.buildClassifier(traindata);
123                classifiers.add(classifier);
124                trainingData.add(new Instances(traindata));
125            }
126        }
127
128        @Override
129        public void buildClassifier(Instances traindata) throws Exception {
130            classifiers = new LinkedList<>();
131            trainingData = new LinkedList<>();
132            final Classifier classifier = setupClassifier();
133            classifier.buildClassifier(traindata);
134            classifiers.add(classifier);
135            trainingData.add(new Instances(traindata));
136        }
137    }
138}
Note: See TracBrowser for help on using the repository browser.