source: trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaBaggingTraining.java

Last change on this file was 135, checked in by sherbold, 8 years ago
  • code documentation and formatting
  • Property svn:mime-type set to text/plain
File size: 6.1 KB
RevLine 
[86]1// Copyright 2015 Georg-August-Universität Göttingen, Germany
[41]2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
[2]15package de.ugoe.cs.cpdp.training;
16
17import java.util.HashSet;
18import java.util.LinkedList;
19import java.util.List;
20import java.util.Set;
21
22import org.apache.commons.collections4.list.SetUniqueList;
23
24import weka.classifiers.AbstractClassifier;
25import weka.classifiers.Classifier;
26import weka.core.DenseInstance;
27import weka.core.Instance;
28import weka.core.Instances;
29
30/**
[135]31 * <p>
32 * The first parameter is the trainer name, second parameter is class name. All subsequent
33 * parameters are configuration parameters of the algorithms. Cross validation parameters always
34 * come last and are prepended with -CVPARAM
35 * </p>
36 * <p>
[2]37 * XML Configurations for Weka Classifiers:
[41]38 *
[2]39 * <pre>
40 * {@code
41 * <!-- examples -->
[25]42 * <setwisetrainer name="WekaBaggingTraining" param="NaiveBayesBagging weka.classifiers.bayes.NaiveBayes" />
43 * <setwisetrainer name="WekaBaggingTraining" param="LogisticBagging weka.classifiers.functions.Logistic -R 1.0E-8 -M -1" />
[2]44 * }
45 * </pre>
[135]46 * </p>
[2]47 *
[135]48 * @author Alexander Trautsch
[2]49 */
[23]50public class WekaBaggingTraining extends WekaBaseTraining implements ISetWiseTrainingStrategy {
[2]51
[135]52    /**
53     * the classifier
54     */
[41]55    private final TraindatasetBagging classifier = new TraindatasetBagging();
[2]56
[135]57    /*
58     * (non-Javadoc)
59     *
60     * @see de.ugoe.cs.cpdp.training.WekaBaseTraining#getClassifier()
61     */
[41]62    @Override
63    public Classifier getClassifier() {
64        return classifier;
65    }
66
[135]67    /*
68     * (non-Javadoc)
69     *
70     * @see
71     * de.ugoe.cs.cpdp.training.ISetWiseTrainingStrategy#apply(org.apache.commons.collections4.list.
72     * SetUniqueList)
73     */
[41]74    @Override
75    public void apply(SetUniqueList<Instances> traindataSet) {
76        try {
77            classifier.buildClassifier(traindataSet);
78        }
79        catch (Exception e) {
80            throw new RuntimeException(e);
81        }
82    }
83
[135]84    /**
85     * <p>
86     * Helper class for bagging classifiers.
87     * </p>
88     *
89     * @author Steffen Herbold
90     */
[41]91    public class TraindatasetBagging extends AbstractClassifier {
92
[135]93        /**
94         * default serialization ID.
95         */
[41]96        private static final long serialVersionUID = 1L;
97
[135]98        /**
99         * internal storage of the training data
100         */
[41]101        private List<Instances> trainingData = null;
102
[135]103        /**
104         * bagging classifier for each training data set
105         */
[41]106        private List<Classifier> classifiers = null;
107
[135]108        /*
109         * (non-Javadoc)
110         *
111         * @see weka.classifiers.AbstractClassifier#classifyInstance(weka.core.Instance)
112         */
[41]113        @Override
114        public double classifyInstance(Instance instance) {
115            if (classifiers == null) {
116                return 0.0;
117            }
118
119            double classification = 0.0;
120            for (int i = 0; i < classifiers.size(); i++) {
121                Classifier classifier = classifiers.get(i);
122                Instances traindata = trainingData.get(i);
123
124                Set<String> attributeNames = new HashSet<>();
125                for (int j = 0; j < traindata.numAttributes(); j++) {
126                    attributeNames.add(traindata.attribute(j).name());
127                }
128
129                double[] values = new double[traindata.numAttributes()];
130                int index = 0;
131                for (int j = 0; j < instance.numAttributes(); j++) {
132                    if (attributeNames.contains(instance.attribute(j).name())) {
133                        values[index] = instance.value(j);
134                        index++;
135                    }
136                }
137
138                Instances tmp = new Instances(traindata);
139                tmp.clear();
140                Instance instCopy = new DenseInstance(instance.weight(), values);
141                instCopy.setDataset(tmp);
142                try {
143                    classification += classifier.classifyInstance(instCopy);
144                }
145                catch (Exception e) {
146                    throw new RuntimeException("bagging classifier could not classify an instance",
147                                               e);
148                }
149            }
150            classification /= classifiers.size();
151            return (classification >= 0.5) ? 1.0 : 0.0;
152        }
153
[135]154        /**
155         * <p>
156         * trains a new dataset wise bagging classifier
157         * </p>
158         *
159         * @param traindataSet
160         *            the training data per prodcut
161         * @throws Exception
162         *             thrown if an error occurs during the training of the classifiers for any
163         *             product
164         */
[41]165        public void buildClassifier(SetUniqueList<Instances> traindataSet) throws Exception {
166            classifiers = new LinkedList<>();
167            trainingData = new LinkedList<>();
168            for (Instances traindata : traindataSet) {
169                Classifier classifier = setupClassifier();
170                classifier.buildClassifier(traindata);
171                classifiers.add(classifier);
172                trainingData.add(new Instances(traindata));
173            }
174        }
175
[135]176        /*
177         * (non-Javadoc)
178         *
179         * @see weka.classifiers.Classifier#buildClassifier(weka.core.Instances)
180         */
[41]181        @Override
182        public void buildClassifier(Instances traindata) throws Exception {
183            classifiers = new LinkedList<>();
184            trainingData = new LinkedList<>();
185            final Classifier classifier = setupClassifier();
186            classifier.buildClassifier(traindata);
187            classifiers.add(classifier);
188            trainingData.add(new Instances(traindata));
189        }
190    }
[2]191}
Note: See TracBrowser for help on using the repository browser.