package de.ugoe.cs.cpdp.training;
import java.io.PrintStream;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import java.util.logging.Level;
import org.apache.commons.io.output.NullOutputStream;
import de.ugoe.cs.util.console.Console;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.clusterers.EM;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Remove;
/**
* WekaClusterTraining2
*
* 1. Cluster traindata
* 2. for each cluster train a classifier with traindata from cluster
* 3. match testdata instance to a cluster, then classify with classifier from the cluster
*
* XML config:
*
*
*
*
*
*
* Questions:
* - how do we configure the clustering params?
*/
public class WekaClusterTraining2 extends WekaBaseTraining2 implements ITrainingStrategy {
private final TraindatasetCluster classifier = new TraindatasetCluster();
@Override
public Classifier getClassifier() {
return classifier;
}
@Override
public void apply(Instances traindata) {
PrintStream errStr = System.err;
System.setErr(new PrintStream(new NullOutputStream()));
try {
classifier.buildClassifier(traindata);
} catch (Exception e) {
throw new RuntimeException(e);
} finally {
System.setErr(errStr);
}
}
public class TraindatasetCluster extends AbstractClassifier {
private static final long serialVersionUID = 1L;
private EM clusterer = null;
private HashMap cclassifier = new HashMap();
private HashMap ctraindata = new HashMap();
private Instance createInstance(Instances instances, Instance instance) {
// attributes for feeding instance to classifier
Set attributeNames = new HashSet<>();
for( int j=0; j clusternumber = ctraindata.keySet().iterator();
while ( clusternumber.hasNext() ) {
cnumber = clusternumber.next();
cclassifier.put(cnumber,setupClassifier());
cclassifier.get(cnumber).buildClassifier(ctraindata.get(cnumber));
//Console.traceln(Level.INFO, String.format("classifier in cluster "+cnumber));
}
}
}
}