package de.ugoe.cs.cpdp.training;
import java.io.PrintStream;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import java.util.logging.Level;
import org.apache.commons.io.output.NullOutputStream;
import de.ugoe.cs.util.console.Console;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.clusterers.EM;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Remove;
/**
* WekaClusterTraining2
*
* 1. Cluster traindata
* 2. for each cluster train a classifier with traindata from cluster
* 3. match testdata instance to a cluster, then classify with classifier from the cluster
*
* XML config:
*
*
*
*
*
*
* Questions:
* - how do we configure the clustering params?
*/
public class WekaClusterTraining2 extends WekaBaseTraining2 implements ITrainingStrategy {
private final TraindatasetCluster classifier = new TraindatasetCluster();
@Override
public Classifier getClassifier() {
return classifier;
}
@Override
public void apply(Instances traindata) {
PrintStream errStr = System.err;
System.setErr(new PrintStream(new NullOutputStream()));
try {
classifier.buildClassifier(traindata);
} catch (Exception e) {
throw new RuntimeException(e);
} finally {
System.setErr(errStr);
}
}
public class TraindatasetCluster extends AbstractClassifier {
private static final long serialVersionUID = 1L;
private EM clusterer = null;
private HashMap cclassifier;
private HashMap ctraindata;
private Instance createInstance(Instances instances, Instance instance) {
// attributes for feeding instance to classifier
Set attributeNames = new HashSet<>();
for( int j=0; j();
ctraindata = new HashMap();
// use standard params for now
clusterer = new EM();
//String[] params = {"-N", "100"};
//clusterer.setOptions(params);
clusterer.buildClusterer(train);
// set max num to traindata size
clusterer.setMaximumNumberOfClusters(train.size());
// 4. get cluster membership of our traindata
//AddCluster cfilter = new AddCluster();
//cfilter.setClusterer(clusterer);
//cfilter.setInputFormat(train);
//Instances ctrain = Filter.useFilter(train, cfilter);
Instances ctrain = new Instances(train);
// get traindata per cluster
int cnumber;
for ( int j=0; j < ctrain.numInstances(); j++ ) {
// get the cluster number from the attributes, subract 1 because if we clusterInstance we get 0-n, and this is 1-n
//cnumber = Integer.parseInt(ctrain.get(j).stringValue(ctrain.get(j).numAttributes()-1).replace("cluster", "")) - 1;
cnumber = clusterer.clusterInstance(ctrain.get(j));
// add training data to list of instances for this cluster number
if ( !ctraindata.containsKey(cnumber) ) {
ctraindata.put(cnumber, new Instances(traindata));
ctraindata.get(cnumber).delete();
}
ctraindata.get(cnumber).add(traindata.get(j));
}
// Debug output
//Console.traceln(Level.INFO, String.format("number of clusters: " + clusterer.numberOfClusters()));
// train one classifier per cluster, we get the clusternumber from the traindata
Iterator clusternumber = ctraindata.keySet().iterator();
while ( clusternumber.hasNext() ) {
cnumber = clusternumber.next();
cclassifier.put(cnumber,setupClassifier());
cclassifier.get(cnumber).buildClassifier(ctraindata.get(cnumber));
//Console.traceln(Level.INFO, String.format("building classifier in cluster "+cnumber + " with " + ctraindata.get(cnumber).size() + " traindata instances"));
}
}
}
}