[86] | 1 | // Copyright 2015 Georg-August-Universität Göttingen, Germany |
---|
[41] | 2 | // |
---|
| 3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
---|
| 4 | // you may not use this file except in compliance with the License. |
---|
| 5 | // You may obtain a copy of the License at |
---|
| 6 | // |
---|
| 7 | // http://www.apache.org/licenses/LICENSE-2.0 |
---|
| 8 | // |
---|
| 9 | // Unless required by applicable law or agreed to in writing, software |
---|
| 10 | // distributed under the License is distributed on an "AS IS" BASIS, |
---|
| 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
---|
| 12 | // See the License for the specific language governing permissions and |
---|
| 13 | // limitations under the License. |
---|
| 14 | |
---|
[2] | 15 | package de.ugoe.cs.cpdp.dataselection; |
---|
| 16 | |
---|
| 17 | import java.util.LinkedList; |
---|
| 18 | import java.util.List; |
---|
| 19 | import java.util.logging.Level; |
---|
| 20 | |
---|
| 21 | import org.apache.commons.collections4.list.SetUniqueList; |
---|
| 22 | |
---|
| 23 | import weka.clusterers.EM; |
---|
| 24 | import weka.core.Instances; |
---|
| 25 | import weka.filters.Filter; |
---|
| 26 | import weka.filters.unsupervised.attribute.AddCluster; |
---|
| 27 | import weka.filters.unsupervised.attribute.Remove; |
---|
| 28 | import de.ugoe.cs.util.console.Console; |
---|
| 29 | |
---|
| 30 | /** |
---|
| 31 | * Use in Config: |
---|
| 32 | * |
---|
[135] | 33 | * Specify number of clusters -N = Num Clusters |
---|
| 34 | * <pointwiseselector name="PointWiseEMClusterSelection" param="-N 10"/> |
---|
[2] | 35 | * |
---|
[41] | 36 | * Try to determine the number of clusters: -I 10 = max iterations -X 5 = 5 folds for cross |
---|
[135] | 37 | * evaluation -max = max number of clusters |
---|
| 38 | * <pointwiseselector name="PointWiseEMClusterSelection" param="-I 10 -X 5 -max 300"/> |
---|
[41] | 39 | * |
---|
| 40 | * Don't forget to add: <preprocessor name="Normalization" param=""/> |
---|
[2] | 41 | */ |
---|
| 42 | public class PointWiseEMClusterSelection implements IPointWiseDataselectionStrategy { |
---|
| 43 | |
---|
[135] | 44 | /** |
---|
| 45 | * paramters passed to the selection |
---|
| 46 | */ |
---|
[41] | 47 | private String[] params; |
---|
[2] | 48 | |
---|
[135] | 49 | /* |
---|
| 50 | * (non-Javadoc) |
---|
| 51 | * |
---|
| 52 | * @see de.ugoe.cs.cpdp.IParameterizable#setParameter(java.lang.String) |
---|
| 53 | */ |
---|
[41] | 54 | @Override |
---|
| 55 | public void setParameter(String parameters) { |
---|
| 56 | params = parameters.split(" "); |
---|
| 57 | } |
---|
[2] | 58 | |
---|
[41] | 59 | /** |
---|
| 60 | * 1. Cluster the traindata 2. for each instance in the testdata find the assigned cluster 3. |
---|
| 61 | * select only traindata from the clusters we found in our testdata |
---|
| 62 | * |
---|
| 63 | * @returns the selected training data |
---|
| 64 | */ |
---|
| 65 | @Override |
---|
| 66 | public Instances apply(Instances testdata, Instances traindata) { |
---|
| 67 | // final Attribute classAttribute = testdata.classAttribute(); |
---|
[2] | 68 | |
---|
[41] | 69 | final List<Integer> selectedCluster = |
---|
| 70 | SetUniqueList.setUniqueList(new LinkedList<Integer>()); |
---|
| 71 | |
---|
| 72 | // 1. copy train- and testdata |
---|
| 73 | Instances train = new Instances(traindata); |
---|
| 74 | Instances test = new Instances(testdata); |
---|
| 75 | |
---|
| 76 | Instances selected = null; |
---|
| 77 | |
---|
| 78 | try { |
---|
| 79 | // remove class attribute from traindata |
---|
| 80 | Remove filter = new Remove(); |
---|
| 81 | filter.setAttributeIndices("" + (train.classIndex() + 1)); |
---|
| 82 | filter.setInputFormat(train); |
---|
| 83 | train = Filter.useFilter(train, filter); |
---|
| 84 | |
---|
| 85 | Console.traceln(Level.INFO, String.format("starting clustering")); |
---|
| 86 | |
---|
| 87 | // 3. cluster data |
---|
| 88 | EM clusterer = new EM(); |
---|
| 89 | clusterer.setOptions(params); |
---|
| 90 | clusterer.buildClusterer(train); |
---|
| 91 | int numClusters = clusterer.getNumClusters(); |
---|
| 92 | if (numClusters == -1) { |
---|
| 93 | Console.traceln(Level.INFO, String.format("we have unlimited clusters")); |
---|
| 94 | } |
---|
| 95 | else { |
---|
| 96 | Console.traceln(Level.INFO, String.format("we have: " + numClusters + " clusters")); |
---|
| 97 | } |
---|
| 98 | |
---|
| 99 | // 4. classify testdata, save cluster int |
---|
| 100 | |
---|
| 101 | // remove class attribute from testdata? |
---|
| 102 | Remove filter2 = new Remove(); |
---|
| 103 | filter2.setAttributeIndices("" + (test.classIndex() + 1)); |
---|
| 104 | filter2.setInputFormat(test); |
---|
| 105 | test = Filter.useFilter(test, filter2); |
---|
| 106 | |
---|
| 107 | int cnum; |
---|
| 108 | for (int i = 0; i < test.numInstances(); i++) { |
---|
| 109 | cnum = ((EM) clusterer).clusterInstance(test.get(i)); |
---|
| 110 | |
---|
| 111 | // we dont want doubles (maybe use a hashset instead of list?) |
---|
| 112 | if (!selectedCluster.contains(cnum)) { |
---|
| 113 | selectedCluster.add(cnum); |
---|
| 114 | // Console.traceln(Level.INFO, String.format("assigned to cluster: "+cnum)); |
---|
| 115 | } |
---|
| 116 | } |
---|
| 117 | |
---|
[135] | 118 | Console.traceln(Level.INFO, String |
---|
| 119 | .format("our testdata is in: " + selectedCluster.size() + " different clusters")); |
---|
[41] | 120 | |
---|
| 121 | // 5. get cluster membership of our traindata |
---|
| 122 | AddCluster cfilter = new AddCluster(); |
---|
| 123 | cfilter.setClusterer(clusterer); |
---|
| 124 | cfilter.setInputFormat(train); |
---|
| 125 | Instances ctrain = Filter.useFilter(train, cfilter); |
---|
| 126 | |
---|
| 127 | // 6. for all traindata get the cluster int, if it is in our list of testdata cluster |
---|
| 128 | // int add the traindata |
---|
| 129 | // of this cluster to our returned traindata |
---|
| 130 | int cnumber; |
---|
| 131 | selected = new Instances(traindata); |
---|
| 132 | selected.delete(); |
---|
| 133 | |
---|
| 134 | for (int j = 0; j < ctrain.numInstances(); j++) { |
---|
| 135 | // get the cluster number from the attributes |
---|
[135] | 136 | cnumber = Integer.parseInt(ctrain.get(j) |
---|
| 137 | .stringValue(ctrain.get(j).numAttributes() - 1).replace("cluster", "")); |
---|
[41] | 138 | |
---|
| 139 | // Console.traceln(Level.INFO, |
---|
| 140 | // String.format("instance "+j+" is in cluster: "+cnumber)); |
---|
| 141 | if (selectedCluster.contains(cnumber)) { |
---|
| 142 | // this only works if the index does not change |
---|
| 143 | selected.add(traindata.get(j)); |
---|
| 144 | // check for differences, just one attribute, we are pretty sure the index does |
---|
| 145 | // not change |
---|
| 146 | if (traindata.get(j).value(3) != ctrain.get(j).value(3)) { |
---|
| 147 | Console.traceln(Level.WARNING, String |
---|
| 148 | .format("we have a difference between train an ctrain!")); |
---|
| 149 | } |
---|
| 150 | } |
---|
| 151 | } |
---|
| 152 | |
---|
[135] | 153 | Console.traceln(Level.INFO, String.format("that leaves us with: " + |
---|
| 154 | selected.numInstances() + " traindata instances from " + traindata.numInstances())); |
---|
[41] | 155 | } |
---|
| 156 | catch (Exception e) { |
---|
| 157 | Console.traceln(Level.WARNING, String.format("ERROR")); |
---|
| 158 | throw new RuntimeException("error in pointwise em", e); |
---|
| 159 | } |
---|
| 160 | |
---|
| 161 | return selected; |
---|
| 162 | } |
---|
| 163 | |
---|
[2] | 164 | } |
---|