| 1 | package de.ugoe.cs.cpdp.dataselection; |
|---|
| 2 | |
|---|
| 3 | import java.util.LinkedList; |
|---|
| 4 | import java.util.List; |
|---|
| 5 | import java.util.logging.Level; |
|---|
| 6 | |
|---|
| 7 | import org.apache.commons.collections4.list.SetUniqueList; |
|---|
| 8 | |
|---|
| 9 | import weka.clusterers.EM; |
|---|
| 10 | import weka.core.Instances; |
|---|
| 11 | import weka.filters.Filter; |
|---|
| 12 | import weka.filters.unsupervised.attribute.AddCluster; |
|---|
| 13 | import weka.filters.unsupervised.attribute.Remove; |
|---|
| 14 | import de.ugoe.cs.util.console.Console; |
|---|
| 15 | |
|---|
| 16 | |
|---|
| 17 | /** |
|---|
| 18 | * Use in Config: |
|---|
| 19 | * |
|---|
| 20 | * Specify number of clusters |
|---|
| 21 | * -N = Num Clusters |
|---|
| 22 | * <pointwiseselector name="PointWiseEMClusterSelection" param="-N 10"/> |
|---|
| 23 | * |
|---|
| 24 | * Try to determine the number of clusters: |
|---|
| 25 | * -I 10 = max iterations |
|---|
| 26 | * -X 5 = 5 folds for cross evaluation |
|---|
| 27 | * -max = max number of clusters |
|---|
| 28 | * <pointwiseselector name="PointWiseEMClusterSelection" param="-I 10 -X 5 -max 300"/> |
|---|
| 29 | * |
|---|
| 30 | * Don't forget to add: |
|---|
| 31 | * <preprocessor name="Normalization" param=""/> |
|---|
| 32 | */ |
|---|
| 33 | public class PointWiseEMClusterSelection implements IPointWiseDataselectionStrategy { |
|---|
| 34 | |
|---|
| 35 | private String[] params; |
|---|
| 36 | |
|---|
| 37 | @Override |
|---|
| 38 | public void setParameter(String parameters) { |
|---|
| 39 | params = parameters.split(" "); |
|---|
| 40 | } |
|---|
| 41 | |
|---|
| 42 | |
|---|
| 43 | /** |
|---|
| 44 | * 1. Cluster the traindata |
|---|
| 45 | * 2. for each instance in the testdata find the assigned cluster |
|---|
| 46 | * 3. select only traindata from the clusters we found in our testdata |
|---|
| 47 | * |
|---|
| 48 | * @returns the selected training data |
|---|
| 49 | */ |
|---|
| 50 | @Override |
|---|
| 51 | public Instances apply(Instances testdata, Instances traindata) { |
|---|
| 52 | //final Attribute classAttribute = testdata.classAttribute(); |
|---|
| 53 | |
|---|
| 54 | final List<Integer> selectedCluster = SetUniqueList.setUniqueList(new LinkedList<Integer>()); |
|---|
| 55 | |
|---|
| 56 | // 1. copy train- and testdata |
|---|
| 57 | Instances train = new Instances(traindata); |
|---|
| 58 | Instances test = new Instances(testdata); |
|---|
| 59 | |
|---|
| 60 | Instances selected = null; |
|---|
| 61 | |
|---|
| 62 | try { |
|---|
| 63 | // remove class attribute from traindata |
|---|
| 64 | Remove filter = new Remove(); |
|---|
| 65 | filter.setAttributeIndices("" + (train.classIndex() + 1)); |
|---|
| 66 | filter.setInputFormat(train); |
|---|
| 67 | train = Filter.useFilter(train, filter); |
|---|
| 68 | |
|---|
| 69 | Console.traceln(Level.INFO, String.format("starting clustering")); |
|---|
| 70 | |
|---|
| 71 | // 3. cluster data |
|---|
| 72 | EM clusterer = new EM(); |
|---|
| 73 | clusterer.setOptions(params); |
|---|
| 74 | clusterer.buildClusterer(train); |
|---|
| 75 | int numClusters = clusterer.getNumClusters(); |
|---|
| 76 | if ( numClusters == -1) { |
|---|
| 77 | Console.traceln(Level.INFO, String.format("we have unlimited clusters")); |
|---|
| 78 | }else { |
|---|
| 79 | Console.traceln(Level.INFO, String.format("we have: "+numClusters+" clusters")); |
|---|
| 80 | } |
|---|
| 81 | |
|---|
| 82 | |
|---|
| 83 | // 4. classify testdata, save cluster int |
|---|
| 84 | |
|---|
| 85 | // remove class attribute from testdata? |
|---|
| 86 | Remove filter2 = new Remove(); |
|---|
| 87 | filter2.setAttributeIndices("" + (test.classIndex() + 1)); |
|---|
| 88 | filter2.setInputFormat(test); |
|---|
| 89 | test = Filter.useFilter(test, filter2); |
|---|
| 90 | |
|---|
| 91 | int cnum; |
|---|
| 92 | for( int i=0; i < test.numInstances(); i++ ) { |
|---|
| 93 | cnum = ((EM)clusterer).clusterInstance(test.get(i)); |
|---|
| 94 | |
|---|
| 95 | // we dont want doubles (maybe use a hashset instead of list?) |
|---|
| 96 | if ( !selectedCluster.contains(cnum) ) { |
|---|
| 97 | selectedCluster.add(cnum); |
|---|
| 98 | //Console.traceln(Level.INFO, String.format("assigned to cluster: "+cnum)); |
|---|
| 99 | } |
|---|
| 100 | } |
|---|
| 101 | |
|---|
| 102 | Console.traceln(Level.INFO, String.format("our testdata is in: "+selectedCluster.size()+" different clusters")); |
|---|
| 103 | |
|---|
| 104 | // 5. get cluster membership of our traindata |
|---|
| 105 | AddCluster cfilter = new AddCluster(); |
|---|
| 106 | cfilter.setClusterer(clusterer); |
|---|
| 107 | cfilter.setInputFormat(train); |
|---|
| 108 | Instances ctrain = Filter.useFilter(train, cfilter); |
|---|
| 109 | |
|---|
| 110 | |
|---|
| 111 | // 6. for all traindata get the cluster int, if it is in our list of testdata cluster int add the traindata |
|---|
| 112 | // of this cluster to our returned traindata |
|---|
| 113 | int cnumber; |
|---|
| 114 | selected = new Instances(traindata); |
|---|
| 115 | selected.delete(); |
|---|
| 116 | |
|---|
| 117 | for ( int j=0; j < ctrain.numInstances(); j++ ) { |
|---|
| 118 | // get the cluster number from the attributes |
|---|
| 119 | cnumber = Integer.parseInt(ctrain.get(j).stringValue(ctrain.get(j).numAttributes()-1).replace("cluster", "")); |
|---|
| 120 | |
|---|
| 121 | //Console.traceln(Level.INFO, String.format("instance "+j+" is in cluster: "+cnumber)); |
|---|
| 122 | if ( selectedCluster.contains(cnumber) ) { |
|---|
| 123 | // this only works if the index does not change |
|---|
| 124 | selected.add(traindata.get(j)); |
|---|
| 125 | // check for differences, just one attribute, we are pretty sure the index does not change |
|---|
| 126 | if ( traindata.get(j).value(3) != ctrain.get(j).value(3) ) { |
|---|
| 127 | Console.traceln(Level.WARNING, String.format("we have a difference between train an ctrain!")); |
|---|
| 128 | } |
|---|
| 129 | } |
|---|
| 130 | } |
|---|
| 131 | |
|---|
| 132 | Console.traceln(Level.INFO, String.format("that leaves us with: "+selected.numInstances()+" traindata instances from "+traindata.numInstances())); |
|---|
| 133 | }catch( Exception e ) { |
|---|
| 134 | Console.traceln(Level.WARNING, String.format("ERROR")); |
|---|
| 135 | throw new RuntimeException("error in pointwise em", e); |
|---|
| 136 | } |
|---|
| 137 | |
|---|
| 138 | return selected; |
|---|
| 139 | } |
|---|
| 140 | |
|---|
| 141 | } |
|---|