1 | // Copyright 2015 Georg-August-Universität Göttingen, Germany
|
---|
2 | //
|
---|
3 | // Licensed under the Apache License, Version 2.0 (the "License");
|
---|
4 | // you may not use this file except in compliance with the License.
|
---|
5 | // You may obtain a copy of the License at
|
---|
6 | //
|
---|
7 | // http://www.apache.org/licenses/LICENSE-2.0
|
---|
8 | //
|
---|
9 | // Unless required by applicable law or agreed to in writing, software
|
---|
10 | // distributed under the License is distributed on an "AS IS" BASIS,
|
---|
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
---|
12 | // See the License for the specific language governing permissions and
|
---|
13 | // limitations under the License.
|
---|
14 |
|
---|
15 | package de.ugoe.cs.cpdp.wekaclassifier;
|
---|
16 |
|
---|
17 | import java.util.Iterator;
|
---|
18 | import java.util.LinkedList;
|
---|
19 | import java.util.List;
|
---|
20 | import java.util.Random;
|
---|
21 | import java.util.stream.IntStream;
|
---|
22 |
|
---|
23 | import de.lmu.ifi.dbs.elki.logging.Logging.Level;
|
---|
24 | import de.ugoe.cs.cpdp.util.SortUtils;
|
---|
25 | import de.ugoe.cs.util.console.Console;
|
---|
26 | import weka.classifiers.AbstractClassifier;
|
---|
27 | import weka.classifiers.Classifier;
|
---|
28 | import weka.classifiers.Evaluation;
|
---|
29 | import weka.classifiers.functions.SMO;
|
---|
30 | import weka.core.Capabilities;
|
---|
31 | import weka.core.Instance;
|
---|
32 | import weka.core.Instances;
|
---|
33 | import weka.core.Utils;
|
---|
34 | import weka.filters.Filter;
|
---|
35 | import weka.filters.supervised.instance.Resample;
|
---|
36 |
|
---|
37 | /**
|
---|
38 | * <p>
|
---|
39 | * VCBSVM after Ryu et al. (2014)
|
---|
40 | * </p>
|
---|
41 | *
|
---|
42 | * @author Steffen Herbold
|
---|
43 | */
|
---|
44 | public class VCBSVM extends AbstractClassifier implements ITestAwareClassifier {
|
---|
45 |
|
---|
46 | /**
|
---|
47 | * Default id
|
---|
48 | */
|
---|
49 | private static final long serialVersionUID = 1L;
|
---|
50 |
|
---|
51 | /**
|
---|
52 | * Test data. CLASSIFICATION MUST BE IGNORED!
|
---|
53 | */
|
---|
54 | private Instances testdata = null;
|
---|
55 |
|
---|
56 | /**
|
---|
57 | * Number of boosting iterations
|
---|
58 | */
|
---|
59 | private int boostingIterations = 5;
|
---|
60 |
|
---|
61 | /**
|
---|
62 | * Penalty parameter lamda
|
---|
63 | */
|
---|
64 | private double lamda = 0.5;
|
---|
65 |
|
---|
66 | /**
|
---|
67 | * Classifier trained in each boosting iteration
|
---|
68 | */
|
---|
69 | private List<Classifier> boostingClassifiers;
|
---|
70 |
|
---|
71 | /**
|
---|
72 | * Weights for each boosting iteration
|
---|
73 | */
|
---|
74 | private List<Double> classifierWeights;
|
---|
75 |
|
---|
76 | /*
|
---|
77 | * (non-Javadoc)
|
---|
78 | *
|
---|
79 | * @see weka.classifiers.AbstractClassifier#getCapabilities()
|
---|
80 | */
|
---|
81 | @Override
|
---|
82 | public Capabilities getCapabilities() {
|
---|
83 | return new SMO().getCapabilities();
|
---|
84 | }
|
---|
85 |
|
---|
86 | /*
|
---|
87 | * (non-Javadoc)
|
---|
88 | *
|
---|
89 | * @see weka.classifiers.AbstractClassifier#setOptions(java.lang.String[])
|
---|
90 | */
|
---|
91 | @Override
|
---|
92 | public void setOptions(String[] options) throws Exception {
|
---|
93 | String lamdaString = Utils.getOption('L', options);
|
---|
94 | String boostingIterString = Utils.getOption('B', options);
|
---|
95 | if (!boostingIterString.isEmpty()) {
|
---|
96 | boostingIterations = Integer.parseInt(boostingIterString);
|
---|
97 | }
|
---|
98 | if (!lamdaString.isEmpty()) {
|
---|
99 | lamda = Double.parseDouble(lamdaString);
|
---|
100 | }
|
---|
101 | }
|
---|
102 |
|
---|
103 | /*
|
---|
104 | * (non-Javadoc)
|
---|
105 | *
|
---|
106 | * @see de.ugoe.cs.cpdp.wekaclassifier.ITestAwareClassifier#setTestdata(weka.core.Instances)
|
---|
107 | */
|
---|
108 | @Override
|
---|
109 | public void setTestdata(Instances testdata) {
|
---|
110 | this.testdata = testdata;
|
---|
111 | }
|
---|
112 |
|
---|
113 | /*
|
---|
114 | * (non-Javadoc)
|
---|
115 | *
|
---|
116 | * @see weka.classifiers.AbstractClassifier#classifyInstance(weka.core.Instance)
|
---|
117 | */
|
---|
118 | @Override
|
---|
119 | public double classifyInstance(Instance instance) throws Exception {
|
---|
120 | double classification = 0.0;
|
---|
121 | Iterator<Classifier> classifierIter = boostingClassifiers.iterator();
|
---|
122 | Iterator<Double> weightIter = classifierWeights.iterator();
|
---|
123 | while (classifierIter.hasNext()) {
|
---|
124 | Classifier classifier = classifierIter.next();
|
---|
125 | Double weight = weightIter.next();
|
---|
126 | if (classifier.classifyInstance(instance) > 0.5d) {
|
---|
127 | classification += weight;
|
---|
128 | }
|
---|
129 | else {
|
---|
130 | classification -= weight;
|
---|
131 | }
|
---|
132 | }
|
---|
133 | return classification >= 0 ? 1.0d : 0.0d;
|
---|
134 | }
|
---|
135 |
|
---|
136 | /*
|
---|
137 | * (non-Javadoc)
|
---|
138 | *
|
---|
139 | * @see weka.classifiers.Classifier#buildClassifier(weka.core.Instances)
|
---|
140 | */
|
---|
141 | @Override
|
---|
142 | public void buildClassifier(Instances data) throws Exception {
|
---|
143 | // get validation set
|
---|
144 | Resample resample = new Resample();
|
---|
145 | resample.setSampleSizePercent(50);
|
---|
146 | Instances validationCandidates;
|
---|
147 | try {
|
---|
148 | resample.setInputFormat(data);
|
---|
149 | validationCandidates = Filter.useFilter(data, resample);
|
---|
150 | }
|
---|
151 | catch (Exception e) {
|
---|
152 | Console.traceln(Level.SEVERE, "failure during validation set selection of VCBSVM");
|
---|
153 | throw new RuntimeException(e);
|
---|
154 | }
|
---|
155 | Double[] validationCandidateWeights = calculateSimilarityWeights(validationCandidates);
|
---|
156 | int[] indexSet = new int[validationCandidateWeights.length];
|
---|
157 | IntStream.range(0, indexSet.length).forEach(val -> indexSet[val] = val);
|
---|
158 | SortUtils.quicksort(validationCandidateWeights, indexSet, true);
|
---|
159 | Instances validationdata = new Instances(validationCandidates);
|
---|
160 | validationdata.clear();
|
---|
161 | int numValidationInstances = (int) Math.ceil(indexSet.length * 0.2);
|
---|
162 | for (int i = 0; i < numValidationInstances; i++) {
|
---|
163 | validationdata.add(validationCandidates.get(indexSet[i]));
|
---|
164 | }
|
---|
165 |
|
---|
166 | // setup training data (data-validationdata)
|
---|
167 | Instances traindata = new Instances(data);
|
---|
168 | traindata.removeAll(validationdata);
|
---|
169 | Double[] similarityWeights = calculateSimilarityWeights(traindata);
|
---|
170 |
|
---|
171 | double[] boostingWeights = new double[traindata.size()];
|
---|
172 | for (int i = 0; i < boostingWeights.length; i++) {
|
---|
173 | boostingWeights[i] = 1.0d;
|
---|
174 | }
|
---|
175 | double bestAuc = 0.0;
|
---|
176 | boostingClassifiers = new LinkedList<>();
|
---|
177 | classifierWeights = new LinkedList<>();
|
---|
178 | for (int boostingIter = 0; boostingIter < boostingIterations; boostingIter++) {
|
---|
179 | for (int i = 0; i < boostingWeights.length; i++) {
|
---|
180 | traindata.get(i).setWeight(boostingWeights[i]);
|
---|
181 | }
|
---|
182 |
|
---|
183 | Instances traindataCurrentLoop;
|
---|
184 | if (boostingIter > 0) {
|
---|
185 | traindataCurrentLoop = sampleData(traindata, similarityWeights);
|
---|
186 | }
|
---|
187 | else {
|
---|
188 | traindataCurrentLoop = traindata;
|
---|
189 | }
|
---|
190 |
|
---|
191 | SMO internalClassifier = new SMO();
|
---|
192 | internalClassifier.buildClassifier(traindataCurrentLoop);
|
---|
193 |
|
---|
194 | double sumWeightedMisclassifications = 0.0d;
|
---|
195 | double sumWeights = 0.0d;
|
---|
196 | for (int i = 0; i < traindataCurrentLoop.size(); i++) {
|
---|
197 | Instance inst = traindataCurrentLoop.get(i);
|
---|
198 | if (inst.classValue() != internalClassifier.classifyInstance(inst)) {
|
---|
199 | sumWeightedMisclassifications += inst.weight();
|
---|
200 | }
|
---|
201 | sumWeights += inst.weight();
|
---|
202 | }
|
---|
203 | double epsilon = sumWeightedMisclassifications / sumWeights;
|
---|
204 | double alpha = lamda * Math.log((1.0d - epsilon) / epsilon);
|
---|
205 | for (int i = 0; i < traindata.size(); i++) {
|
---|
206 | Instance inst = traindata.get(i);
|
---|
207 | if (inst.classValue() != internalClassifier.classifyInstance(inst)) {
|
---|
208 | boostingWeights[i] *= boostingWeights[i] * Math.exp(alpha);
|
---|
209 | }
|
---|
210 | else {
|
---|
211 | boostingWeights[i] *= boostingWeights[i] * Math.exp(-alpha);
|
---|
212 | }
|
---|
213 | }
|
---|
214 | classifierWeights.add(alpha);
|
---|
215 | boostingClassifiers.add(internalClassifier);
|
---|
216 |
|
---|
217 | final Evaluation eval = new Evaluation(validationdata);
|
---|
218 | eval.evaluateModel(this, validationdata);
|
---|
219 | double currentAuc = eval.areaUnderROC(1);
|
---|
220 | final Evaluation eval2 = new Evaluation(validationdata);
|
---|
221 | eval2.evaluateModel(internalClassifier, validationdata);
|
---|
222 |
|
---|
223 | if (currentAuc >= bestAuc) {
|
---|
224 | bestAuc = currentAuc;
|
---|
225 | }
|
---|
226 | else {
|
---|
227 | // performance drop, abort boosting, classifier of current iteration is dropped
|
---|
228 | Console.traceln(Level.INFO, "no gain for boosting iteration " + (boostingIter + 1) +
|
---|
229 | "; aborting boosting");
|
---|
230 | classifierWeights.remove(classifierWeights.size() - 1);
|
---|
231 | boostingClassifiers.remove(boostingClassifiers.size() - 1);
|
---|
232 | return;
|
---|
233 | }
|
---|
234 | }
|
---|
235 | }
|
---|
236 |
|
---|
237 | /**
|
---|
238 | * <p>
|
---|
239 | * Calculates the similarity weights for the training data
|
---|
240 | * </p>
|
---|
241 | *
|
---|
242 | * @param data
|
---|
243 | * training data
|
---|
244 | * @return vector with similarity weights
|
---|
245 | */
|
---|
246 | private Double[] calculateSimilarityWeights(Instances data) {
|
---|
247 | double[] minAttValues = new double[data.numAttributes()];
|
---|
248 | double[] maxAttValues = new double[data.numAttributes()];
|
---|
249 | Double[] weights = new Double[data.numInstances()];
|
---|
250 |
|
---|
251 | for (int j = 0; j < data.numAttributes(); j++) {
|
---|
252 | if (j != data.classIndex()) {
|
---|
253 | minAttValues[j] = testdata.attributeStats(j).numericStats.min;
|
---|
254 | maxAttValues[j] = testdata.attributeStats(j).numericStats.max;
|
---|
255 | }
|
---|
256 | }
|
---|
257 |
|
---|
258 | for (int i = 0; i < data.numInstances(); i++) {
|
---|
259 | Instance inst = data.instance(i);
|
---|
260 | int similar = 0;
|
---|
261 | for (int j = 0; j < data.numAttributes(); j++) {
|
---|
262 | if (j != data.classIndex()) {
|
---|
263 | if (inst.value(j) >= minAttValues[j] && inst.value(j) <= maxAttValues[j]) {
|
---|
264 | similar++;
|
---|
265 | }
|
---|
266 | }
|
---|
267 | }
|
---|
268 | weights[i] = similar / (data.numAttributes() - 1.0d);
|
---|
269 | }
|
---|
270 | return weights;
|
---|
271 | }
|
---|
272 |
|
---|
273 | /**
|
---|
274 | *
|
---|
275 | * <p>
|
---|
276 | * Samples data according to the similarity weights. This sampling
|
---|
277 | * </p>
|
---|
278 | *
|
---|
279 | * @param data
|
---|
280 | * @param similarityWeights
|
---|
281 | * @return sampled data
|
---|
282 | */
|
---|
283 | private Instances sampleData(Instances data, Double[] similarityWeights) {
|
---|
284 | // split data into four sets;
|
---|
285 | Instances similarPositive = new Instances(data);
|
---|
286 | similarPositive.clear();
|
---|
287 | Instances similarNegative = new Instances(data);
|
---|
288 | similarNegative.clear();
|
---|
289 | Instances notsimiPositive = new Instances(data);
|
---|
290 | notsimiPositive.clear();
|
---|
291 | Instances notsimiNegative = new Instances(data);
|
---|
292 | notsimiNegative.clear();
|
---|
293 | for (int i = 0; i < data.numInstances(); i++) {
|
---|
294 | if (data.get(i).classValue() == 1.0) {
|
---|
295 | if (similarityWeights[i] == 1.0) {
|
---|
296 | similarPositive.add(data.get(i));
|
---|
297 | }
|
---|
298 | else {
|
---|
299 | notsimiPositive.add(data.get(i));
|
---|
300 | }
|
---|
301 | }
|
---|
302 | else {
|
---|
303 | if (similarityWeights[i] == 1.0) {
|
---|
304 | similarNegative.add(data.get(i));
|
---|
305 | }
|
---|
306 | else {
|
---|
307 | notsimiNegative.add(data.get(i));
|
---|
308 | }
|
---|
309 | }
|
---|
310 | }
|
---|
311 |
|
---|
312 | int sampleSizes = (similarPositive.size() + notsimiPositive.size()) / 2;
|
---|
313 |
|
---|
314 | similarPositive = weightedResample(similarPositive, sampleSizes);
|
---|
315 | notsimiPositive = weightedResample(notsimiPositive, sampleSizes);
|
---|
316 | similarNegative = weightedResample(similarNegative, sampleSizes);
|
---|
317 | notsimiNegative = weightedResample(notsimiNegative, sampleSizes);
|
---|
318 | similarPositive.addAll(similarNegative);
|
---|
319 | similarPositive.addAll(notsimiPositive);
|
---|
320 | similarPositive.addAll(notsimiNegative);
|
---|
321 | return similarPositive;
|
---|
322 | }
|
---|
323 |
|
---|
324 | /**
|
---|
325 | * <p>
|
---|
326 | * This is just my interpretation of the resampling. Details are missing from the paper.
|
---|
327 | * </p>
|
---|
328 | *
|
---|
329 | * @param data
|
---|
330 | * data that is sampled
|
---|
331 | * @param size
|
---|
332 | * desired size of the sample
|
---|
333 | * @return sampled data
|
---|
334 | */
|
---|
335 | private Instances weightedResample(final Instances data, final int size) {
|
---|
336 | final Instances resampledData = new Instances(data);
|
---|
337 | resampledData.clear();
|
---|
338 | double sumOfWeights = data.sumOfWeights();
|
---|
339 | Random rand = new Random();
|
---|
340 | while (resampledData.size() < size) {
|
---|
341 | double randVal = rand.nextDouble() * sumOfWeights;
|
---|
342 | double currentWeightSum = 0.0;
|
---|
343 | for (int i = 0; i < data.size(); i++) {
|
---|
344 | currentWeightSum += data.get(i).weight();
|
---|
345 | if (currentWeightSum >= randVal) {
|
---|
346 | resampledData.add(data.get(i));
|
---|
347 | break;
|
---|
348 | }
|
---|
349 | }
|
---|
350 | }
|
---|
351 |
|
---|
352 | return resampledData;
|
---|
353 | }
|
---|
354 | }
|
---|