package de.ugoe.cs.cpdp.training; import java.io.PrintStream; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.Set; import java.util.logging.Level; import org.apache.commons.io.output.NullOutputStream; import de.ugoe.cs.util.console.Console; import weka.classifiers.AbstractClassifier; import weka.classifiers.Classifier; import weka.clusterers.EM; import weka.core.DenseInstance; import weka.core.Instance; import weka.core.Instances; import weka.filters.Filter; import weka.filters.unsupervised.attribute.Remove; /** * WekaClusterTraining2 * * Currently supports only EM Clustering. * * 1. Cluster training data * 2. for each cluster train a classifier with training data from cluster * 3. match test data instance to a cluster, then classify with classifier from the cluster * * XML configuration: * * * * * */ public class WekaClusterTraining2 extends WekaBaseTraining2 implements ITrainingStrategy { private final TraindatasetCluster classifier = new TraindatasetCluster(); @Override public Classifier getClassifier() { return classifier; } @Override public void apply(Instances traindata) { PrintStream errStr = System.err; System.setErr(new PrintStream(new NullOutputStream())); try { classifier.buildClassifier(traindata); } catch (Exception e) { throw new RuntimeException(e); } finally { System.setErr(errStr); } } public class TraindatasetCluster extends AbstractClassifier { private static final long serialVersionUID = 1L; private EM clusterer = null; private HashMap cclassifier; private HashMap ctraindata; /** * Helper method that gives us a clean instance copy with * the values of the instancelist of the first parameter. * * @param instancelist with attributes * @param instance with only values * @return copy of the instance */ private Instance createInstance(Instances instances, Instance instance) { // attributes for feeding instance to classifier Set attributeNames = new HashSet<>(); for( int j=0; j(); ctraindata = new HashMap(); // 3. cluster data // use standard params for now clusterer = new EM(); // we can set options like so: //String[] params = {"-N", "100"}; //clusterer.setOptions(params); // set max num of clusters to train data size (although we do not want that) clusterer.setMaximumNumberOfClusters(train.size()); // build clusterer clusterer.buildClusterer(train); Instances ctrain = new Instances(train); // get train data per cluster int cnumber; for ( int j=0; j < ctrain.numInstances(); j++ ) { cnumber = clusterer.clusterInstance(ctrain.get(j)); // add training data to list of instances for this cluster number if ( !ctraindata.containsKey(cnumber) ) { ctraindata.put(cnumber, new Instances(traindata)); ctraindata.get(cnumber).delete(); } ctraindata.get(cnumber).add(traindata.get(j)); } // train one classifier per cluster, we get the cluster number from the training data Iterator clusternumber = ctraindata.keySet().iterator(); while ( clusternumber.hasNext() ) { cnumber = clusternumber.next(); cclassifier.put(cnumber,setupClassifier()); cclassifier.get(cnumber).buildClassifier(ctraindata.get(cnumber)); //Console.traceln(Level.INFO, String.format("classifier in cluster "+cnumber)); } } } }