- Timestamp:
- 09/24/15 10:59:05 (9 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/CrossPare/src/de/ugoe/cs/cpdp/dataselection/PointWiseEMClusterSelection.java
r2 r41 1 // Copyright 2015 Georg-August-Universität Göttingen, Germany 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 1 15 package de.ugoe.cs.cpdp.dataselection; 2 16 … … 14 28 import de.ugoe.cs.util.console.Console; 15 29 16 17 30 /** 18 31 * Use in Config: 19 32 * 20 * Specify number of clusters 21 * -N = Num Clusters 22 * <pointwiseselector name="PointWiseEMClusterSelection" param="-N 10"/> 23 * 24 * Try to determine the number of clusters: 25 * -I 10 = max iterations 26 * -X 5 = 5 folds for cross evaluation 27 * -max = max number of clusters 28 * <pointwiseselector name="PointWiseEMClusterSelection" param="-I 10 -X 5 -max 300"/> 33 * Specify number of clusters -N = Num Clusters <pointwiseselector 34 * name="PointWiseEMClusterSelection" param="-N 10"/> 29 35 * 30 * Don't forget to add: 31 * <preprocessor name="Normalization" param=""/> 36 * Try to determine the number of clusters: -I 10 = max iterations -X 5 = 5 folds for cross 37 * evaluation -max = max number of clusters <pointwiseselector name="PointWiseEMClusterSelection" 38 * param="-I 10 -X 5 -max 300"/> 39 * 40 * Don't forget to add: <preprocessor name="Normalization" param=""/> 32 41 */ 33 42 public class PointWiseEMClusterSelection implements IPointWiseDataselectionStrategy { 34 35 private String[] params;36 37 @Override38 public void setParameter(String parameters) {39 params = parameters.split(" ");40 }41 43 42 43 /** 44 * 1. Cluster the traindata 45 * 2. for each instance in the testdata find the assigned cluster 46 * 3. select only traindata from the clusters we found in our testdata 47 * 48 * @returns the selected training data 49 */ 50 @Override 51 public Instances apply(Instances testdata, Instances traindata) { 52 //final Attribute classAttribute = testdata.classAttribute(); 53 54 final List<Integer> selectedCluster = SetUniqueList.setUniqueList(new LinkedList<Integer>()); 44 private String[] params; 55 45 56 // 1. copy train- and testdata 57 Instances train = new Instances(traindata); 58 Instances test = new Instances(testdata); 59 60 Instances selected = null; 61 62 try { 63 // remove class attribute from traindata 64 Remove filter = new Remove(); 65 filter.setAttributeIndices("" + (train.classIndex() + 1)); 66 filter.setInputFormat(train); 67 train = Filter.useFilter(train, filter); 68 69 Console.traceln(Level.INFO, String.format("starting clustering")); 70 71 // 3. cluster data 72 EM clusterer = new EM(); 73 clusterer.setOptions(params); 74 clusterer.buildClusterer(train); 75 int numClusters = clusterer.getNumClusters(); 76 if ( numClusters == -1) { 77 Console.traceln(Level.INFO, String.format("we have unlimited clusters")); 78 }else { 79 Console.traceln(Level.INFO, String.format("we have: "+numClusters+" clusters")); 80 } 81 82 83 // 4. classify testdata, save cluster int 84 85 // remove class attribute from testdata? 86 Remove filter2 = new Remove(); 87 filter2.setAttributeIndices("" + (test.classIndex() + 1)); 88 filter2.setInputFormat(test); 89 test = Filter.useFilter(test, filter2); 90 91 int cnum; 92 for( int i=0; i < test.numInstances(); i++ ) { 93 cnum = ((EM)clusterer).clusterInstance(test.get(i)); 46 @Override 47 public void setParameter(String parameters) { 48 params = parameters.split(" "); 49 } 94 50 95 // we dont want doubles (maybe use a hashset instead of list?) 96 if ( !selectedCluster.contains(cnum) ) { 97 selectedCluster.add(cnum); 98 //Console.traceln(Level.INFO, String.format("assigned to cluster: "+cnum)); 99 } 100 } 101 102 Console.traceln(Level.INFO, String.format("our testdata is in: "+selectedCluster.size()+" different clusters")); 103 104 // 5. get cluster membership of our traindata 105 AddCluster cfilter = new AddCluster(); 106 cfilter.setClusterer(clusterer); 107 cfilter.setInputFormat(train); 108 Instances ctrain = Filter.useFilter(train, cfilter); 109 110 111 // 6. for all traindata get the cluster int, if it is in our list of testdata cluster int add the traindata 112 // of this cluster to our returned traindata 113 int cnumber; 114 selected = new Instances(traindata); 115 selected.delete(); 116 117 for ( int j=0; j < ctrain.numInstances(); j++ ) { 118 // get the cluster number from the attributes 119 cnumber = Integer.parseInt(ctrain.get(j).stringValue(ctrain.get(j).numAttributes()-1).replace("cluster", "")); 120 121 //Console.traceln(Level.INFO, String.format("instance "+j+" is in cluster: "+cnumber)); 122 if ( selectedCluster.contains(cnumber) ) { 123 // this only works if the index does not change 124 selected.add(traindata.get(j)); 125 // check for differences, just one attribute, we are pretty sure the index does not change 126 if ( traindata.get(j).value(3) != ctrain.get(j).value(3) ) { 127 Console.traceln(Level.WARNING, String.format("we have a difference between train an ctrain!")); 128 } 129 } 130 } 131 132 Console.traceln(Level.INFO, String.format("that leaves us with: "+selected.numInstances()+" traindata instances from "+traindata.numInstances())); 133 }catch( Exception e ) { 134 Console.traceln(Level.WARNING, String.format("ERROR")); 135 throw new RuntimeException("error in pointwise em", e); 136 } 137 138 return selected; 139 } 51 /** 52 * 1. Cluster the traindata 2. for each instance in the testdata find the assigned cluster 3. 53 * select only traindata from the clusters we found in our testdata 54 * 55 * @returns the selected training data 56 */ 57 @Override 58 public Instances apply(Instances testdata, Instances traindata) { 59 // final Attribute classAttribute = testdata.classAttribute(); 60 61 final List<Integer> selectedCluster = 62 SetUniqueList.setUniqueList(new LinkedList<Integer>()); 63 64 // 1. copy train- and testdata 65 Instances train = new Instances(traindata); 66 Instances test = new Instances(testdata); 67 68 Instances selected = null; 69 70 try { 71 // remove class attribute from traindata 72 Remove filter = new Remove(); 73 filter.setAttributeIndices("" + (train.classIndex() + 1)); 74 filter.setInputFormat(train); 75 train = Filter.useFilter(train, filter); 76 77 Console.traceln(Level.INFO, String.format("starting clustering")); 78 79 // 3. cluster data 80 EM clusterer = new EM(); 81 clusterer.setOptions(params); 82 clusterer.buildClusterer(train); 83 int numClusters = clusterer.getNumClusters(); 84 if (numClusters == -1) { 85 Console.traceln(Level.INFO, String.format("we have unlimited clusters")); 86 } 87 else { 88 Console.traceln(Level.INFO, String.format("we have: " + numClusters + " clusters")); 89 } 90 91 // 4. classify testdata, save cluster int 92 93 // remove class attribute from testdata? 94 Remove filter2 = new Remove(); 95 filter2.setAttributeIndices("" + (test.classIndex() + 1)); 96 filter2.setInputFormat(test); 97 test = Filter.useFilter(test, filter2); 98 99 int cnum; 100 for (int i = 0; i < test.numInstances(); i++) { 101 cnum = ((EM) clusterer).clusterInstance(test.get(i)); 102 103 // we dont want doubles (maybe use a hashset instead of list?) 104 if (!selectedCluster.contains(cnum)) { 105 selectedCluster.add(cnum); 106 // Console.traceln(Level.INFO, String.format("assigned to cluster: "+cnum)); 107 } 108 } 109 110 Console.traceln(Level.INFO, 111 String.format("our testdata is in: " + selectedCluster.size() + 112 " different clusters")); 113 114 // 5. get cluster membership of our traindata 115 AddCluster cfilter = new AddCluster(); 116 cfilter.setClusterer(clusterer); 117 cfilter.setInputFormat(train); 118 Instances ctrain = Filter.useFilter(train, cfilter); 119 120 // 6. for all traindata get the cluster int, if it is in our list of testdata cluster 121 // int add the traindata 122 // of this cluster to our returned traindata 123 int cnumber; 124 selected = new Instances(traindata); 125 selected.delete(); 126 127 for (int j = 0; j < ctrain.numInstances(); j++) { 128 // get the cluster number from the attributes 129 cnumber = 130 Integer.parseInt(ctrain.get(j).stringValue(ctrain.get(j).numAttributes() - 1) 131 .replace("cluster", "")); 132 133 // Console.traceln(Level.INFO, 134 // String.format("instance "+j+" is in cluster: "+cnumber)); 135 if (selectedCluster.contains(cnumber)) { 136 // this only works if the index does not change 137 selected.add(traindata.get(j)); 138 // check for differences, just one attribute, we are pretty sure the index does 139 // not change 140 if (traindata.get(j).value(3) != ctrain.get(j).value(3)) { 141 Console.traceln(Level.WARNING, String 142 .format("we have a difference between train an ctrain!")); 143 } 144 } 145 } 146 147 Console.traceln(Level.INFO, 148 String.format("that leaves us with: " + selected.numInstances() + 149 " traindata instances from " + traindata.numInstances())); 150 } 151 catch (Exception e) { 152 Console.traceln(Level.WARNING, String.format("ERROR")); 153 throw new RuntimeException("error in pointwise em", e); 154 } 155 156 return selected; 157 } 140 158 141 159 }
Note: See TracChangeset
for help on using the changeset viewer.