// Copyright 2015 Georg-August-Universität Göttingen, Germany // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package de.ugoe.cs.cpdp.training; import java.io.PrintStream; import java.util.HashSet; import java.util.LinkedList; import java.util.List; import java.util.Set; import org.apache.commons.collections4.list.SetUniqueList; import org.apache.commons.io.output.NullOutputStream; import weka.classifiers.AbstractClassifier; import weka.classifiers.Classifier; import weka.core.DenseInstance; import weka.core.Instance; import weka.core.Instances; /** * Programmatic WekaBaggingTraining * * first parameter is Trainer Name. second parameter is class name * * all subsequent parameters are configuration params (for example for trees) Cross Validation * params always come last and are prepended with -CVPARAM * * XML Configurations for Weka Classifiers: * *
 * {@code
 * 
 * 
 * 
 * }
 * 
* */ public class WekaBaggingTraining extends WekaBaseTraining implements ISetWiseTrainingStrategy { private final TraindatasetBagging classifier = new TraindatasetBagging(); @Override public Classifier getClassifier() { return classifier; } @Override public void apply(SetUniqueList traindataSet) { PrintStream errStr = System.err; System.setErr(new PrintStream(new NullOutputStream())); try { classifier.buildClassifier(traindataSet); } catch (Exception e) { throw new RuntimeException(e); } finally { System.setErr(errStr); } } public class TraindatasetBagging extends AbstractClassifier { private static final long serialVersionUID = 1L; private List trainingData = null; private List classifiers = null; @Override public double classifyInstance(Instance instance) { if (classifiers == null) { return 0.0; } double classification = 0.0; for (int i = 0; i < classifiers.size(); i++) { Classifier classifier = classifiers.get(i); Instances traindata = trainingData.get(i); Set attributeNames = new HashSet<>(); for (int j = 0; j < traindata.numAttributes(); j++) { attributeNames.add(traindata.attribute(j).name()); } double[] values = new double[traindata.numAttributes()]; int index = 0; for (int j = 0; j < instance.numAttributes(); j++) { if (attributeNames.contains(instance.attribute(j).name())) { values[index] = instance.value(j); index++; } } Instances tmp = new Instances(traindata); tmp.clear(); Instance instCopy = new DenseInstance(instance.weight(), values); instCopy.setDataset(tmp); try { classification += classifier.classifyInstance(instCopy); } catch (Exception e) { throw new RuntimeException("bagging classifier could not classify an instance", e); } } classification /= classifiers.size(); return (classification >= 0.5) ? 1.0 : 0.0; } public void buildClassifier(SetUniqueList traindataSet) throws Exception { classifiers = new LinkedList<>(); trainingData = new LinkedList<>(); for (Instances traindata : traindataSet) { Classifier classifier = setupClassifier(); classifier.buildClassifier(traindata); classifiers.add(classifier); trainingData.add(new Instances(traindata)); } } @Override public void buildClassifier(Instances traindata) throws Exception { classifiers = new LinkedList<>(); trainingData = new LinkedList<>(); final Classifier classifier = setupClassifier(); classifier.buildClassifier(traindata); classifiers.add(classifier); trainingData.add(new Instances(traindata)); } } }