- Timestamp:
- 09/24/15 10:59:05 (9 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaLocalEMTraining.java
r25 r41 1 // Copyright 2015 Georg-August-Universität Göttingen, Germany 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 1 15 package de.ugoe.cs.cpdp.training; 2 16 … … 24 38 * WekaLocalEMTraining 25 39 * 26 * Local Trainer with EM Clustering for data partitioning. 27 * Currently supports only EM Clustering. 28 * 29 * 1. Cluster training data 30 * 2. for each cluster train a classifier with training data from cluster 40 * Local Trainer with EM Clustering for data partitioning. Currently supports only EM Clustering. 41 * 42 * 1. Cluster training data 2. for each cluster train a classifier with training data from cluster 31 43 * 3. match test data instance to a cluster, then classify with classifier from the cluster 32 44 * 33 * XML configuration: 34 * <!-- because of clustering --> 35 * <preprocessor name="Normalization" param=""/> 36 * 37 * <!-- cluster trainer --> 38 * <trainer name="WekaLocalEMTraining" param="NaiveBayes weka.classifiers.bayes.NaiveBayes" /> 45 * XML configuration: <!-- because of clustering --> <preprocessor name="Normalization" param=""/> 46 * 47 * <!-- cluster trainer --> <trainer name="WekaLocalEMTraining" 48 * param="NaiveBayes weka.classifiers.bayes.NaiveBayes" /> 39 49 */ 40 50 public class WekaLocalEMTraining extends WekaBaseTraining implements ITrainingStrategy { 41 51 42 private final TraindatasetCluster classifier = new TraindatasetCluster(); 43 44 @Override 45 public Classifier getClassifier() { 46 return classifier; 47 } 48 49 @Override 50 public void apply(Instances traindata) { 51 PrintStream errStr = System.err; 52 System.setErr(new PrintStream(new NullOutputStream())); 53 try { 54 classifier.buildClassifier(traindata); 55 } catch (Exception e) { 56 throw new RuntimeException(e); 57 } finally { 58 System.setErr(errStr); 59 } 60 } 61 62 63 public class TraindatasetCluster extends AbstractClassifier { 64 65 private static final long serialVersionUID = 1L; 66 67 private EM clusterer = null; 68 69 private HashMap<Integer, Classifier> cclassifier; 70 private HashMap<Integer, Instances> ctraindata; 71 72 73 /** 74 * Helper method that gives us a clean instance copy with 75 * the values of the instancelist of the first parameter. 76 * 77 * @param instancelist with attributes 78 * @param instance with only values 79 * @return copy of the instance 80 */ 81 private Instance createInstance(Instances instances, Instance instance) { 82 // attributes for feeding instance to classifier 83 Set<String> attributeNames = new HashSet<>(); 84 for( int j=0; j<instances.numAttributes(); j++ ) { 85 attributeNames.add(instances.attribute(j).name()); 86 } 87 88 double[] values = new double[instances.numAttributes()]; 89 int index = 0; 90 for( int j=0; j<instance.numAttributes(); j++ ) { 91 if( attributeNames.contains(instance.attribute(j).name())) { 92 values[index] = instance.value(j); 93 index++; 94 } 95 } 96 97 Instances tmp = new Instances(instances); 98 tmp.clear(); 99 Instance instCopy = new DenseInstance(instance.weight(), values); 100 instCopy.setDataset(tmp); 101 102 return instCopy; 103 } 104 105 @Override 106 public double classifyInstance(Instance instance) { 107 double ret = 0; 108 try { 109 // 1. copy the instance (keep the class attribute) 110 Instances traindata = ctraindata.get(0); 111 Instance classInstance = createInstance(traindata, instance); 112 113 // 2. remove class attribute before clustering 114 Remove filter = new Remove(); 115 filter.setAttributeIndices("" + (traindata.classIndex() + 1)); 116 filter.setInputFormat(traindata); 117 traindata = Filter.useFilter(traindata, filter); 118 119 // 3. copy the instance (without the class attribute) for clustering 120 Instance clusterInstance = createInstance(traindata, instance); 121 122 // 4. match instance without class attribute to a cluster number 123 int cnum = clusterer.clusterInstance(clusterInstance); 124 125 // 5. classify instance with class attribute to the classifier of that cluster number 126 ret = cclassifier.get(cnum).classifyInstance(classInstance); 127 128 }catch( Exception e ) { 129 Console.traceln(Level.INFO, String.format("ERROR matching instance to cluster!")); 130 throw new RuntimeException(e); 131 } 132 return ret; 133 } 134 135 @Override 136 public void buildClassifier(Instances traindata) throws Exception { 137 138 // 1. copy training data 139 Instances train = new Instances(traindata); 140 141 // 2. remove class attribute for clustering 142 Remove filter = new Remove(); 143 filter.setAttributeIndices("" + (train.classIndex() + 1)); 144 filter.setInputFormat(train); 145 train = Filter.useFilter(train, filter); 146 147 // new objects 148 cclassifier = new HashMap<Integer, Classifier>(); 149 ctraindata = new HashMap<Integer, Instances>(); 150 151 Instances ctrain; 152 int maxNumClusters = train.size(); 153 boolean sufficientInstancesInEachCluster; 154 do { // while(onlyTarget) 155 sufficientInstancesInEachCluster = true; 156 clusterer = new EM(); 157 clusterer.setMaximumNumberOfClusters(maxNumClusters); 158 clusterer.buildClusterer(train); 159 160 // 4. get cluster membership of our traindata 161 //AddCluster cfilter = new AddCluster(); 162 //cfilter.setClusterer(clusterer); 163 //cfilter.setInputFormat(train); 164 //Instances ctrain = Filter.useFilter(train, cfilter); 165 166 ctrain = new Instances(train); 167 ctraindata = new HashMap<>(); 168 169 // get traindata per cluster 170 for ( int j=0; j < ctrain.numInstances(); j++ ) { 171 // get the cluster number from the attributes, subract 1 because if we clusterInstance we get 0-n, and this is 1-n 172 //cnumber = Integer.parseInt(ctrain.get(j).stringValue(ctrain.get(j).numAttributes()-1).replace("cluster", "")) - 1; 173 174 int cnumber = clusterer.clusterInstance(ctrain.get(j)); 175 // add training data to list of instances for this cluster number 176 if ( !ctraindata.containsKey(cnumber) ) { 177 ctraindata.put(cnumber, new Instances(traindata)); 178 ctraindata.get(cnumber).delete(); 179 } 180 ctraindata.get(cnumber).add(traindata.get(j)); 181 } 182 183 for( Entry<Integer,Instances> entry : ctraindata.entrySet() ) { 184 Instances instances = entry.getValue(); 185 int[] counts = instances.attributeStats(instances.classIndex()).nominalCounts; 186 for( int count : counts ) { 187 sufficientInstancesInEachCluster &= count>0; 188 } 189 sufficientInstancesInEachCluster &= instances.numInstances()>=5; 190 } 191 maxNumClusters = clusterer.numberOfClusters()-1; 192 } while(!sufficientInstancesInEachCluster); 193 194 // train one classifier per cluster, we get the cluster number from the training data 195 Iterator<Integer> clusternumber = ctraindata.keySet().iterator(); 196 while ( clusternumber.hasNext() ) { 197 int cnumber = clusternumber.next(); 198 cclassifier.put(cnumber,setupClassifier()); 199 cclassifier.get(cnumber).buildClassifier(ctraindata.get(cnumber)); 200 201 //Console.traceln(Level.INFO, String.format("classifier in cluster "+cnumber)); 202 } 203 } 204 } 52 private final TraindatasetCluster classifier = new TraindatasetCluster(); 53 54 @Override 55 public Classifier getClassifier() { 56 return classifier; 57 } 58 59 @Override 60 public void apply(Instances traindata) { 61 PrintStream errStr = System.err; 62 System.setErr(new PrintStream(new NullOutputStream())); 63 try { 64 classifier.buildClassifier(traindata); 65 } 66 catch (Exception e) { 67 throw new RuntimeException(e); 68 } 69 finally { 70 System.setErr(errStr); 71 } 72 } 73 74 public class TraindatasetCluster extends AbstractClassifier { 75 76 private static final long serialVersionUID = 1L; 77 78 private EM clusterer = null; 79 80 private HashMap<Integer, Classifier> cclassifier; 81 private HashMap<Integer, Instances> ctraindata; 82 83 /** 84 * Helper method that gives us a clean instance copy with the values of the instancelist of 85 * the first parameter. 86 * 87 * @param instancelist 88 * with attributes 89 * @param instance 90 * with only values 91 * @return copy of the instance 92 */ 93 private Instance createInstance(Instances instances, Instance instance) { 94 // attributes for feeding instance to classifier 95 Set<String> attributeNames = new HashSet<>(); 96 for (int j = 0; j < instances.numAttributes(); j++) { 97 attributeNames.add(instances.attribute(j).name()); 98 } 99 100 double[] values = new double[instances.numAttributes()]; 101 int index = 0; 102 for (int j = 0; j < instance.numAttributes(); j++) { 103 if (attributeNames.contains(instance.attribute(j).name())) { 104 values[index] = instance.value(j); 105 index++; 106 } 107 } 108 109 Instances tmp = new Instances(instances); 110 tmp.clear(); 111 Instance instCopy = new DenseInstance(instance.weight(), values); 112 instCopy.setDataset(tmp); 113 114 return instCopy; 115 } 116 117 @Override 118 public double classifyInstance(Instance instance) { 119 double ret = 0; 120 try { 121 // 1. copy the instance (keep the class attribute) 122 Instances traindata = ctraindata.get(0); 123 Instance classInstance = createInstance(traindata, instance); 124 125 // 2. remove class attribute before clustering 126 Remove filter = new Remove(); 127 filter.setAttributeIndices("" + (traindata.classIndex() + 1)); 128 filter.setInputFormat(traindata); 129 traindata = Filter.useFilter(traindata, filter); 130 131 // 3. copy the instance (without the class attribute) for clustering 132 Instance clusterInstance = createInstance(traindata, instance); 133 134 // 4. match instance without class attribute to a cluster number 135 int cnum = clusterer.clusterInstance(clusterInstance); 136 137 // 5. classify instance with class attribute to the classifier of that cluster 138 // number 139 ret = cclassifier.get(cnum).classifyInstance(classInstance); 140 141 } 142 catch (Exception e) { 143 Console.traceln(Level.INFO, String.format("ERROR matching instance to cluster!")); 144 throw new RuntimeException(e); 145 } 146 return ret; 147 } 148 149 @Override 150 public void buildClassifier(Instances traindata) throws Exception { 151 152 // 1. copy training data 153 Instances train = new Instances(traindata); 154 155 // 2. remove class attribute for clustering 156 Remove filter = new Remove(); 157 filter.setAttributeIndices("" + (train.classIndex() + 1)); 158 filter.setInputFormat(train); 159 train = Filter.useFilter(train, filter); 160 161 // new objects 162 cclassifier = new HashMap<Integer, Classifier>(); 163 ctraindata = new HashMap<Integer, Instances>(); 164 165 Instances ctrain; 166 int maxNumClusters = train.size(); 167 boolean sufficientInstancesInEachCluster; 168 do { // while(onlyTarget) 169 sufficientInstancesInEachCluster = true; 170 clusterer = new EM(); 171 clusterer.setMaximumNumberOfClusters(maxNumClusters); 172 clusterer.buildClusterer(train); 173 174 // 4. get cluster membership of our traindata 175 // AddCluster cfilter = new AddCluster(); 176 // cfilter.setClusterer(clusterer); 177 // cfilter.setInputFormat(train); 178 // Instances ctrain = Filter.useFilter(train, cfilter); 179 180 ctrain = new Instances(train); 181 ctraindata = new HashMap<>(); 182 183 // get traindata per cluster 184 for (int j = 0; j < ctrain.numInstances(); j++) { 185 // get the cluster number from the attributes, subract 1 because if we 186 // clusterInstance we get 0-n, and this is 1-n 187 // cnumber = 188 // Integer.parseInt(ctrain.get(j).stringValue(ctrain.get(j).numAttributes()-1).replace("cluster", 189 // "")) - 1; 190 191 int cnumber = clusterer.clusterInstance(ctrain.get(j)); 192 // add training data to list of instances for this cluster number 193 if (!ctraindata.containsKey(cnumber)) { 194 ctraindata.put(cnumber, new Instances(traindata)); 195 ctraindata.get(cnumber).delete(); 196 } 197 ctraindata.get(cnumber).add(traindata.get(j)); 198 } 199 200 for (Entry<Integer, Instances> entry : ctraindata.entrySet()) { 201 Instances instances = entry.getValue(); 202 int[] counts = instances.attributeStats(instances.classIndex()).nominalCounts; 203 for (int count : counts) { 204 sufficientInstancesInEachCluster &= count > 0; 205 } 206 sufficientInstancesInEachCluster &= instances.numInstances() >= 5; 207 } 208 maxNumClusters = clusterer.numberOfClusters() - 1; 209 } 210 while (!sufficientInstancesInEachCluster); 211 212 // train one classifier per cluster, we get the cluster number from the training data 213 Iterator<Integer> clusternumber = ctraindata.keySet().iterator(); 214 while (clusternumber.hasNext()) { 215 int cnumber = clusternumber.next(); 216 cclassifier.put(cnumber, setupClassifier()); 217 cclassifier.get(cnumber).buildClassifier(ctraindata.get(cnumber)); 218 219 // Console.traceln(Level.INFO, String.format("classifier in cluster "+cnumber)); 220 } 221 } 222 } 205 223 }
Note: See TracChangeset
for help on using the changeset viewer.