package de.ugoe.cs.cpdp.training; import java.io.PrintStream; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.Set; import java.util.logging.Level; import org.apache.commons.io.output.NullOutputStream; import de.ugoe.cs.util.console.Console; import weka.classifiers.AbstractClassifier; import weka.classifiers.Classifier; import weka.clusterers.EM; import weka.core.DenseInstance; import weka.core.Instance; import weka.core.Instances; import weka.filters.Filter; import weka.filters.unsupervised.attribute.Remove; /** * WekaClusterTraining2 * * 1. Cluster traindata * 2. for each cluster train a classifier with traindata from cluster * 3. match testdata instance to a cluster, then classify with classifier from the cluster * * XML config: * * * * * * * Questions: * - how do we configure the clustering params? */ public class WekaClusterTraining2 extends WekaBaseTraining2 implements ITrainingStrategy { private final TraindatasetCluster classifier = new TraindatasetCluster(); @Override public Classifier getClassifier() { return classifier; } @Override public void apply(Instances traindata) { PrintStream errStr = System.err; System.setErr(new PrintStream(new NullOutputStream())); try { classifier.buildClassifier(traindata); } catch (Exception e) { throw new RuntimeException(e); } finally { System.setErr(errStr); } } public class TraindatasetCluster extends AbstractClassifier { private static final long serialVersionUID = 1L; private EM clusterer = null; private HashMap cclassifier; private HashMap ctraindata; private Instance createInstance(Instances instances, Instance instance) { // attributes for feeding instance to classifier Set attributeNames = new HashSet<>(); for( int j=0; j(); ctraindata = new HashMap(); // use standard params for now clusterer = new EM(); //String[] params = {"-N", "100"}; //clusterer.setOptions(params); clusterer.buildClusterer(train); // set max num to traindata size clusterer.setMaximumNumberOfClusters(train.size()); // 4. get cluster membership of our traindata //AddCluster cfilter = new AddCluster(); //cfilter.setClusterer(clusterer); //cfilter.setInputFormat(train); //Instances ctrain = Filter.useFilter(train, cfilter); Instances ctrain = new Instances(train); // get traindata per cluster int cnumber; for ( int j=0; j < ctrain.numInstances(); j++ ) { // get the cluster number from the attributes, subract 1 because if we clusterInstance we get 0-n, and this is 1-n //cnumber = Integer.parseInt(ctrain.get(j).stringValue(ctrain.get(j).numAttributes()-1).replace("cluster", "")) - 1; cnumber = clusterer.clusterInstance(ctrain.get(j)); // add training data to list of instances for this cluster number if ( !ctraindata.containsKey(cnumber) ) { ctraindata.put(cnumber, new Instances(traindata)); ctraindata.get(cnumber).delete(); } ctraindata.get(cnumber).add(traindata.get(j)); } // Debug output //Console.traceln(Level.INFO, String.format("number of clusters: " + clusterer.numberOfClusters())); // train one classifier per cluster, we get the clusternumber from the traindata Iterator clusternumber = ctraindata.keySet().iterator(); while ( clusternumber.hasNext() ) { cnumber = clusternumber.next(); cclassifier.put(cnumber,setupClassifier()); cclassifier.get(cnumber).buildClassifier(ctraindata.get(cnumber)); //Console.traceln(Level.INFO, String.format("building classifier in cluster "+cnumber + " with " + ctraindata.get(cnumber).size() + " traindata instances")); } } } }