source: trunk/CrossPare/src/de/ugoe/cs/cpdp/wekaclassifier/VCBSVM.java @ 72

Last change on this file since 72 was 72, checked in by sherbold, 8 years ago

VCBSVM after Ryu et al., 2014

  • Property svn:mime-type set to text/plain
File size: 12.5 KB
Line 
1// Copyright 2015 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.cpdp.wekaclassifier;
16
17import java.util.Iterator;
18import java.util.LinkedList;
19import java.util.List;
20import java.util.Random;
21import java.util.stream.IntStream;
22
23import de.lmu.ifi.dbs.elki.logging.Logging.Level;
24import de.ugoe.cs.cpdp.util.SortUtils;
25import de.ugoe.cs.util.console.Console;
26import weka.classifiers.AbstractClassifier;
27import weka.classifiers.Classifier;
28import weka.classifiers.Evaluation;
29import weka.classifiers.functions.SMO;
30import weka.core.Capabilities;
31import weka.core.Instance;
32import weka.core.Instances;
33import weka.core.Utils;
34import weka.filters.Filter;
35import weka.filters.supervised.instance.Resample;
36
37/**
38 * <p>
39 * VCBSVM after Ryu et al. (2014)
40 * </p>
41 *
42 * @author Steffen Herbold
43 */
44public class VCBSVM extends AbstractClassifier implements ITestAwareClassifier {
45
46    /**
47     * Default id
48     */
49    private static final long serialVersionUID = 1L;
50
51    /**
52     * Test data. CLASSIFICATION MUST BE IGNORED!
53     */
54    private Instances testdata = null;
55
56    /**
57     * Number of boosting iterations
58     */
59    private int boostingIterations = 5;
60
61    /**
62     * Penalty parameter lamda
63     */
64    private double lamda = 0.5;
65
66    /**
67     * Classifier trained in each boosting iteration
68     */
69    private List<Classifier> boostingClassifiers;
70
71    /**
72     * Weights for each boosting iteration
73     */
74    private List<Double> classifierWeights;
75
76    /*
77     * (non-Javadoc)
78     *
79     * @see weka.classifiers.AbstractClassifier#getCapabilities()
80     */
81    @Override
82    public Capabilities getCapabilities() {
83        return new SMO().getCapabilities();
84    }
85
86    /*
87     * (non-Javadoc)
88     *
89     * @see weka.classifiers.AbstractClassifier#setOptions(java.lang.String[])
90     */
91    @Override
92    public void setOptions(String[] options) throws Exception {
93        String lamdaString = Utils.getOption('L', options);
94        String boostingIterString = Utils.getOption('B', options);
95        if (!boostingIterString.equals("")) {
96            boostingIterations = Integer.parseInt(boostingIterString);
97        }
98        if (lamdaString.equals("")) {
99            lamda = Double.parseDouble(lamdaString);
100        }
101    }
102
103    /*
104     * (non-Javadoc)
105     *
106     * @see de.ugoe.cs.cpdp.wekaclassifier.ITestAwareClassifier#setTestdata(weka.core.Instances)
107     */
108    @Override
109    public void setTestdata(Instances testdata) {
110        this.testdata = testdata;
111    }
112
113    /*
114     * (non-Javadoc)
115     *
116     * @see weka.classifiers.AbstractClassifier#classifyInstance(weka.core.Instance)
117     */
118    @Override
119    public double classifyInstance(Instance instance) throws Exception {
120        double classification = 0.0;
121        Iterator<Classifier> classifierIter = boostingClassifiers.iterator();
122        Iterator<Double> weightIter = classifierWeights.iterator();
123        while (classifierIter.hasNext()) {
124            Classifier classifier = classifierIter.next();
125            Double weight = weightIter.next();
126            if (classifier.classifyInstance(instance) > 0.5d) {
127                classification += weight;
128            }
129            else {
130                classification -= weight;
131            }
132        }
133        return classification >= 0 ? 1.0d : 0.0d;
134    }
135
136    /*
137     * (non-Javadoc)
138     *
139     * @see weka.classifiers.Classifier#buildClassifier(weka.core.Instances)
140     */
141    @Override
142    public void buildClassifier(Instances data) throws Exception {
143        // get validation set
144        Resample resample = new Resample();
145        resample.setSampleSizePercent(50);
146        Instances validationCandidates;
147        try {
148            resample.setInputFormat(data);
149            validationCandidates = Filter.useFilter(data, resample);
150        }
151        catch (Exception e) {
152            Console.traceln(Level.SEVERE, "failure during validation set selection of VCBSVM");
153            throw new RuntimeException(e);
154        }
155        Double[] validationCandidateWeights = calculateSimilarityWeights(validationCandidates);
156        int[] indexSet = new int[validationCandidateWeights.length];
157        IntStream.range(0, indexSet.length).forEach(val -> indexSet[val] = val);
158        SortUtils.quicksort(validationCandidateWeights, indexSet, true);
159        Instances validationdata = new Instances(validationCandidates);
160        validationdata.clear();
161        int numValidationInstances = (int) Math.ceil(indexSet.length * 0.2);
162        for (int i = 0; i < numValidationInstances; i++) {
163            validationdata.add(validationCandidates.get(indexSet[i]));
164        }
165
166        // setup training data (data-validationdata)
167        Instances traindata = new Instances(data);
168        traindata.removeAll(validationdata);
169        Double[] similarityWeights = calculateSimilarityWeights(traindata);
170
171        double[] boostingWeights = new double[traindata.size()];
172        for (int i = 0; i < boostingWeights.length; i++) {
173            boostingWeights[i] = 1.0d;
174        }
175        double bestAuc = 0.0;
176        boostingClassifiers = new LinkedList<>();
177        classifierWeights = new LinkedList<>();
178        for (int boostingIter = 0; boostingIter < boostingIterations; boostingIter++) {
179            for (int i = 0; i < boostingWeights.length; i++) {
180                traindata.get(i).setWeight(boostingWeights[i]);
181            }
182
183            Instances traindataCurrentLoop;
184            if (boostingIter > 0) {
185                traindataCurrentLoop = sampleData(traindata, similarityWeights);
186            }
187            else {
188                traindataCurrentLoop = traindata;
189            }
190
191            SMO internalClassifier = new SMO();
192            internalClassifier.buildClassifier(traindataCurrentLoop);
193
194            double sumWeightedMisclassifications = 0.0d;
195            double sumWeights = 0.0d;
196            for (int i = 0; i < traindataCurrentLoop.size(); i++) {
197                Instance inst = traindataCurrentLoop.get(i);
198                if (inst.classValue() != internalClassifier.classifyInstance(inst)) {
199                    sumWeightedMisclassifications += inst.weight();
200                }
201                sumWeights += inst.weight();
202            }
203            double epsilon = sumWeightedMisclassifications / sumWeights;
204            double alpha = lamda * Math.log((1.0d - epsilon) / epsilon);
205            for (int i = 0; i < traindata.size(); i++) {
206                Instance inst = traindata.get(i);
207                if (inst.classValue() != internalClassifier.classifyInstance(inst)) {
208                    boostingWeights[i] *= boostingWeights[i] * Math.exp(alpha);
209                }
210                else {
211                    boostingWeights[i] *= boostingWeights[i] * Math.exp(-alpha);
212                }
213            }
214            classifierWeights.add(alpha);
215            boostingClassifiers.add(internalClassifier);
216
217            final Evaluation eval = new Evaluation(validationdata);
218            eval.evaluateModel(this, validationdata);
219            double currentAuc = eval.areaUnderROC(1);
220            final Evaluation eval2 = new Evaluation(validationdata);
221            eval2.evaluateModel(internalClassifier, validationdata);
222
223            if (currentAuc >= bestAuc) {
224                bestAuc = currentAuc;
225            }
226            else {
227                // performance drop, abort boosting, classifier of current iteration is dropped
228                Console.traceln(Level.INFO, "no gain for boosting iteration " + (boostingIter + 1) +
229                    "; aborting boosting");
230                classifierWeights.remove(classifierWeights.size() - 1);
231                boostingClassifiers.remove(boostingClassifiers.size() - 1);
232                return;
233            }
234        }
235    }
236
237    /**
238     * <p>
239     * Calculates the similarity weights for the training data
240     * </p>
241     *
242     * @param data
243     *            training data
244     * @return vector with similarity weights
245     */
246    private Double[] calculateSimilarityWeights(Instances data) {
247        double[] minAttValues = new double[data.numAttributes()];
248        double[] maxAttValues = new double[data.numAttributes()];
249        Double[] weights = new Double[data.numInstances()];
250
251        for (int j = 0; j < data.numAttributes(); j++) {
252            if (j != data.classIndex()) {
253                minAttValues[j] = testdata.attributeStats(j).numericStats.min;
254                maxAttValues[j] = testdata.attributeStats(j).numericStats.max;
255            }
256        }
257
258        for (int i = 0; i < data.numInstances(); i++) {
259            Instance inst = data.instance(i);
260            int similar = 0;
261            for (int j = 0; j < data.numAttributes(); j++) {
262                if (j != data.classIndex()) {
263                    if (inst.value(j) >= minAttValues[j] && inst.value(j) <= maxAttValues[j]) {
264                        similar++;
265                    }
266                }
267            }
268            weights[i] = similar / (data.numAttributes() - 1.0d);
269        }
270        return weights;
271    }
272
273    /**
274     *
275     * <p>
276     * Samples data according to the similarity weights. This sampling
277     * </p>
278     *
279     * @param data
280     * @param similarityWeights
281     * @return sampled data
282     */
283    private Instances sampleData(Instances data, Double[] similarityWeights) {
284        // split data into four sets;
285        Instances similarPositive = new Instances(data);
286        similarPositive.clear();
287        Instances similarNegative = new Instances(data);
288        similarNegative.clear();
289        Instances notsimiPositive = new Instances(data);
290        notsimiPositive.clear();
291        Instances notsimiNegative = new Instances(data);
292        notsimiNegative.clear();
293        for (int i = 0; i < data.numInstances(); i++) {
294            if (data.get(i).classValue() == 1.0) {
295                if (similarityWeights[i] == 1.0) {
296                    similarPositive.add(data.get(i));
297                }
298                else {
299                    notsimiPositive.add(data.get(i));
300                }
301            }
302            else {
303                if (similarityWeights[i] == 1.0) {
304                    similarNegative.add(data.get(i));
305                }
306                else {
307                    notsimiNegative.add(data.get(i));
308                }
309            }
310        }
311
312        int sampleSizes = (similarPositive.size() + notsimiPositive.size()) / 2;
313
314        similarPositive = weightedResample(similarPositive, sampleSizes);
315        notsimiPositive = weightedResample(notsimiPositive, sampleSizes);
316        similarNegative = weightedResample(similarNegative, sampleSizes);
317        notsimiNegative = weightedResample(notsimiNegative, sampleSizes);
318        similarPositive.addAll(similarNegative);
319        similarPositive.addAll(notsimiPositive);
320        similarPositive.addAll(notsimiNegative);
321        return similarPositive;
322    }
323
324    /**
325     * <p>
326     * This is just my interpretation of the resampling. Details are missing from the paper.
327     * </p>
328     *
329     * @param data
330     *            data that is sampled
331     * @param size
332     *            desired size of the sample
333     * @return sampled data
334     */
335    private Instances weightedResample(final Instances data, final int size) {
336        final Instances resampledData = new Instances(data);
337        resampledData.clear();
338        double sumOfWeights = data.sumOfWeights();
339        Random rand = new Random();
340        while (resampledData.size() < size) {
341            double randVal = rand.nextDouble() * sumOfWeights;
342            double currentWeightSum = 0.0;
343            for (int i = 0; i < data.size(); i++) {
344                currentWeightSum += data.get(i).weight();
345                if (currentWeightSum >= randVal) {
346                    resampledData.add(data.get(i));
347                    break;
348                }
349            }
350        }
351
352        return resampledData;
353    }
354}
Note: See TracBrowser for help on using the repository browser.