// Copyright 2015 Georg-August-Universität Göttingen, Germany // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package de.ugoe.cs.cpdp.training; import java.util.LinkedList; import java.util.List; import de.ugoe.cs.cpdp.util.WekaUtils; import weka.classifiers.AbstractClassifier; import weka.classifiers.Classifier; import weka.core.Instance; import weka.core.Instances; /** *

* Implements training following the LASER classification scheme. *

* * @author Steffen Herbold */ public class WekaLASERTraining extends WekaBaseTraining implements ITrainingStrategy { /** * Internal classifier used for LASER. */ private final LASERClassifier internalClassifier = new LASERClassifier(); /* * (non-Javadoc) * * @see de.ugoe.cs.cpdp.training.WekaBaseTraining#getClassifier() */ @Override public Classifier getClassifier() { return internalClassifier; } /* * (non-Javadoc) * * @see de.ugoe.cs.cpdp.training.ITrainingStrategy#apply(weka.core.Instances) */ @Override public void apply(Instances traindata) { try { internalClassifier.buildClassifier(traindata); } catch (Exception e) { throw new RuntimeException(e); } } /** *

* Internal helper class that defines the laser classifier. *

* * @author Steffen Herbold */ public class LASERClassifier extends AbstractClassifier { /** * Default serial ID. */ private static final long serialVersionUID = 1L; /** * Internal reference to the classifier. */ private Classifier laserClassifier = null; /** * Internal storage of the training data required for NN analysis. */ private Instances traindata = null; /* * (non-Javadoc) * * @see weka.classifiers.AbstractClassifier#classifyInstance(weka.core.Instance) */ @Override public double classifyInstance(Instance instance) throws Exception { List closestInstances = new LinkedList<>(); double minDistance = Double.MAX_VALUE; for (int i = 0; i < traindata.size(); i++) { double distance = WekaUtils.hammingDistance(instance, traindata.get(i)); if (distance < minDistance) { minDistance = distance; } } for (int i = 0; i < traindata.size(); i++) { double distance = WekaUtils.hammingDistance(instance, traindata.get(i)); if (distance <= minDistance) { closestInstances.add(i); } } if (closestInstances.size() == 1) { int closestIndex = closestInstances.get(0); Instance closestTrainingInstance = traindata.get(closestIndex); List closestToTrainingInstance = new LinkedList<>(); double minTrainingDistance = Double.MAX_VALUE; for (int i = 0; i < traindata.size(); i++) { if (closestIndex != i) { double distance = WekaUtils.hammingDistance(closestTrainingInstance, traindata.get(i)); if (distance < minTrainingDistance) { minTrainingDistance = distance; } } } for (int i = 0; i < traindata.size(); i++) { if (closestIndex != i) { double distance = WekaUtils.hammingDistance(closestTrainingInstance, traindata.get(i)); if (distance <= minTrainingDistance) { closestToTrainingInstance.add(i); } } } if (closestToTrainingInstance.size() == 1) { return laserClassifier.classifyInstance(instance); } else { double label = Double.NaN; boolean allEqual = true; for (Integer index : closestToTrainingInstance) { if (Double.isNaN(label)) { label = traindata.get(index).classValue(); } else if (label != traindata.get(index).classValue()) { allEqual = false; break; } } if (allEqual) { return label; } else { return laserClassifier.classifyInstance(instance); } } } else { double label = Double.NaN; boolean allEqual = true; for (Integer index : closestInstances) { if (Double.isNaN(label)) { label = traindata.get(index).classValue(); } else if (label != traindata.get(index).classValue()) { allEqual = false; break; } } if (allEqual) { return label; } else { return laserClassifier.classifyInstance(instance); } } } /* * (non-Javadoc) * * @see weka.classifiers.Classifier#buildClassifier(weka.core.Instances) */ @Override public void buildClassifier(Instances traindata) throws Exception { this.traindata = new Instances(traindata); laserClassifier = setupClassifier(); laserClassifier.buildClassifier(traindata); } } }