// Copyright 2015 Georg-August-Universität Göttingen, Germany // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package de.ugoe.cs.cpdp.wekaclassifier; import java.util.Iterator; import java.util.LinkedList; import java.util.List; import java.util.Random; import java.util.stream.IntStream; import de.lmu.ifi.dbs.elki.logging.Logging.Level; import de.ugoe.cs.cpdp.util.SortUtils; import de.ugoe.cs.util.console.Console; import weka.classifiers.AbstractClassifier; import weka.classifiers.Classifier; import weka.classifiers.Evaluation; import weka.classifiers.functions.SMO; import weka.core.Capabilities; import weka.core.Instance; import weka.core.Instances; import weka.core.Utils; import weka.filters.Filter; import weka.filters.supervised.instance.Resample; /** *

* VCBSVM after Ryu et al. (2014) *

* * @author Steffen Herbold */ public class VCBSVM extends AbstractClassifier implements ITestAwareClassifier { /** * Default id */ private static final long serialVersionUID = 1L; /** * Test data. CLASSIFICATION MUST BE IGNORED! */ private Instances testdata = null; /** * Number of boosting iterations */ private int boostingIterations = 5; /** * Penalty parameter lamda */ private double lamda = 0.5; /** * Classifier trained in each boosting iteration */ private List boostingClassifiers; /** * Weights for each boosting iteration */ private List classifierWeights; /* * (non-Javadoc) * * @see weka.classifiers.AbstractClassifier#getCapabilities() */ @Override public Capabilities getCapabilities() { return new SMO().getCapabilities(); } /* * (non-Javadoc) * * @see weka.classifiers.AbstractClassifier#setOptions(java.lang.String[]) */ @Override public void setOptions(String[] options) throws Exception { String lamdaString = Utils.getOption('L', options); String boostingIterString = Utils.getOption('B', options); if (!boostingIterString.isEmpty()) { boostingIterations = Integer.parseInt(boostingIterString); } if (!lamdaString.isEmpty()) { lamda = Double.parseDouble(lamdaString); } } /* * (non-Javadoc) * * @see de.ugoe.cs.cpdp.wekaclassifier.ITestAwareClassifier#setTestdata(weka.core.Instances) */ @Override public void setTestdata(Instances testdata) { this.testdata = testdata; } /* * (non-Javadoc) * * @see weka.classifiers.AbstractClassifier#classifyInstance(weka.core.Instance) */ @Override public double classifyInstance(Instance instance) throws Exception { double classification = 0.0; Iterator classifierIter = boostingClassifiers.iterator(); Iterator weightIter = classifierWeights.iterator(); while (classifierIter.hasNext()) { Classifier classifier = classifierIter.next(); Double weight = weightIter.next(); if (classifier.classifyInstance(instance) > 0.5d) { classification += weight; } else { classification -= weight; } } return classification >= 0 ? 1.0d : 0.0d; } /* * (non-Javadoc) * * @see weka.classifiers.Classifier#buildClassifier(weka.core.Instances) */ @Override public void buildClassifier(Instances data) throws Exception { // get validation set Resample resample = new Resample(); resample.setSampleSizePercent(50); Instances validationCandidates; try { resample.setInputFormat(data); validationCandidates = Filter.useFilter(data, resample); } catch (Exception e) { Console.traceln(Level.SEVERE, "failure during validation set selection of VCBSVM"); throw new RuntimeException(e); } Double[] validationCandidateWeights = calculateSimilarityWeights(validationCandidates); int[] indexSet = new int[validationCandidateWeights.length]; IntStream.range(0, indexSet.length).forEach(val -> indexSet[val] = val); SortUtils.quicksort(validationCandidateWeights, indexSet, true); Instances validationdata = new Instances(validationCandidates); validationdata.clear(); int numValidationInstances = (int) Math.ceil(indexSet.length * 0.2); for (int i = 0; i < numValidationInstances; i++) { validationdata.add(validationCandidates.get(indexSet[i])); } // setup training data (data-validationdata) Instances traindata = new Instances(data); traindata.removeAll(validationdata); Double[] similarityWeights = calculateSimilarityWeights(traindata); double[] boostingWeights = new double[traindata.size()]; for (int i = 0; i < boostingWeights.length; i++) { boostingWeights[i] = 1.0d; } double bestAuc = 0.0; boostingClassifiers = new LinkedList<>(); classifierWeights = new LinkedList<>(); for (int boostingIter = 0; boostingIter < boostingIterations; boostingIter++) { for (int i = 0; i < boostingWeights.length; i++) { traindata.get(i).setWeight(boostingWeights[i]); } Instances traindataCurrentLoop; if (boostingIter > 0) { traindataCurrentLoop = sampleData(traindata, similarityWeights); } else { traindataCurrentLoop = traindata; } SMO internalClassifier = new SMO(); internalClassifier.buildClassifier(traindataCurrentLoop); double sumWeightedMisclassifications = 0.0d; double sumWeights = 0.0d; for (int i = 0; i < traindataCurrentLoop.size(); i++) { Instance inst = traindataCurrentLoop.get(i); if (inst.classValue() != internalClassifier.classifyInstance(inst)) { sumWeightedMisclassifications += inst.weight(); } sumWeights += inst.weight(); } double epsilon = sumWeightedMisclassifications / sumWeights; double alpha = lamda * Math.log((1.0d - epsilon) / epsilon); for (int i = 0; i < traindata.size(); i++) { Instance inst = traindata.get(i); if (inst.classValue() != internalClassifier.classifyInstance(inst)) { boostingWeights[i] *= boostingWeights[i] * Math.exp(alpha); } else { boostingWeights[i] *= boostingWeights[i] * Math.exp(-alpha); } } classifierWeights.add(alpha); boostingClassifiers.add(internalClassifier); final Evaluation eval = new Evaluation(validationdata); eval.evaluateModel(this, validationdata); double currentAuc = eval.areaUnderROC(1); final Evaluation eval2 = new Evaluation(validationdata); eval2.evaluateModel(internalClassifier, validationdata); if (currentAuc >= bestAuc) { bestAuc = currentAuc; } else { // performance drop, abort boosting, classifier of current iteration is dropped Console.traceln(Level.INFO, "no gain for boosting iteration " + (boostingIter + 1) + "; aborting boosting"); classifierWeights.remove(classifierWeights.size() - 1); boostingClassifiers.remove(boostingClassifiers.size() - 1); return; } } } /** *

* Calculates the similarity weights for the training data *

* * @param data * training data * @return vector with similarity weights */ private Double[] calculateSimilarityWeights(Instances data) { double[] minAttValues = new double[data.numAttributes()]; double[] maxAttValues = new double[data.numAttributes()]; Double[] weights = new Double[data.numInstances()]; for (int j = 0; j < data.numAttributes(); j++) { if (j != data.classIndex()) { minAttValues[j] = testdata.attributeStats(j).numericStats.min; maxAttValues[j] = testdata.attributeStats(j).numericStats.max; } } for (int i = 0; i < data.numInstances(); i++) { Instance inst = data.instance(i); int similar = 0; for (int j = 0; j < data.numAttributes(); j++) { if (j != data.classIndex()) { if (inst.value(j) >= minAttValues[j] && inst.value(j) <= maxAttValues[j]) { similar++; } } } weights[i] = similar / (data.numAttributes() - 1.0d); } return weights; } /** * *

* Samples data according to the similarity weights. This sampling *

* * @param data * @param similarityWeights * @return sampled data */ private Instances sampleData(Instances data, Double[] similarityWeights) { // split data into four sets; Instances similarPositive = new Instances(data); similarPositive.clear(); Instances similarNegative = new Instances(data); similarNegative.clear(); Instances notsimiPositive = new Instances(data); notsimiPositive.clear(); Instances notsimiNegative = new Instances(data); notsimiNegative.clear(); for (int i = 0; i < data.numInstances(); i++) { if (data.get(i).classValue() == 1.0) { if (similarityWeights[i] == 1.0) { similarPositive.add(data.get(i)); } else { notsimiPositive.add(data.get(i)); } } else { if (similarityWeights[i] == 1.0) { similarNegative.add(data.get(i)); } else { notsimiNegative.add(data.get(i)); } } } int sampleSizes = (similarPositive.size() + notsimiPositive.size()) / 2; similarPositive = weightedResample(similarPositive, sampleSizes); notsimiPositive = weightedResample(notsimiPositive, sampleSizes); similarNegative = weightedResample(similarNegative, sampleSizes); notsimiNegative = weightedResample(notsimiNegative, sampleSizes); similarPositive.addAll(similarNegative); similarPositive.addAll(notsimiPositive); similarPositive.addAll(notsimiNegative); return similarPositive; } /** *

* This is just my interpretation of the resampling. Details are missing from the paper. *

* * @param data * data that is sampled * @param size * desired size of the sample * @return sampled data */ private Instances weightedResample(final Instances data, final int size) { if (data.isEmpty()) { return data; } final Instances resampledData = new Instances(data); resampledData.clear(); double sumOfWeights = data.sumOfWeights(); Random rand = new Random(); while (resampledData.size() < size) { double randVal = rand.nextDouble() * sumOfWeights; double currentWeightSum = 0.0; for (int i = 0; i < data.size(); i++) { currentWeightSum += data.get(i).weight(); if (currentWeightSum >= randVal) { resampledData.add(data.get(i)); break; } } } return resampledData; } }