| 1 | // Copyright 2015 Georg-August-Universität Göttingen, Germany |
|---|
| 2 | // |
|---|
| 3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
|---|
| 4 | // you may not use this file except in compliance with the License. |
|---|
| 5 | // You may obtain a copy of the License at |
|---|
| 6 | // |
|---|
| 7 | // http://www.apache.org/licenses/LICENSE-2.0 |
|---|
| 8 | // |
|---|
| 9 | // Unless required by applicable law or agreed to in writing, software |
|---|
| 10 | // distributed under the License is distributed on an "AS IS" BASIS, |
|---|
| 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|---|
| 12 | // See the License for the specific language governing permissions and |
|---|
| 13 | // limitations under the License. |
|---|
| 14 | |
|---|
| 15 | package de.ugoe.cs.cpdp.dataselection; |
|---|
| 16 | |
|---|
| 17 | import java.util.LinkedList; |
|---|
| 18 | import java.util.List; |
|---|
| 19 | import java.util.logging.Level; |
|---|
| 20 | |
|---|
| 21 | import org.apache.commons.collections4.list.SetUniqueList; |
|---|
| 22 | |
|---|
| 23 | import weka.clusterers.EM; |
|---|
| 24 | import weka.core.Instances; |
|---|
| 25 | import weka.filters.Filter; |
|---|
| 26 | import weka.filters.unsupervised.attribute.AddCluster; |
|---|
| 27 | import weka.filters.unsupervised.attribute.Remove; |
|---|
| 28 | import de.ugoe.cs.util.console.Console; |
|---|
| 29 | |
|---|
| 30 | /** |
|---|
| 31 | * Use in Config: |
|---|
| 32 | * |
|---|
| 33 | * Specify number of clusters -N = Num Clusters <pointwiseselector |
|---|
| 34 | * name="PointWiseEMClusterSelection" param="-N 10"/> |
|---|
| 35 | * |
|---|
| 36 | * Try to determine the number of clusters: -I 10 = max iterations -X 5 = 5 folds for cross |
|---|
| 37 | * evaluation -max = max number of clusters <pointwiseselector name="PointWiseEMClusterSelection" |
|---|
| 38 | * param="-I 10 -X 5 -max 300"/> |
|---|
| 39 | * |
|---|
| 40 | * Don't forget to add: <preprocessor name="Normalization" param=""/> |
|---|
| 41 | */ |
|---|
| 42 | public class PointWiseEMClusterSelection implements IPointWiseDataselectionStrategy { |
|---|
| 43 | |
|---|
| 44 | private String[] params; |
|---|
| 45 | |
|---|
| 46 | @Override |
|---|
| 47 | public void setParameter(String parameters) { |
|---|
| 48 | params = parameters.split(" "); |
|---|
| 49 | } |
|---|
| 50 | |
|---|
| 51 | /** |
|---|
| 52 | * 1. Cluster the traindata 2. for each instance in the testdata find the assigned cluster 3. |
|---|
| 53 | * select only traindata from the clusters we found in our testdata |
|---|
| 54 | * |
|---|
| 55 | * @returns the selected training data |
|---|
| 56 | */ |
|---|
| 57 | @Override |
|---|
| 58 | public Instances apply(Instances testdata, Instances traindata) { |
|---|
| 59 | // final Attribute classAttribute = testdata.classAttribute(); |
|---|
| 60 | |
|---|
| 61 | final List<Integer> selectedCluster = |
|---|
| 62 | SetUniqueList.setUniqueList(new LinkedList<Integer>()); |
|---|
| 63 | |
|---|
| 64 | // 1. copy train- and testdata |
|---|
| 65 | Instances train = new Instances(traindata); |
|---|
| 66 | Instances test = new Instances(testdata); |
|---|
| 67 | |
|---|
| 68 | Instances selected = null; |
|---|
| 69 | |
|---|
| 70 | try { |
|---|
| 71 | // remove class attribute from traindata |
|---|
| 72 | Remove filter = new Remove(); |
|---|
| 73 | filter.setAttributeIndices("" + (train.classIndex() + 1)); |
|---|
| 74 | filter.setInputFormat(train); |
|---|
| 75 | train = Filter.useFilter(train, filter); |
|---|
| 76 | |
|---|
| 77 | Console.traceln(Level.INFO, String.format("starting clustering")); |
|---|
| 78 | |
|---|
| 79 | // 3. cluster data |
|---|
| 80 | EM clusterer = new EM(); |
|---|
| 81 | clusterer.setOptions(params); |
|---|
| 82 | clusterer.buildClusterer(train); |
|---|
| 83 | int numClusters = clusterer.getNumClusters(); |
|---|
| 84 | if (numClusters == -1) { |
|---|
| 85 | Console.traceln(Level.INFO, String.format("we have unlimited clusters")); |
|---|
| 86 | } |
|---|
| 87 | else { |
|---|
| 88 | Console.traceln(Level.INFO, String.format("we have: " + numClusters + " clusters")); |
|---|
| 89 | } |
|---|
| 90 | |
|---|
| 91 | // 4. classify testdata, save cluster int |
|---|
| 92 | |
|---|
| 93 | // remove class attribute from testdata? |
|---|
| 94 | Remove filter2 = new Remove(); |
|---|
| 95 | filter2.setAttributeIndices("" + (test.classIndex() + 1)); |
|---|
| 96 | filter2.setInputFormat(test); |
|---|
| 97 | test = Filter.useFilter(test, filter2); |
|---|
| 98 | |
|---|
| 99 | int cnum; |
|---|
| 100 | for (int i = 0; i < test.numInstances(); i++) { |
|---|
| 101 | cnum = ((EM) clusterer).clusterInstance(test.get(i)); |
|---|
| 102 | |
|---|
| 103 | // we dont want doubles (maybe use a hashset instead of list?) |
|---|
| 104 | if (!selectedCluster.contains(cnum)) { |
|---|
| 105 | selectedCluster.add(cnum); |
|---|
| 106 | // Console.traceln(Level.INFO, String.format("assigned to cluster: "+cnum)); |
|---|
| 107 | } |
|---|
| 108 | } |
|---|
| 109 | |
|---|
| 110 | Console.traceln(Level.INFO, |
|---|
| 111 | String.format("our testdata is in: " + selectedCluster.size() + |
|---|
| 112 | " different clusters")); |
|---|
| 113 | |
|---|
| 114 | // 5. get cluster membership of our traindata |
|---|
| 115 | AddCluster cfilter = new AddCluster(); |
|---|
| 116 | cfilter.setClusterer(clusterer); |
|---|
| 117 | cfilter.setInputFormat(train); |
|---|
| 118 | Instances ctrain = Filter.useFilter(train, cfilter); |
|---|
| 119 | |
|---|
| 120 | // 6. for all traindata get the cluster int, if it is in our list of testdata cluster |
|---|
| 121 | // int add the traindata |
|---|
| 122 | // of this cluster to our returned traindata |
|---|
| 123 | int cnumber; |
|---|
| 124 | selected = new Instances(traindata); |
|---|
| 125 | selected.delete(); |
|---|
| 126 | |
|---|
| 127 | for (int j = 0; j < ctrain.numInstances(); j++) { |
|---|
| 128 | // get the cluster number from the attributes |
|---|
| 129 | cnumber = |
|---|
| 130 | Integer.parseInt(ctrain.get(j).stringValue(ctrain.get(j).numAttributes() - 1) |
|---|
| 131 | .replace("cluster", "")); |
|---|
| 132 | |
|---|
| 133 | // Console.traceln(Level.INFO, |
|---|
| 134 | // String.format("instance "+j+" is in cluster: "+cnumber)); |
|---|
| 135 | if (selectedCluster.contains(cnumber)) { |
|---|
| 136 | // this only works if the index does not change |
|---|
| 137 | selected.add(traindata.get(j)); |
|---|
| 138 | // check for differences, just one attribute, we are pretty sure the index does |
|---|
| 139 | // not change |
|---|
| 140 | if (traindata.get(j).value(3) != ctrain.get(j).value(3)) { |
|---|
| 141 | Console.traceln(Level.WARNING, String |
|---|
| 142 | .format("we have a difference between train an ctrain!")); |
|---|
| 143 | } |
|---|
| 144 | } |
|---|
| 145 | } |
|---|
| 146 | |
|---|
| 147 | Console.traceln(Level.INFO, |
|---|
| 148 | String.format("that leaves us with: " + selected.numInstances() + |
|---|
| 149 | " traindata instances from " + traindata.numInstances())); |
|---|
| 150 | } |
|---|
| 151 | catch (Exception e) { |
|---|
| 152 | Console.traceln(Level.WARNING, String.format("ERROR")); |
|---|
| 153 | throw new RuntimeException("error in pointwise em", e); |
|---|
| 154 | } |
|---|
| 155 | |
|---|
| 156 | return selected; |
|---|
| 157 | } |
|---|
| 158 | |
|---|
| 159 | } |
|---|