package de.ugoe.cs.cpdp.training;
import java.io.PrintStream;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map.Entry;
import java.util.Set;
import java.util.logging.Level;
import org.apache.commons.io.output.NullOutputStream;
import de.ugoe.cs.util.console.Console;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.clusterers.EM;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Remove;
/**
* WekaLocalEMTraining
*
* Local Trainer with EM Clustering for data partitioning.
* Currently supports only EM Clustering.
*
* 1. Cluster training data
* 2. for each cluster train a classifier with training data from cluster
* 3. match test data instance to a cluster, then classify with classifier from the cluster
*
* XML configuration:
*
*
*
*
*
*/
public class WekaLocalEMTraining extends WekaBaseTraining implements ITrainingStrategy {
private final TraindatasetCluster classifier = new TraindatasetCluster();
@Override
public Classifier getClassifier() {
return classifier;
}
@Override
public void apply(Instances traindata) {
PrintStream errStr = System.err;
System.setErr(new PrintStream(new NullOutputStream()));
try {
classifier.buildClassifier(traindata);
} catch (Exception e) {
throw new RuntimeException(e);
} finally {
System.setErr(errStr);
}
}
public class TraindatasetCluster extends AbstractClassifier {
private static final long serialVersionUID = 1L;
private EM clusterer = null;
private HashMap cclassifier;
private HashMap ctraindata;
/**
* Helper method that gives us a clean instance copy with
* the values of the instancelist of the first parameter.
*
* @param instancelist with attributes
* @param instance with only values
* @return copy of the instance
*/
private Instance createInstance(Instances instances, Instance instance) {
// attributes for feeding instance to classifier
Set attributeNames = new HashSet<>();
for( int j=0; j();
ctraindata = new HashMap();
Instances ctrain;
int maxNumClusters = train.size();
boolean sufficientInstancesInEachCluster;
do { // while(onlyTarget)
sufficientInstancesInEachCluster = true;
clusterer = new EM();
clusterer.setMaximumNumberOfClusters(maxNumClusters);
clusterer.buildClusterer(train);
// 4. get cluster membership of our traindata
//AddCluster cfilter = new AddCluster();
//cfilter.setClusterer(clusterer);
//cfilter.setInputFormat(train);
//Instances ctrain = Filter.useFilter(train, cfilter);
ctrain = new Instances(train);
ctraindata = new HashMap<>();
// get traindata per cluster
for ( int j=0; j < ctrain.numInstances(); j++ ) {
// get the cluster number from the attributes, subract 1 because if we clusterInstance we get 0-n, and this is 1-n
//cnumber = Integer.parseInt(ctrain.get(j).stringValue(ctrain.get(j).numAttributes()-1).replace("cluster", "")) - 1;
int cnumber = clusterer.clusterInstance(ctrain.get(j));
// add training data to list of instances for this cluster number
if ( !ctraindata.containsKey(cnumber) ) {
ctraindata.put(cnumber, new Instances(traindata));
ctraindata.get(cnumber).delete();
}
ctraindata.get(cnumber).add(traindata.get(j));
}
for( Entry entry : ctraindata.entrySet() ) {
Instances instances = entry.getValue();
int[] counts = instances.attributeStats(instances.classIndex()).nominalCounts;
for( int count : counts ) {
sufficientInstancesInEachCluster &= count>0;
}
sufficientInstancesInEachCluster &= instances.numInstances()>=5;
}
maxNumClusters = clusterer.numberOfClusters()-1;
} while(!sufficientInstancesInEachCluster);
// train one classifier per cluster, we get the cluster number from the training data
Iterator clusternumber = ctraindata.keySet().iterator();
while ( clusternumber.hasNext() ) {
int cnumber = clusternumber.next();
cclassifier.put(cnumber,setupClassifier());
cclassifier.get(cnumber).buildClassifier(ctraindata.get(cnumber));
//Console.traceln(Level.INFO, String.format("classifier in cluster "+cnumber));
}
}
}
}