package de.ugoe.cs.cpdp.dataselection; import java.util.LinkedList; import java.util.List; import java.util.logging.Level; import org.apache.commons.collections4.list.SetUniqueList; import weka.clusterers.EM; import weka.core.Instances; import weka.filters.Filter; import weka.filters.unsupervised.attribute.AddCluster; import weka.filters.unsupervised.attribute.Remove; import de.ugoe.cs.util.console.Console; /** * Use in Config: * * Specify number of clusters * -N = Num Clusters * * * Try to determine the number of clusters: * -I 10 = max iterations * -X 5 = 5 folds for cross evaluation * -max = max number of clusters * * * Don't forget to add: * */ public class PointWiseEMClusterSelection implements IPointWiseDataselectionStrategy { private String[] params; @Override public void setParameter(String parameters) { params = parameters.split(" "); } /** * 1. Cluster the traindata * 2. for each instance in the testdata find the assigned cluster * 3. select only traindata from the clusters we found in our testdata * * @returns the selected training data */ @Override public Instances apply(Instances testdata, Instances traindata) { //final Attribute classAttribute = testdata.classAttribute(); final List selectedCluster = SetUniqueList.setUniqueList(new LinkedList()); // 1. copy train- and testdata Instances train = new Instances(traindata); Instances test = new Instances(testdata); Instances selected = null; try { // remove class attribute from traindata Remove filter = new Remove(); filter.setAttributeIndices("" + (train.classIndex() + 1)); filter.setInputFormat(train); train = Filter.useFilter(train, filter); Console.traceln(Level.INFO, String.format("starting clustering")); // 3. cluster data EM clusterer = new EM(); clusterer.setOptions(params); clusterer.buildClusterer(train); int numClusters = clusterer.getNumClusters(); if ( numClusters == -1) { Console.traceln(Level.INFO, String.format("we have unlimited clusters")); }else { Console.traceln(Level.INFO, String.format("we have: "+numClusters+" clusters")); } // 4. classify testdata, save cluster int // remove class attribute from testdata? Remove filter2 = new Remove(); filter2.setAttributeIndices("" + (test.classIndex() + 1)); filter2.setInputFormat(test); test = Filter.useFilter(test, filter2); int cnum; for( int i=0; i < test.numInstances(); i++ ) { cnum = ((EM)clusterer).clusterInstance(test.get(i)); // we dont want doubles (maybe use a hashset instead of list?) if ( !selectedCluster.contains(cnum) ) { selectedCluster.add(cnum); //Console.traceln(Level.INFO, String.format("assigned to cluster: "+cnum)); } } Console.traceln(Level.INFO, String.format("our testdata is in: "+selectedCluster.size()+" different clusters")); // 5. get cluster membership of our traindata AddCluster cfilter = new AddCluster(); cfilter.setClusterer(clusterer); cfilter.setInputFormat(train); Instances ctrain = Filter.useFilter(train, cfilter); // 6. for all traindata get the cluster int, if it is in our list of testdata cluster int add the traindata // of this cluster to our returned traindata int cnumber; selected = new Instances(traindata); selected.delete(); for ( int j=0; j < ctrain.numInstances(); j++ ) { // get the cluster number from the attributes cnumber = Integer.parseInt(ctrain.get(j).stringValue(ctrain.get(j).numAttributes()-1).replace("cluster", "")); //Console.traceln(Level.INFO, String.format("instance "+j+" is in cluster: "+cnumber)); if ( selectedCluster.contains(cnumber) ) { // this only works if the index does not change selected.add(traindata.get(j)); // check for differences, just one attribute, we are pretty sure the index does not change if ( traindata.get(j).value(3) != ctrain.get(j).value(3) ) { Console.traceln(Level.WARNING, String.format("we have a difference between train an ctrain!")); } } } Console.traceln(Level.INFO, String.format("that leaves us with: "+selected.numInstances()+" traindata instances from "+traindata.numInstances())); }catch( Exception e ) { Console.traceln(Level.WARNING, String.format("ERROR")); throw new RuntimeException("error in pointwise em", e); } return selected; } }