Changeset 41 for trunk/CrossPare/src/de/ugoe/cs/cpdp/training
- Timestamp:
- 09/24/15 10:59:05 (9 years ago)
- Location:
- trunk/CrossPare/src/de/ugoe/cs/cpdp/training
- Files:
-
- 12 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/CrossPare/src/de/ugoe/cs/cpdp/training/FixClass.java
r31 r41 1 // Copyright 2015 Georg-August-Universität Göttingen, Germany 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 1 15 package de.ugoe.cs.cpdp.training; 2 16 … … 14 28 * @author Steffen Herbold 15 29 */ 16 public class FixClass extends AbstractClassifier implements ITrainingStrategy, IWekaCompatibleTrainer { 30 public class FixClass extends AbstractClassifier implements ITrainingStrategy, 31 IWekaCompatibleTrainer 32 { 17 33 18 34 private static final long serialVersionUID = 1L; 19 35 20 36 private double fixedClassValue = 0.0d; 21 37 22 23 24 25 26 27 28 29 30 38 /** 39 * Returns default capabilities of the classifier. 40 * 41 * @return the capabilities of this classifier 42 */ 43 @Override 44 public Capabilities getCapabilities() { 45 Capabilities result = super.getCapabilities(); 46 result.disableAll(); 31 47 32 33 34 35 36 37 38 48 // attributes 49 result.enable(Capability.NOMINAL_ATTRIBUTES); 50 result.enable(Capability.NUMERIC_ATTRIBUTES); 51 result.enable(Capability.DATE_ATTRIBUTES); 52 result.enable(Capability.STRING_ATTRIBUTES); 53 result.enable(Capability.RELATIONAL_ATTRIBUTES); 54 result.enable(Capability.MISSING_VALUES); 39 55 40 41 42 43 56 // class 57 result.enable(Capability.NOMINAL_CLASS); 58 result.enable(Capability.NUMERIC_CLASS); 59 result.enable(Capability.MISSING_CLASS_VALUES); 44 60 45 46 61 // instances 62 result.setMinimumNumberInstances(0); 47 63 48 49 64 return result; 65 } 50 66 51 52 53 54 67 @Override 68 public void setOptions(String[] options) throws Exception { 69 fixedClassValue = Double.parseDouble(Utils.getOption('C', options)); 70 } 55 71 56 57 58 59 72 @Override 73 public double classifyInstance(Instance instance) { 74 return fixedClassValue; 75 } 60 76 61 62 63 64 77 @Override 78 public void buildClassifier(Instances traindata) throws Exception { 79 // do nothing 80 } 65 81 66 @Override 67 public void setParameter(String parameters) { 68 try { 69 this.setOptions(parameters.split(" ")); 70 } catch (Exception e) { 71 e.printStackTrace(); 72 } 73 } 82 @Override 83 public void setParameter(String parameters) { 84 try { 85 this.setOptions(parameters.split(" ")); 86 } 87 catch (Exception e) { 88 e.printStackTrace(); 89 } 90 } 74 91 75 76 77 78 92 @Override 93 public void apply(Instances traindata) { 94 // do nothing! 95 } 79 96 80 81 82 83 97 @Override 98 public String getName() { 99 return "FixClass"; 100 } 84 101 85 86 87 88 102 @Override 103 public Classifier getClassifier() { 104 return this; 105 } 89 106 90 107 } -
trunk/CrossPare/src/de/ugoe/cs/cpdp/training/ISetWiseTrainingStrategy.java
r2 r41 1 // Copyright 2015 Georg-August-Universität Göttingen, Germany 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 1 15 package de.ugoe.cs.cpdp.training; 2 16 … … 7 21 // Bagging Strategy: separate models for each training data set 8 22 public interface ISetWiseTrainingStrategy extends ITrainer { 9 10 11 12 23 24 void apply(SetUniqueList<Instances> traindataSet); 25 26 String getName(); 13 27 } -
trunk/CrossPare/src/de/ugoe/cs/cpdp/training/ITrainer.java
r2 r41 1 // Copyright 2015 Georg-August-Universität Göttingen, Germany 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 1 15 package de.ugoe.cs.cpdp.training; 2 16 -
trunk/CrossPare/src/de/ugoe/cs/cpdp/training/ITrainingStrategy.java
r6 r41 1 // Copyright 2015 Georg-August-Universität Göttingen, Germany 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 1 15 package de.ugoe.cs.cpdp.training; 2 16 … … 4 18 5 19 public interface ITrainingStrategy extends ITrainer { 6 7 8 9 20 21 void apply(Instances traindata); 22 23 String getName(); 10 24 } -
trunk/CrossPare/src/de/ugoe/cs/cpdp/training/IWekaCompatibleTrainer.java
r24 r41 1 // Copyright 2015 Georg-August-Universität Göttingen, Germany 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 1 15 package de.ugoe.cs.cpdp.training; 2 16 … … 4 18 5 19 public interface IWekaCompatibleTrainer extends ITrainer { 6 7 8 9 20 21 Classifier getClassifier(); 22 23 String getName(); 10 24 } -
trunk/CrossPare/src/de/ugoe/cs/cpdp/training/QuadTree.java
r23 r41 1 // Copyright 2015 Georg-August-Universität Göttingen, Germany 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 1 15 package de.ugoe.cs.cpdp.training; 2 16 … … 12 26 * QuadTree implementation 13 27 * 14 * QuadTree gets a list of instances and then recursively split them into 4 childs 15 * For this it usesthe median of the 2 values x,y28 * QuadTree gets a list of instances and then recursively split them into 4 childs For this it uses 29 * the median of the 2 values x,y 16 30 */ 17 31 public class QuadTree { 18 19 /* 1 parent or null */ 20 private QuadTree parent = null; 21 22 /* 4 childs, 1 per quadrant */ 23 private QuadTree child_nw; 24 private QuadTree child_ne; 25 private QuadTree child_se; 26 private QuadTree child_sw; 27 28 /* list (only helps with generation of list of childs!) */ 29 private ArrayList<QuadTree> l = new ArrayList<QuadTree>(); 30 31 /* level only used for debugging */ 32 public int level = 0; 33 34 /* size of the quadrant */ 35 private double[] x; 36 private double[] y; 37 38 public static boolean verbose = false; 39 public static int size = 0; 40 public static double alpha = 0; 41 42 /* cluster payloads */ 43 public static ArrayList<ArrayList<QuadTreePayload<Instance>>> ccluster = new ArrayList<ArrayList<QuadTreePayload<Instance>>>(); 44 45 /* cluster sizes (index is cluster number, arraylist is list of boxes (x0,y0,x1,y1) */ 46 public static HashMap<Integer, ArrayList<Double[][]>> csize = new HashMap<Integer, ArrayList<Double[][]>>(); 47 48 /* payload of this instance */ 49 private ArrayList<QuadTreePayload<Instance>> payload; 50 51 52 public QuadTree(QuadTree parent, ArrayList<QuadTreePayload<Instance>> payload) { 53 this.parent = parent; 54 this.payload = payload; 55 } 56 57 58 public String toString() { 59 String n = ""; 60 if(this.parent == null) { 61 n += "rootnode "; 62 } 63 String level = new String(new char[this.level]).replace("\0", "-"); 64 n += level + " instances: " + this.getNumbers(); 65 return n; 66 } 67 68 /** 69 * Returns the payload, used for clustering 70 * in the clustering list we only have children with paylod 71 * 72 * @return payload 73 */ 74 public ArrayList<QuadTreePayload<Instance>> getPayload() { 75 return this.payload; 76 } 77 78 /** 79 * Calculate the density of this quadrant 80 * 81 * density = number of instances / global size (all instances) 82 * 83 * @return density 84 */ 85 public double getDensity() { 86 double dens = 0; 87 dens = (double)this.getNumbers() / QuadTree.size; 88 return dens; 89 } 90 91 public void setSize(double[] x, double[] y){ 92 this.x = x; 93 this.y = y; 94 } 95 96 public double[][] getSize() { 97 return new double[][] {this.x, this.y}; 98 } 99 100 public Double[][] getSizeDouble() { 101 Double[] tmpX = new Double[2]; 102 Double[] tmpY = new Double[2]; 103 104 tmpX[0] = this.x[0]; 105 tmpX[1] = this.x[1]; 106 107 tmpY[0] = this.y[0]; 108 tmpY[1] = this.y[1]; 109 110 return new Double[][] {tmpX, tmpY}; 111 } 112 113 /** 114 * TODO: DRY, median ist immer dasselbe 115 * 116 * @return median for x 117 */ 118 private double getMedianForX() { 119 double med_x =0 ; 120 121 Collections.sort(this.payload, new Comparator<QuadTreePayload<Instance>>() { 122 @Override 123 public int compare(QuadTreePayload<Instance> x1, QuadTreePayload<Instance> x2) { 124 return Double.compare(x1.x, x2.x); 125 } 126 }); 127 128 if(this.payload.size() % 2 == 0) { 129 int mid = this.payload.size() / 2; 130 med_x = (this.payload.get(mid).x + this.payload.get(mid+1).x) / 2; 131 }else { 132 int mid = this.payload.size() / 2; 133 med_x = this.payload.get(mid).x; 134 } 135 136 if(QuadTree.verbose) { 137 System.out.println("sorted:"); 138 for(int i = 0; i < this.payload.size(); i++) { 139 System.out.print(""+this.payload.get(i).x+","); 140 } 141 System.out.println("median x: " + med_x); 142 } 143 return med_x; 144 } 145 146 private double getMedianForY() { 147 double med_y =0 ; 148 149 Collections.sort(this.payload, new Comparator<QuadTreePayload<Instance>>() { 150 @Override 151 public int compare(QuadTreePayload<Instance> y1, QuadTreePayload<Instance> y2) { 152 return Double.compare(y1.y, y2.y); 153 } 154 }); 155 156 if(this.payload.size() % 2 == 0) { 157 int mid = this.payload.size() / 2; 158 med_y = (this.payload.get(mid).y + this.payload.get(mid+1).y) / 2; 159 }else { 160 int mid = this.payload.size() / 2; 161 med_y = this.payload.get(mid).y; 162 } 163 164 if(QuadTree.verbose) { 165 System.out.println("sorted:"); 166 for(int i = 0; i < this.payload.size(); i++) { 167 System.out.print(""+this.payload.get(i).y+","); 168 } 169 System.out.println("median y: " + med_y); 170 } 171 return med_y; 172 } 173 174 /** 175 * Reurns the number of instances in the payload 176 * 177 * @return int number of instances 178 */ 179 public int getNumbers() { 180 int number = 0; 181 if(this.payload != null) { 182 number = this.payload.size(); 183 } 184 return number; 185 } 186 187 /** 188 * Calculate median values of payload for x, y and split into 4 sectors 189 * 190 * @return Array of QuadTree nodes (4 childs) 191 * @throws Exception if we would run into an recursive loop 192 */ 193 public QuadTree[] split() throws Exception { 194 195 double medx = this.getMedianForX(); 196 double medy = this.getMedianForY(); 197 198 // Payload lists for each child 199 ArrayList<QuadTreePayload<Instance>> nw = new ArrayList<QuadTreePayload<Instance>>(); 200 ArrayList<QuadTreePayload<Instance>> sw = new ArrayList<QuadTreePayload<Instance>>(); 201 ArrayList<QuadTreePayload<Instance>> ne = new ArrayList<QuadTreePayload<Instance>>(); 202 ArrayList<QuadTreePayload<Instance>> se = new ArrayList<QuadTreePayload<Instance>>(); 203 204 // sort the payloads to new payloads 205 // here we have the problem that payloads with the same values are sorted 206 // into the same slots and it could happen that medx and medy = size_x[1] and size_y[1] 207 // in that case we would have an endless loop 208 for(int i=0; i < this.payload.size(); i++) { 209 210 QuadTreePayload<Instance> item = this.payload.get(i); 211 212 // north west 213 if(item.x <= medx && item.y >= medy) { 214 nw.add(item); 215 } 216 217 // south west 218 else if(item.x <= medx && item.y <= medy) { 219 sw.add(item); 220 } 221 222 // north east 223 else if(item.x >= medx && item.y >= medy) { 224 ne.add(item); 225 } 226 227 // south east 228 else if(item.x >= medx && item.y <= medy) { 229 se.add(item); 230 } 231 } 232 233 // if we assign one child a payload equal to our own (see problem above) 234 // we throw an exceptions which stops the recursion on this node 235 if(nw.equals(this.payload)) { 236 throw new Exception("payload equal"); 237 } 238 if(sw.equals(this.payload)) { 239 throw new Exception("payload equal"); 240 } 241 if(ne.equals(this.payload)) { 242 throw new Exception("payload equal"); 243 } 244 if(se.equals(this.payload)) { 245 throw new Exception("payload equal"); 246 } 247 248 this.child_nw = new QuadTree(this, nw); 249 this.child_nw.setSize(new double[] {this.x[0], medx}, new double[] {medy, this.y[1]}); 250 this.child_nw.level = this.level + 1; 251 252 this.child_sw = new QuadTree(this, sw); 253 this.child_sw.setSize(new double[] {this.x[0], medx}, new double[] {this.y[0], medy}); 254 this.child_sw.level = this.level + 1; 255 256 this.child_ne = new QuadTree(this, ne); 257 this.child_ne.setSize(new double[] {medx, this.x[1]}, new double[] {medy, this.y[1]}); 258 this.child_ne.level = this.level + 1; 259 260 this.child_se = new QuadTree(this, se); 261 this.child_se.setSize(new double[] {medx, this.x[1]}, new double[] {this.y[0], medy}); 262 this.child_se.level = this.level + 1; 263 264 this.payload = null; 265 return new QuadTree[] {this.child_nw, this.child_ne, this.child_se, this.child_sw}; 266 } 267 268 /** 269 * TODO: static method 270 * 271 * @param q 272 */ 273 public void recursiveSplit(QuadTree q) { 274 if(QuadTree.verbose) { 275 System.out.println("splitting: "+ q); 276 } 277 if(q.getNumbers() < QuadTree.alpha) { 278 return; 279 }else{ 280 // exception is thrown if we would run into an endless loop (see comments in split()) 281 try { 282 QuadTree[] childs = q.split(); 283 this.recursiveSplit(childs[0]); 284 this.recursiveSplit(childs[1]); 285 this.recursiveSplit(childs[2]); 286 this.recursiveSplit(childs[3]); 287 }catch(Exception e) { 288 return; 289 } 290 } 291 } 292 293 /** 294 * returns an list of childs sorted by density 295 * 296 * @param q QuadTree 297 * @return list of QuadTrees 298 */ 299 private void generateList(QuadTree q) { 300 301 // we only have all childs or none at all 302 if(q.child_ne == null) { 303 this.l.add(q); 304 } 305 306 if(q.child_ne != null) { 307 this.generateList(q.child_ne); 308 } 309 if(q.child_nw != null) { 310 this.generateList(q.child_nw); 311 } 312 if(q.child_se != null) { 313 this.generateList(q.child_se); 314 } 315 if(q.child_sw != null) { 316 this.generateList(q.child_sw); 317 } 318 } 319 320 /** 321 * Checks if passed QuadTree is neighboring to us 322 * 323 * @param q QuadTree 324 * @return true if passed QuadTree is a neighbor 325 */ 326 public boolean isNeighbour(QuadTree q) { 327 boolean is_neighbour = false; 328 329 double[][] our_size = this.getSize(); 330 double[][] new_size = q.getSize(); 331 332 // X is i=0, Y is i=1 333 for(int i =0; i < 2; i++) { 334 // we are smaller than q 335 // -------------- q 336 // ------- we 337 if(our_size[i][0] >= new_size[i][0] && our_size[i][1] <= new_size[i][1]) { 338 is_neighbour = true; 339 } 340 // we overlap with q at some point 341 //a) ---------------q 342 // ----------- we 343 //b) --------- q 344 // --------- we 345 if((our_size[i][0] >= new_size[i][0] && our_size[i][0] <= new_size[i][1]) || 346 (our_size[i][1] >= new_size[i][0] && our_size[i][1] <= new_size[i][1])) { 347 is_neighbour = true; 348 } 349 // we are larger than q 350 // ---- q 351 // ---------- we 352 if(our_size[i][1] >= new_size[i][1] && our_size[i][0] <= new_size[i][0]) { 353 is_neighbour = true; 354 } 355 } 356 357 if(is_neighbour && QuadTree.verbose) { 358 System.out.println(this + " neighbour of: " + q); 359 } 360 361 return is_neighbour; 362 } 363 364 /** 365 * Perform pruning and clustering of the quadtree 366 * 367 * Pruning according to: 368 * Tim Menzies, Andrew Butcher, David Cok, Andrian Marcus, Lucas Layman, 369 * Forrest Shull, Burak Turhan, Thomas Zimmermann, 370 * "Local versus Global Lessons for Defect Prediction and Effort Estimation," 371 * IEEE Transactions on Software Engineering, vol. 39, no. 6, pp. 822-834, June, 2013 372 * 373 * 1) get list of leaf quadrants 374 * 2) sort by their density 375 * 3) set stop_rule to 0.5 * highest Density in the list 376 * 4) merge all nodes with a density > stop_rule to the new cluster and remove all from list 377 * 5) repeat 378 * 379 * @param q List of QuadTree (children only) 380 */ 381 public void gridClustering(ArrayList<QuadTree> list) { 382 383 if(list.size() == 0) { 384 return; 385 } 386 387 double stop_rule; 388 QuadTree biggest; 389 QuadTree current; 390 391 // current clusterlist 392 ArrayList<QuadTreePayload<Instance>> current_cluster; 393 394 // remove list (for removal of items after scanning of the list) 395 ArrayList<Integer> remove = new ArrayList<Integer>(); 396 397 // 1. find biggest, and add it 398 biggest = list.get(list.size()-1); 399 stop_rule = biggest.getDensity() * 0.5; 400 401 current_cluster = new ArrayList<QuadTreePayload<Instance>>(); 402 current_cluster.addAll(biggest.getPayload()); 403 404 // remove the biggest because we are starting with it 405 remove.add(list.size()-1); 406 407 ArrayList<Double[][]> tmpSize = new ArrayList<Double[][]>(); 408 tmpSize.add(biggest.getSizeDouble()); 409 410 // check the items for their density 411 for(int i=list.size()-1; i >= 0; i--) { 412 current = list.get(i); 413 414 // 2. find neighbors with correct density 415 // if density > stop_rule and is_neighbour add to cluster and remove from list 416 if(current.getDensity() > stop_rule && !current.equals(biggest) && current.isNeighbour(biggest)) { 417 current_cluster.addAll(current.getPayload()); 418 419 // add it to remove list (we cannot remove it inside the loop because it would move the index) 420 remove.add(i); 421 422 // get the size 423 tmpSize.add(current.getSizeDouble()); 424 } 425 } 426 427 // 3. remove our removal candidates from the list 428 for(Integer item: remove) { 429 list.remove((int)item); 430 } 431 432 // 4. add to cluster 433 QuadTree.ccluster.add(current_cluster); 434 435 // 5. add sizes of our current (biggest) this adds a number of sizes (all QuadTree Instances belonging to this cluster) 436 // we need that to classify test instances to a cluster later 437 Integer cnumber = new Integer(QuadTree.ccluster.size()-1); 438 if(QuadTree.csize.containsKey(cnumber) == false) { 439 QuadTree.csize.put(cnumber, tmpSize); 440 } 441 442 // repeat 443 this.gridClustering(list); 444 } 445 446 public void printInfo() { 447 System.out.println("we have " + ccluster.size() + " clusters"); 448 449 for(int i=0; i < ccluster.size(); i++) { 450 System.out.println("cluster: "+i+ " size: "+ ccluster.get(i).size()); 451 } 452 } 453 454 /** 455 * Helper Method to get a sorted list (by density) for all 456 * children 457 * 458 * @param q QuadTree 459 * @return Sorted ArrayList of quadtrees 460 */ 461 public ArrayList<QuadTree> getList(QuadTree q) { 462 this.generateList(q); 463 464 Collections.sort(this.l, new Comparator<QuadTree>() { 465 @Override 466 public int compare(QuadTree x1, QuadTree x2) { 467 return Double.compare(x1.getDensity(), x2.getDensity()); 468 } 469 }); 470 471 return this.l; 472 } 32 33 /* 1 parent or null */ 34 private QuadTree parent = null; 35 36 /* 4 childs, 1 per quadrant */ 37 private QuadTree child_nw; 38 private QuadTree child_ne; 39 private QuadTree child_se; 40 private QuadTree child_sw; 41 42 /* list (only helps with generation of list of childs!) */ 43 private ArrayList<QuadTree> l = new ArrayList<QuadTree>(); 44 45 /* level only used for debugging */ 46 public int level = 0; 47 48 /* size of the quadrant */ 49 private double[] x; 50 private double[] y; 51 52 public static boolean verbose = false; 53 public static int size = 0; 54 public static double alpha = 0; 55 56 /* cluster payloads */ 57 public static ArrayList<ArrayList<QuadTreePayload<Instance>>> ccluster = 58 new ArrayList<ArrayList<QuadTreePayload<Instance>>>(); 59 60 /* cluster sizes (index is cluster number, arraylist is list of boxes (x0,y0,x1,y1) */ 61 public static HashMap<Integer, ArrayList<Double[][]>> csize = 62 new HashMap<Integer, ArrayList<Double[][]>>(); 63 64 /* payload of this instance */ 65 private ArrayList<QuadTreePayload<Instance>> payload; 66 67 public QuadTree(QuadTree parent, ArrayList<QuadTreePayload<Instance>> payload) { 68 this.parent = parent; 69 this.payload = payload; 70 } 71 72 public String toString() { 73 String n = ""; 74 if (this.parent == null) { 75 n += "rootnode "; 76 } 77 String level = new String(new char[this.level]).replace("\0", "-"); 78 n += level + " instances: " + this.getNumbers(); 79 return n; 80 } 81 82 /** 83 * Returns the payload, used for clustering in the clustering list we only have children with 84 * paylod 85 * 86 * @return payload 87 */ 88 public ArrayList<QuadTreePayload<Instance>> getPayload() { 89 return this.payload; 90 } 91 92 /** 93 * Calculate the density of this quadrant 94 * 95 * density = number of instances / global size (all instances) 96 * 97 * @return density 98 */ 99 public double getDensity() { 100 double dens = 0; 101 dens = (double) this.getNumbers() / QuadTree.size; 102 return dens; 103 } 104 105 public void setSize(double[] x, double[] y) { 106 this.x = x; 107 this.y = y; 108 } 109 110 public double[][] getSize() { 111 return new double[][] 112 { this.x, this.y }; 113 } 114 115 public Double[][] getSizeDouble() { 116 Double[] tmpX = new Double[2]; 117 Double[] tmpY = new Double[2]; 118 119 tmpX[0] = this.x[0]; 120 tmpX[1] = this.x[1]; 121 122 tmpY[0] = this.y[0]; 123 tmpY[1] = this.y[1]; 124 125 return new Double[][] 126 { tmpX, tmpY }; 127 } 128 129 /** 130 * TODO: DRY, median ist immer dasselbe 131 * 132 * @return median for x 133 */ 134 private double getMedianForX() { 135 double med_x = 0; 136 137 Collections.sort(this.payload, new Comparator<QuadTreePayload<Instance>>() { 138 @Override 139 public int compare(QuadTreePayload<Instance> x1, QuadTreePayload<Instance> x2) { 140 return Double.compare(x1.x, x2.x); 141 } 142 }); 143 144 if (this.payload.size() % 2 == 0) { 145 int mid = this.payload.size() / 2; 146 med_x = (this.payload.get(mid).x + this.payload.get(mid + 1).x) / 2; 147 } 148 else { 149 int mid = this.payload.size() / 2; 150 med_x = this.payload.get(mid).x; 151 } 152 153 if (QuadTree.verbose) { 154 System.out.println("sorted:"); 155 for (int i = 0; i < this.payload.size(); i++) { 156 System.out.print("" + this.payload.get(i).x + ","); 157 } 158 System.out.println("median x: " + med_x); 159 } 160 return med_x; 161 } 162 163 private double getMedianForY() { 164 double med_y = 0; 165 166 Collections.sort(this.payload, new Comparator<QuadTreePayload<Instance>>() { 167 @Override 168 public int compare(QuadTreePayload<Instance> y1, QuadTreePayload<Instance> y2) { 169 return Double.compare(y1.y, y2.y); 170 } 171 }); 172 173 if (this.payload.size() % 2 == 0) { 174 int mid = this.payload.size() / 2; 175 med_y = (this.payload.get(mid).y + this.payload.get(mid + 1).y) / 2; 176 } 177 else { 178 int mid = this.payload.size() / 2; 179 med_y = this.payload.get(mid).y; 180 } 181 182 if (QuadTree.verbose) { 183 System.out.println("sorted:"); 184 for (int i = 0; i < this.payload.size(); i++) { 185 System.out.print("" + this.payload.get(i).y + ","); 186 } 187 System.out.println("median y: " + med_y); 188 } 189 return med_y; 190 } 191 192 /** 193 * Reurns the number of instances in the payload 194 * 195 * @return int number of instances 196 */ 197 public int getNumbers() { 198 int number = 0; 199 if (this.payload != null) { 200 number = this.payload.size(); 201 } 202 return number; 203 } 204 205 /** 206 * Calculate median values of payload for x, y and split into 4 sectors 207 * 208 * @return Array of QuadTree nodes (4 childs) 209 * @throws Exception 210 * if we would run into an recursive loop 211 */ 212 public QuadTree[] split() throws Exception { 213 214 double medx = this.getMedianForX(); 215 double medy = this.getMedianForY(); 216 217 // Payload lists for each child 218 ArrayList<QuadTreePayload<Instance>> nw = new ArrayList<QuadTreePayload<Instance>>(); 219 ArrayList<QuadTreePayload<Instance>> sw = new ArrayList<QuadTreePayload<Instance>>(); 220 ArrayList<QuadTreePayload<Instance>> ne = new ArrayList<QuadTreePayload<Instance>>(); 221 ArrayList<QuadTreePayload<Instance>> se = new ArrayList<QuadTreePayload<Instance>>(); 222 223 // sort the payloads to new payloads 224 // here we have the problem that payloads with the same values are sorted 225 // into the same slots and it could happen that medx and medy = size_x[1] and size_y[1] 226 // in that case we would have an endless loop 227 for (int i = 0; i < this.payload.size(); i++) { 228 229 QuadTreePayload<Instance> item = this.payload.get(i); 230 231 // north west 232 if (item.x <= medx && item.y >= medy) { 233 nw.add(item); 234 } 235 236 // south west 237 else if (item.x <= medx && item.y <= medy) { 238 sw.add(item); 239 } 240 241 // north east 242 else if (item.x >= medx && item.y >= medy) { 243 ne.add(item); 244 } 245 246 // south east 247 else if (item.x >= medx && item.y <= medy) { 248 se.add(item); 249 } 250 } 251 252 // if we assign one child a payload equal to our own (see problem above) 253 // we throw an exceptions which stops the recursion on this node 254 if (nw.equals(this.payload)) { 255 throw new Exception("payload equal"); 256 } 257 if (sw.equals(this.payload)) { 258 throw new Exception("payload equal"); 259 } 260 if (ne.equals(this.payload)) { 261 throw new Exception("payload equal"); 262 } 263 if (se.equals(this.payload)) { 264 throw new Exception("payload equal"); 265 } 266 267 this.child_nw = new QuadTree(this, nw); 268 this.child_nw.setSize(new double[] 269 { this.x[0], medx }, new double[] 270 { medy, this.y[1] }); 271 this.child_nw.level = this.level + 1; 272 273 this.child_sw = new QuadTree(this, sw); 274 this.child_sw.setSize(new double[] 275 { this.x[0], medx }, new double[] 276 { this.y[0], medy }); 277 this.child_sw.level = this.level + 1; 278 279 this.child_ne = new QuadTree(this, ne); 280 this.child_ne.setSize(new double[] 281 { medx, this.x[1] }, new double[] 282 { medy, this.y[1] }); 283 this.child_ne.level = this.level + 1; 284 285 this.child_se = new QuadTree(this, se); 286 this.child_se.setSize(new double[] 287 { medx, this.x[1] }, new double[] 288 { this.y[0], medy }); 289 this.child_se.level = this.level + 1; 290 291 this.payload = null; 292 return new QuadTree[] 293 { this.child_nw, this.child_ne, this.child_se, this.child_sw }; 294 } 295 296 /** 297 * TODO: static method 298 * 299 * @param q 300 */ 301 public void recursiveSplit(QuadTree q) { 302 if (QuadTree.verbose) { 303 System.out.println("splitting: " + q); 304 } 305 if (q.getNumbers() < QuadTree.alpha) { 306 return; 307 } 308 else { 309 // exception is thrown if we would run into an endless loop (see comments in split()) 310 try { 311 QuadTree[] childs = q.split(); 312 this.recursiveSplit(childs[0]); 313 this.recursiveSplit(childs[1]); 314 this.recursiveSplit(childs[2]); 315 this.recursiveSplit(childs[3]); 316 } 317 catch (Exception e) { 318 return; 319 } 320 } 321 } 322 323 /** 324 * returns an list of childs sorted by density 325 * 326 * @param q 327 * QuadTree 328 * @return list of QuadTrees 329 */ 330 private void generateList(QuadTree q) { 331 332 // we only have all childs or none at all 333 if (q.child_ne == null) { 334 this.l.add(q); 335 } 336 337 if (q.child_ne != null) { 338 this.generateList(q.child_ne); 339 } 340 if (q.child_nw != null) { 341 this.generateList(q.child_nw); 342 } 343 if (q.child_se != null) { 344 this.generateList(q.child_se); 345 } 346 if (q.child_sw != null) { 347 this.generateList(q.child_sw); 348 } 349 } 350 351 /** 352 * Checks if passed QuadTree is neighboring to us 353 * 354 * @param q 355 * QuadTree 356 * @return true if passed QuadTree is a neighbor 357 */ 358 public boolean isNeighbour(QuadTree q) { 359 boolean is_neighbour = false; 360 361 double[][] our_size = this.getSize(); 362 double[][] new_size = q.getSize(); 363 364 // X is i=0, Y is i=1 365 for (int i = 0; i < 2; i++) { 366 // we are smaller than q 367 // -------------- q 368 // ------- we 369 if (our_size[i][0] >= new_size[i][0] && our_size[i][1] <= new_size[i][1]) { 370 is_neighbour = true; 371 } 372 // we overlap with q at some point 373 // a) ---------------q 374 // ----------- we 375 // b) --------- q 376 // --------- we 377 if ((our_size[i][0] >= new_size[i][0] && our_size[i][0] <= new_size[i][1]) || 378 (our_size[i][1] >= new_size[i][0] && our_size[i][1] <= new_size[i][1])) 379 { 380 is_neighbour = true; 381 } 382 // we are larger than q 383 // ---- q 384 // ---------- we 385 if (our_size[i][1] >= new_size[i][1] && our_size[i][0] <= new_size[i][0]) { 386 is_neighbour = true; 387 } 388 } 389 390 if (is_neighbour && QuadTree.verbose) { 391 System.out.println(this + " neighbour of: " + q); 392 } 393 394 return is_neighbour; 395 } 396 397 /** 398 * Perform pruning and clustering of the quadtree 399 * 400 * Pruning according to: Tim Menzies, Andrew Butcher, David Cok, Andrian Marcus, Lucas Layman, 401 * Forrest Shull, Burak Turhan, Thomas Zimmermann, 402 * "Local versus Global Lessons for Defect Prediction and Effort Estimation," IEEE Transactions 403 * on Software Engineering, vol. 39, no. 6, pp. 822-834, June, 2013 404 * 405 * 1) get list of leaf quadrants 2) sort by their density 3) set stop_rule to 0.5 * highest 406 * Density in the list 4) merge all nodes with a density > stop_rule to the new cluster and 407 * remove all from list 5) repeat 408 * 409 * @param q 410 * List of QuadTree (children only) 411 */ 412 public void gridClustering(ArrayList<QuadTree> list) { 413 414 if (list.size() == 0) { 415 return; 416 } 417 418 double stop_rule; 419 QuadTree biggest; 420 QuadTree current; 421 422 // current clusterlist 423 ArrayList<QuadTreePayload<Instance>> current_cluster; 424 425 // remove list (for removal of items after scanning of the list) 426 ArrayList<Integer> remove = new ArrayList<Integer>(); 427 428 // 1. find biggest, and add it 429 biggest = list.get(list.size() - 1); 430 stop_rule = biggest.getDensity() * 0.5; 431 432 current_cluster = new ArrayList<QuadTreePayload<Instance>>(); 433 current_cluster.addAll(biggest.getPayload()); 434 435 // remove the biggest because we are starting with it 436 remove.add(list.size() - 1); 437 438 ArrayList<Double[][]> tmpSize = new ArrayList<Double[][]>(); 439 tmpSize.add(biggest.getSizeDouble()); 440 441 // check the items for their density 442 for (int i = list.size() - 1; i >= 0; i--) { 443 current = list.get(i); 444 445 // 2. find neighbors with correct density 446 // if density > stop_rule and is_neighbour add to cluster and remove from list 447 if (current.getDensity() > stop_rule && !current.equals(biggest) && 448 current.isNeighbour(biggest)) 449 { 450 current_cluster.addAll(current.getPayload()); 451 452 // add it to remove list (we cannot remove it inside the loop because it would move 453 // the index) 454 remove.add(i); 455 456 // get the size 457 tmpSize.add(current.getSizeDouble()); 458 } 459 } 460 461 // 3. remove our removal candidates from the list 462 for (Integer item : remove) { 463 list.remove((int) item); 464 } 465 466 // 4. add to cluster 467 QuadTree.ccluster.add(current_cluster); 468 469 // 5. add sizes of our current (biggest) this adds a number of sizes (all QuadTree Instances 470 // belonging to this cluster) 471 // we need that to classify test instances to a cluster later 472 Integer cnumber = new Integer(QuadTree.ccluster.size() - 1); 473 if (QuadTree.csize.containsKey(cnumber) == false) { 474 QuadTree.csize.put(cnumber, tmpSize); 475 } 476 477 // repeat 478 this.gridClustering(list); 479 } 480 481 public void printInfo() { 482 System.out.println("we have " + ccluster.size() + " clusters"); 483 484 for (int i = 0; i < ccluster.size(); i++) { 485 System.out.println("cluster: " + i + " size: " + ccluster.get(i).size()); 486 } 487 } 488 489 /** 490 * Helper Method to get a sorted list (by density) for all children 491 * 492 * @param q 493 * QuadTree 494 * @return Sorted ArrayList of quadtrees 495 */ 496 public ArrayList<QuadTree> getList(QuadTree q) { 497 this.generateList(q); 498 499 Collections.sort(this.l, new Comparator<QuadTree>() { 500 @Override 501 public int compare(QuadTree x1, QuadTree x2) { 502 return Double.compare(x1.getDensity(), x2.getDensity()); 503 } 504 }); 505 506 return this.l; 507 } 473 508 } -
trunk/CrossPare/src/de/ugoe/cs/cpdp/training/RandomClass.java
r38 r41 1 // Copyright 2015 Georg-August-Universität Göttingen, Germany 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 1 15 package de.ugoe.cs.cpdp.training; 2 16 … … 11 25 * Assigns a random class label to the instance it is evaluated on. 12 26 * 13 * The range of class labels are hardcoded in fixedClassValues. 14 * This can later be extended to take values from the XML configuration.27 * The range of class labels are hardcoded in fixedClassValues. This can later be extended to take 28 * values from the XML configuration. 15 29 */ 16 public class RandomClass extends AbstractClassifier implements ITrainingStrategy, IWekaCompatibleTrainer { 30 public class RandomClass extends AbstractClassifier implements ITrainingStrategy, 31 IWekaCompatibleTrainer 32 { 17 33 18 34 private static final long serialVersionUID = 1L; 19 35 20 private double[] fixedClassValues = {0.0d, 1.0d}; 21 22 @Override 23 public void setParameter(String parameters) { 24 // do nothing, maybe take percentages for distribution later 25 } 36 private double[] fixedClassValues = 37 { 0.0d, 1.0d }; 26 38 27 28 public void buildClassifier(Instances arg0) throws Exception{29 // do nothing 30 39 @Override 40 public void setParameter(String parameters) { 41 // do nothing, maybe take percentages for distribution later 42 } 31 43 32 33 public Classifier getClassifier(){34 return this; 35 44 @Override 45 public void buildClassifier(Instances arg0) throws Exception { 46 // do nothing 47 } 36 48 37 38 public void apply(Instances traindata) {39 // nothing to do 40 49 @Override 50 public Classifier getClassifier() { 51 return this; 52 } 41 53 42 @Override 43 public String getName() { 44 return "RandomClass"; 45 } 46 47 @Override 48 public double classifyInstance(Instance instance) { 49 Random rand = new Random(); 50 int randomNum = rand.nextInt(this.fixedClassValues.length); 51 return this.fixedClassValues[randomNum]; 52 } 54 @Override 55 public void apply(Instances traindata) { 56 // nothing to do 57 } 58 59 @Override 60 public String getName() { 61 return "RandomClass"; 62 } 63 64 @Override 65 public double classifyInstance(Instance instance) { 66 Random rand = new Random(); 67 int randomNum = rand.nextInt(this.fixedClassValues.length); 68 return this.fixedClassValues[randomNum]; 69 } 53 70 } -
trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaBaggingTraining.java
r25 r41 1 // Copyright 2015 Georg-August-Universität Göttingen, Germany 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 1 15 package de.ugoe.cs.cpdp.training; 2 16 … … 18 32 /** 19 33 * Programmatic WekaBaggingTraining 20 *21 * first parameter is Trainer Name.22 * second parameter is class name23 34 * 24 * all subsequent parameters are configuration params (for example for trees) 25 * Cross Validation params always come last and are prepended with -CVPARAM 35 * first parameter is Trainer Name. second parameter is class name 36 * 37 * all subsequent parameters are configuration params (for example for trees) Cross Validation 38 * params always come last and are prepended with -CVPARAM 26 39 * 27 40 * XML Configurations for Weka Classifiers: 41 * 28 42 * <pre> 29 43 * {@code … … 37 51 public class WekaBaggingTraining extends WekaBaseTraining implements ISetWiseTrainingStrategy { 38 52 39 private final TraindatasetBagging classifier = new TraindatasetBagging(); 40 41 @Override 42 public Classifier getClassifier() { 43 return classifier; 44 } 45 46 @Override 47 public void apply(SetUniqueList<Instances> traindataSet) { 48 PrintStream errStr = System.err; 49 System.setErr(new PrintStream(new NullOutputStream())); 50 try { 51 classifier.buildClassifier(traindataSet); 52 } catch (Exception e) { 53 throw new RuntimeException(e); 54 } finally { 55 System.setErr(errStr); 56 } 57 } 58 59 public class TraindatasetBagging extends AbstractClassifier { 60 61 private static final long serialVersionUID = 1L; 53 private final TraindatasetBagging classifier = new TraindatasetBagging(); 62 54 63 private List<Instances> trainingData = null; 64 65 private List<Classifier> classifiers = null; 66 67 @Override 68 public double classifyInstance(Instance instance) { 69 if( classifiers==null ) { 70 return 0.0; 71 } 72 73 double classification = 0.0; 74 for( int i=0 ; i<classifiers.size(); i++ ) { 75 Classifier classifier = classifiers.get(i); 76 Instances traindata = trainingData.get(i); 77 78 Set<String> attributeNames = new HashSet<>(); 79 for( int j=0; j<traindata.numAttributes(); j++ ) { 80 attributeNames.add(traindata.attribute(j).name()); 81 } 82 83 double[] values = new double[traindata.numAttributes()]; 84 int index = 0; 85 for( int j=0; j<instance.numAttributes(); j++ ) { 86 if( attributeNames.contains(instance.attribute(j).name())) { 87 values[index] = instance.value(j); 88 index++; 89 } 90 } 91 92 Instances tmp = new Instances(traindata); 93 tmp.clear(); 94 Instance instCopy = new DenseInstance(instance.weight(), values); 95 instCopy.setDataset(tmp); 96 try { 97 classification += classifier.classifyInstance(instCopy); 98 } catch (Exception e) { 99 throw new RuntimeException("bagging classifier could not classify an instance", e); 100 } 101 } 102 classification /= classifiers.size(); 103 return (classification>=0.5) ? 1.0 : 0.0; 104 } 105 106 public void buildClassifier(SetUniqueList<Instances> traindataSet) throws Exception { 107 classifiers = new LinkedList<>(); 108 trainingData = new LinkedList<>(); 109 for( Instances traindata : traindataSet ) { 110 Classifier classifier = setupClassifier(); 111 classifier.buildClassifier(traindata); 112 classifiers.add(classifier); 113 trainingData.add(new Instances(traindata)); 114 } 115 } 116 117 @Override 118 public void buildClassifier(Instances traindata) throws Exception { 119 classifiers = new LinkedList<>(); 120 trainingData = new LinkedList<>(); 121 final Classifier classifier = setupClassifier(); 122 classifier.buildClassifier(traindata); 123 classifiers.add(classifier); 124 trainingData.add(new Instances(traindata)); 125 } 126 } 55 @Override 56 public Classifier getClassifier() { 57 return classifier; 58 } 59 60 @Override 61 public void apply(SetUniqueList<Instances> traindataSet) { 62 PrintStream errStr = System.err; 63 System.setErr(new PrintStream(new NullOutputStream())); 64 try { 65 classifier.buildClassifier(traindataSet); 66 } 67 catch (Exception e) { 68 throw new RuntimeException(e); 69 } 70 finally { 71 System.setErr(errStr); 72 } 73 } 74 75 public class TraindatasetBagging extends AbstractClassifier { 76 77 private static final long serialVersionUID = 1L; 78 79 private List<Instances> trainingData = null; 80 81 private List<Classifier> classifiers = null; 82 83 @Override 84 public double classifyInstance(Instance instance) { 85 if (classifiers == null) { 86 return 0.0; 87 } 88 89 double classification = 0.0; 90 for (int i = 0; i < classifiers.size(); i++) { 91 Classifier classifier = classifiers.get(i); 92 Instances traindata = trainingData.get(i); 93 94 Set<String> attributeNames = new HashSet<>(); 95 for (int j = 0; j < traindata.numAttributes(); j++) { 96 attributeNames.add(traindata.attribute(j).name()); 97 } 98 99 double[] values = new double[traindata.numAttributes()]; 100 int index = 0; 101 for (int j = 0; j < instance.numAttributes(); j++) { 102 if (attributeNames.contains(instance.attribute(j).name())) { 103 values[index] = instance.value(j); 104 index++; 105 } 106 } 107 108 Instances tmp = new Instances(traindata); 109 tmp.clear(); 110 Instance instCopy = new DenseInstance(instance.weight(), values); 111 instCopy.setDataset(tmp); 112 try { 113 classification += classifier.classifyInstance(instCopy); 114 } 115 catch (Exception e) { 116 throw new RuntimeException("bagging classifier could not classify an instance", 117 e); 118 } 119 } 120 classification /= classifiers.size(); 121 return (classification >= 0.5) ? 1.0 : 0.0; 122 } 123 124 public void buildClassifier(SetUniqueList<Instances> traindataSet) throws Exception { 125 classifiers = new LinkedList<>(); 126 trainingData = new LinkedList<>(); 127 for (Instances traindata : traindataSet) { 128 Classifier classifier = setupClassifier(); 129 classifier.buildClassifier(traindata); 130 classifiers.add(classifier); 131 trainingData.add(new Instances(traindata)); 132 } 133 } 134 135 @Override 136 public void buildClassifier(Instances traindata) throws Exception { 137 classifiers = new LinkedList<>(); 138 trainingData = new LinkedList<>(); 139 final Classifier classifier = setupClassifier(); 140 classifier.buildClassifier(traindata); 141 classifiers.add(classifier); 142 trainingData.add(new Instances(traindata)); 143 } 144 } 127 145 } -
trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaBaseTraining.java
r25 r41 1 // Copyright 2015 Georg-August-Universität Göttingen, Germany 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 1 15 package de.ugoe.cs.cpdp.training; 2 16 … … 15 29 * Allows specification of the Weka classifier and its params in the XML experiment configuration. 16 30 * 17 * Important conventions of the XML format: 18 * Cross Validation params always come last and are prepended with -CVPARAM19 * Example: <trainer name="WekaTraining"param="RandomForestLocal weka.classifiers.trees.RandomForest -CVPARAM I 5 25 5"/>31 * Important conventions of the XML format: Cross Validation params always come last and are 32 * prepended with -CVPARAM Example: <trainer name="WekaTraining" 33 * param="RandomForestLocal weka.classifiers.trees.RandomForest -CVPARAM I 5 25 5"/> 20 34 */ 21 35 public abstract class WekaBaseTraining implements IWekaCompatibleTrainer { 22 23 protected Classifier classifier = null;24 protected String classifierClassName;25 protected String classifierName;26 protected String[] classifierParams;27 28 @Override29 public void setParameter(String parameters) {30 String[] params = parameters.split(" ");31 36 32 // first part of the params is the classifierName (e.g. SMORBF) 33 classifierName = params[0]; 34 35 // the following parameters can be copied from weka! 36 37 // second param is classifierClassName (e.g. weka.classifiers.functions.SMO) 38 classifierClassName = params[1]; 39 40 // rest are params to the specified classifier (e.g. -K weka.classifiers.functions.supportVector.RBFKernel) 41 classifierParams = Arrays.copyOfRange(params, 2, params.length); 42 43 classifier = setupClassifier(); 44 } 37 protected Classifier classifier = null; 38 protected String classifierClassName; 39 protected String classifierName; 40 protected String[] classifierParams; 45 41 46 @Override 47 public Classifier getClassifier() { 48 return classifier; 49 } 42 @Override 43 public void setParameter(String parameters) { 44 String[] params = parameters.split(" "); 50 45 51 public Classifier setupClassifier() { 52 Classifier cl = null; 53 try{ 54 @SuppressWarnings("rawtypes") 55 Class c = Class.forName(classifierClassName); 56 Classifier obj = (Classifier) c.newInstance(); 57 58 // Filter out -CVPARAM, these are special because they do not belong to the Weka classifier class as parameters 59 String[] param = Arrays.copyOf(classifierParams, classifierParams.length); 60 String[] cvparam = {}; 61 boolean cv = false; 62 for ( int i=0; i < classifierParams.length; i++ ) { 63 if(classifierParams[i].equals("-CVPARAM")) { 64 // rest of array are cvparam 65 cvparam = Arrays.copyOfRange(classifierParams, i+1, classifierParams.length); 66 67 // before this we have normal params 68 param = Arrays.copyOfRange(classifierParams, 0, i); 69 70 cv = true; 71 break; 72 } 73 } 74 75 // set classifier params 76 ((OptionHandler)obj).setOptions(param); 77 cl = obj; 78 79 // we have cross val params 80 // cant check on cvparam.length here, it may not be initialized 81 if(cv) { 82 final CVParameterSelection ps = new CVParameterSelection(); 83 ps.setClassifier(obj); 84 ps.setNumFolds(5); 85 //ps.addCVParameter("I 5 25 5"); 86 for( int i=1 ; i<cvparam.length/4 ; i++ ) { 87 ps.addCVParameter(Arrays.asList(Arrays.copyOfRange(cvparam, 0, 4*i)).toString().replaceAll(", ", " ").replaceAll("^\\[|\\]$", "")); 88 } 89 90 cl = ps; 91 } 46 // first part of the params is the classifierName (e.g. SMORBF) 47 classifierName = params[0]; 92 48 93 }catch(ClassNotFoundException e) { 94 Console.traceln(Level.WARNING, String.format("class not found: %s", e.toString())); 95 e.printStackTrace(); 96 } catch (InstantiationException e) { 97 Console.traceln(Level.WARNING, String.format("Instantiation Exception: %s", e.toString())); 98 e.printStackTrace(); 99 } catch (IllegalAccessException e) { 100 Console.traceln(Level.WARNING, String.format("Illegal Access Exception: %s", e.toString())); 101 e.printStackTrace(); 102 } catch (Exception e) { 103 Console.traceln(Level.WARNING, String.format("Exception: %s", e.toString())); 104 e.printStackTrace(); 105 } 106 107 return cl; 108 } 49 // the following parameters can be copied from weka! 109 50 110 @Override 111 public String getName() { 112 return classifierName; 113 } 114 51 // second param is classifierClassName (e.g. weka.classifiers.functions.SMO) 52 classifierClassName = params[1]; 53 54 // rest are params to the specified classifier (e.g. -K 55 // weka.classifiers.functions.supportVector.RBFKernel) 56 classifierParams = Arrays.copyOfRange(params, 2, params.length); 57 58 classifier = setupClassifier(); 59 } 60 61 @Override 62 public Classifier getClassifier() { 63 return classifier; 64 } 65 66 public Classifier setupClassifier() { 67 Classifier cl = null; 68 try { 69 @SuppressWarnings("rawtypes") 70 Class c = Class.forName(classifierClassName); 71 Classifier obj = (Classifier) c.newInstance(); 72 73 // Filter out -CVPARAM, these are special because they do not belong to the Weka 74 // classifier class as parameters 75 String[] param = Arrays.copyOf(classifierParams, classifierParams.length); 76 String[] cvparam = { }; 77 boolean cv = false; 78 for (int i = 0; i < classifierParams.length; i++) { 79 if (classifierParams[i].equals("-CVPARAM")) { 80 // rest of array are cvparam 81 cvparam = Arrays.copyOfRange(classifierParams, i + 1, classifierParams.length); 82 83 // before this we have normal params 84 param = Arrays.copyOfRange(classifierParams, 0, i); 85 86 cv = true; 87 break; 88 } 89 } 90 91 // set classifier params 92 ((OptionHandler) obj).setOptions(param); 93 cl = obj; 94 95 // we have cross val params 96 // cant check on cvparam.length here, it may not be initialized 97 if (cv) { 98 final CVParameterSelection ps = new CVParameterSelection(); 99 ps.setClassifier(obj); 100 ps.setNumFolds(5); 101 // ps.addCVParameter("I 5 25 5"); 102 for (int i = 1; i < cvparam.length / 4; i++) { 103 ps.addCVParameter(Arrays.asList(Arrays.copyOfRange(cvparam, 0, 4 * i)) 104 .toString().replaceAll(", ", " ").replaceAll("^\\[|\\]$", "")); 105 } 106 107 cl = ps; 108 } 109 110 } 111 catch (ClassNotFoundException e) { 112 Console.traceln(Level.WARNING, String.format("class not found: %s", e.toString())); 113 e.printStackTrace(); 114 } 115 catch (InstantiationException e) { 116 Console.traceln(Level.WARNING, 117 String.format("Instantiation Exception: %s", e.toString())); 118 e.printStackTrace(); 119 } 120 catch (IllegalAccessException e) { 121 Console.traceln(Level.WARNING, 122 String.format("Illegal Access Exception: %s", e.toString())); 123 e.printStackTrace(); 124 } 125 catch (Exception e) { 126 Console.traceln(Level.WARNING, String.format("Exception: %s", e.toString())); 127 e.printStackTrace(); 128 } 129 130 return cl; 131 } 132 133 @Override 134 public String getName() { 135 return classifierName; 136 } 137 115 138 } -
trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaLocalEMTraining.java
r25 r41 1 // Copyright 2015 Georg-August-Universität Göttingen, Germany 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 1 15 package de.ugoe.cs.cpdp.training; 2 16 … … 24 38 * WekaLocalEMTraining 25 39 * 26 * Local Trainer with EM Clustering for data partitioning. 27 * Currently supports only EM Clustering. 28 * 29 * 1. Cluster training data 30 * 2. for each cluster train a classifier with training data from cluster 40 * Local Trainer with EM Clustering for data partitioning. Currently supports only EM Clustering. 41 * 42 * 1. Cluster training data 2. for each cluster train a classifier with training data from cluster 31 43 * 3. match test data instance to a cluster, then classify with classifier from the cluster 32 44 * 33 * XML configuration: 34 * <!-- because of clustering --> 35 * <preprocessor name="Normalization" param=""/> 36 * 37 * <!-- cluster trainer --> 38 * <trainer name="WekaLocalEMTraining" param="NaiveBayes weka.classifiers.bayes.NaiveBayes" /> 45 * XML configuration: <!-- because of clustering --> <preprocessor name="Normalization" param=""/> 46 * 47 * <!-- cluster trainer --> <trainer name="WekaLocalEMTraining" 48 * param="NaiveBayes weka.classifiers.bayes.NaiveBayes" /> 39 49 */ 40 50 public class WekaLocalEMTraining extends WekaBaseTraining implements ITrainingStrategy { 41 51 42 private final TraindatasetCluster classifier = new TraindatasetCluster(); 43 44 @Override 45 public Classifier getClassifier() { 46 return classifier; 47 } 48 49 @Override 50 public void apply(Instances traindata) { 51 PrintStream errStr = System.err; 52 System.setErr(new PrintStream(new NullOutputStream())); 53 try { 54 classifier.buildClassifier(traindata); 55 } catch (Exception e) { 56 throw new RuntimeException(e); 57 } finally { 58 System.setErr(errStr); 59 } 60 } 61 62 63 public class TraindatasetCluster extends AbstractClassifier { 64 65 private static final long serialVersionUID = 1L; 66 67 private EM clusterer = null; 68 69 private HashMap<Integer, Classifier> cclassifier; 70 private HashMap<Integer, Instances> ctraindata; 71 72 73 /** 74 * Helper method that gives us a clean instance copy with 75 * the values of the instancelist of the first parameter. 76 * 77 * @param instancelist with attributes 78 * @param instance with only values 79 * @return copy of the instance 80 */ 81 private Instance createInstance(Instances instances, Instance instance) { 82 // attributes for feeding instance to classifier 83 Set<String> attributeNames = new HashSet<>(); 84 for( int j=0; j<instances.numAttributes(); j++ ) { 85 attributeNames.add(instances.attribute(j).name()); 86 } 87 88 double[] values = new double[instances.numAttributes()]; 89 int index = 0; 90 for( int j=0; j<instance.numAttributes(); j++ ) { 91 if( attributeNames.contains(instance.attribute(j).name())) { 92 values[index] = instance.value(j); 93 index++; 94 } 95 } 96 97 Instances tmp = new Instances(instances); 98 tmp.clear(); 99 Instance instCopy = new DenseInstance(instance.weight(), values); 100 instCopy.setDataset(tmp); 101 102 return instCopy; 103 } 104 105 @Override 106 public double classifyInstance(Instance instance) { 107 double ret = 0; 108 try { 109 // 1. copy the instance (keep the class attribute) 110 Instances traindata = ctraindata.get(0); 111 Instance classInstance = createInstance(traindata, instance); 112 113 // 2. remove class attribute before clustering 114 Remove filter = new Remove(); 115 filter.setAttributeIndices("" + (traindata.classIndex() + 1)); 116 filter.setInputFormat(traindata); 117 traindata = Filter.useFilter(traindata, filter); 118 119 // 3. copy the instance (without the class attribute) for clustering 120 Instance clusterInstance = createInstance(traindata, instance); 121 122 // 4. match instance without class attribute to a cluster number 123 int cnum = clusterer.clusterInstance(clusterInstance); 124 125 // 5. classify instance with class attribute to the classifier of that cluster number 126 ret = cclassifier.get(cnum).classifyInstance(classInstance); 127 128 }catch( Exception e ) { 129 Console.traceln(Level.INFO, String.format("ERROR matching instance to cluster!")); 130 throw new RuntimeException(e); 131 } 132 return ret; 133 } 134 135 @Override 136 public void buildClassifier(Instances traindata) throws Exception { 137 138 // 1. copy training data 139 Instances train = new Instances(traindata); 140 141 // 2. remove class attribute for clustering 142 Remove filter = new Remove(); 143 filter.setAttributeIndices("" + (train.classIndex() + 1)); 144 filter.setInputFormat(train); 145 train = Filter.useFilter(train, filter); 146 147 // new objects 148 cclassifier = new HashMap<Integer, Classifier>(); 149 ctraindata = new HashMap<Integer, Instances>(); 150 151 Instances ctrain; 152 int maxNumClusters = train.size(); 153 boolean sufficientInstancesInEachCluster; 154 do { // while(onlyTarget) 155 sufficientInstancesInEachCluster = true; 156 clusterer = new EM(); 157 clusterer.setMaximumNumberOfClusters(maxNumClusters); 158 clusterer.buildClusterer(train); 159 160 // 4. get cluster membership of our traindata 161 //AddCluster cfilter = new AddCluster(); 162 //cfilter.setClusterer(clusterer); 163 //cfilter.setInputFormat(train); 164 //Instances ctrain = Filter.useFilter(train, cfilter); 165 166 ctrain = new Instances(train); 167 ctraindata = new HashMap<>(); 168 169 // get traindata per cluster 170 for ( int j=0; j < ctrain.numInstances(); j++ ) { 171 // get the cluster number from the attributes, subract 1 because if we clusterInstance we get 0-n, and this is 1-n 172 //cnumber = Integer.parseInt(ctrain.get(j).stringValue(ctrain.get(j).numAttributes()-1).replace("cluster", "")) - 1; 173 174 int cnumber = clusterer.clusterInstance(ctrain.get(j)); 175 // add training data to list of instances for this cluster number 176 if ( !ctraindata.containsKey(cnumber) ) { 177 ctraindata.put(cnumber, new Instances(traindata)); 178 ctraindata.get(cnumber).delete(); 179 } 180 ctraindata.get(cnumber).add(traindata.get(j)); 181 } 182 183 for( Entry<Integer,Instances> entry : ctraindata.entrySet() ) { 184 Instances instances = entry.getValue(); 185 int[] counts = instances.attributeStats(instances.classIndex()).nominalCounts; 186 for( int count : counts ) { 187 sufficientInstancesInEachCluster &= count>0; 188 } 189 sufficientInstancesInEachCluster &= instances.numInstances()>=5; 190 } 191 maxNumClusters = clusterer.numberOfClusters()-1; 192 } while(!sufficientInstancesInEachCluster); 193 194 // train one classifier per cluster, we get the cluster number from the training data 195 Iterator<Integer> clusternumber = ctraindata.keySet().iterator(); 196 while ( clusternumber.hasNext() ) { 197 int cnumber = clusternumber.next(); 198 cclassifier.put(cnumber,setupClassifier()); 199 cclassifier.get(cnumber).buildClassifier(ctraindata.get(cnumber)); 200 201 //Console.traceln(Level.INFO, String.format("classifier in cluster "+cnumber)); 202 } 203 } 204 } 52 private final TraindatasetCluster classifier = new TraindatasetCluster(); 53 54 @Override 55 public Classifier getClassifier() { 56 return classifier; 57 } 58 59 @Override 60 public void apply(Instances traindata) { 61 PrintStream errStr = System.err; 62 System.setErr(new PrintStream(new NullOutputStream())); 63 try { 64 classifier.buildClassifier(traindata); 65 } 66 catch (Exception e) { 67 throw new RuntimeException(e); 68 } 69 finally { 70 System.setErr(errStr); 71 } 72 } 73 74 public class TraindatasetCluster extends AbstractClassifier { 75 76 private static final long serialVersionUID = 1L; 77 78 private EM clusterer = null; 79 80 private HashMap<Integer, Classifier> cclassifier; 81 private HashMap<Integer, Instances> ctraindata; 82 83 /** 84 * Helper method that gives us a clean instance copy with the values of the instancelist of 85 * the first parameter. 86 * 87 * @param instancelist 88 * with attributes 89 * @param instance 90 * with only values 91 * @return copy of the instance 92 */ 93 private Instance createInstance(Instances instances, Instance instance) { 94 // attributes for feeding instance to classifier 95 Set<String> attributeNames = new HashSet<>(); 96 for (int j = 0; j < instances.numAttributes(); j++) { 97 attributeNames.add(instances.attribute(j).name()); 98 } 99 100 double[] values = new double[instances.numAttributes()]; 101 int index = 0; 102 for (int j = 0; j < instance.numAttributes(); j++) { 103 if (attributeNames.contains(instance.attribute(j).name())) { 104 values[index] = instance.value(j); 105 index++; 106 } 107 } 108 109 Instances tmp = new Instances(instances); 110 tmp.clear(); 111 Instance instCopy = new DenseInstance(instance.weight(), values); 112 instCopy.setDataset(tmp); 113 114 return instCopy; 115 } 116 117 @Override 118 public double classifyInstance(Instance instance) { 119 double ret = 0; 120 try { 121 // 1. copy the instance (keep the class attribute) 122 Instances traindata = ctraindata.get(0); 123 Instance classInstance = createInstance(traindata, instance); 124 125 // 2. remove class attribute before clustering 126 Remove filter = new Remove(); 127 filter.setAttributeIndices("" + (traindata.classIndex() + 1)); 128 filter.setInputFormat(traindata); 129 traindata = Filter.useFilter(traindata, filter); 130 131 // 3. copy the instance (without the class attribute) for clustering 132 Instance clusterInstance = createInstance(traindata, instance); 133 134 // 4. match instance without class attribute to a cluster number 135 int cnum = clusterer.clusterInstance(clusterInstance); 136 137 // 5. classify instance with class attribute to the classifier of that cluster 138 // number 139 ret = cclassifier.get(cnum).classifyInstance(classInstance); 140 141 } 142 catch (Exception e) { 143 Console.traceln(Level.INFO, String.format("ERROR matching instance to cluster!")); 144 throw new RuntimeException(e); 145 } 146 return ret; 147 } 148 149 @Override 150 public void buildClassifier(Instances traindata) throws Exception { 151 152 // 1. copy training data 153 Instances train = new Instances(traindata); 154 155 // 2. remove class attribute for clustering 156 Remove filter = new Remove(); 157 filter.setAttributeIndices("" + (train.classIndex() + 1)); 158 filter.setInputFormat(train); 159 train = Filter.useFilter(train, filter); 160 161 // new objects 162 cclassifier = new HashMap<Integer, Classifier>(); 163 ctraindata = new HashMap<Integer, Instances>(); 164 165 Instances ctrain; 166 int maxNumClusters = train.size(); 167 boolean sufficientInstancesInEachCluster; 168 do { // while(onlyTarget) 169 sufficientInstancesInEachCluster = true; 170 clusterer = new EM(); 171 clusterer.setMaximumNumberOfClusters(maxNumClusters); 172 clusterer.buildClusterer(train); 173 174 // 4. get cluster membership of our traindata 175 // AddCluster cfilter = new AddCluster(); 176 // cfilter.setClusterer(clusterer); 177 // cfilter.setInputFormat(train); 178 // Instances ctrain = Filter.useFilter(train, cfilter); 179 180 ctrain = new Instances(train); 181 ctraindata = new HashMap<>(); 182 183 // get traindata per cluster 184 for (int j = 0; j < ctrain.numInstances(); j++) { 185 // get the cluster number from the attributes, subract 1 because if we 186 // clusterInstance we get 0-n, and this is 1-n 187 // cnumber = 188 // Integer.parseInt(ctrain.get(j).stringValue(ctrain.get(j).numAttributes()-1).replace("cluster", 189 // "")) - 1; 190 191 int cnumber = clusterer.clusterInstance(ctrain.get(j)); 192 // add training data to list of instances for this cluster number 193 if (!ctraindata.containsKey(cnumber)) { 194 ctraindata.put(cnumber, new Instances(traindata)); 195 ctraindata.get(cnumber).delete(); 196 } 197 ctraindata.get(cnumber).add(traindata.get(j)); 198 } 199 200 for (Entry<Integer, Instances> entry : ctraindata.entrySet()) { 201 Instances instances = entry.getValue(); 202 int[] counts = instances.attributeStats(instances.classIndex()).nominalCounts; 203 for (int count : counts) { 204 sufficientInstancesInEachCluster &= count > 0; 205 } 206 sufficientInstancesInEachCluster &= instances.numInstances() >= 5; 207 } 208 maxNumClusters = clusterer.numberOfClusters() - 1; 209 } 210 while (!sufficientInstancesInEachCluster); 211 212 // train one classifier per cluster, we get the cluster number from the training data 213 Iterator<Integer> clusternumber = ctraindata.keySet().iterator(); 214 while (clusternumber.hasNext()) { 215 int cnumber = clusternumber.next(); 216 cclassifier.put(cnumber, setupClassifier()); 217 cclassifier.get(cnumber).buildClassifier(ctraindata.get(cnumber)); 218 219 // Console.traceln(Level.INFO, String.format("classifier in cluster "+cnumber)); 220 } 221 } 222 } 205 223 } -
trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaLocalFQTraining.java
r25 r41 1 // Copyright 2015 Georg-August-Universität Göttingen, Germany 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 1 15 package de.ugoe.cs.cpdp.training; 2 16 … … 24 38 25 39 /** 26 * Trainer with reimplementation of WHERE clustering algorithm from: 27 * Tim Menzies, Andrew Butcher, David Cok, Andrian Marcus, Lucas Layman, 28 * Forrest Shull, Burak Turhan, Thomas Zimmermann, 29 * "Local versus Global Lessons for Defect Prediction and Effort Estimation," 30 * IEEE Transactions on Software Engineering, vol. 39, no. 6, pp. 822-834, June, 2013 40 * Trainer with reimplementation of WHERE clustering algorithm from: Tim Menzies, Andrew Butcher, 41 * David Cok, Andrian Marcus, Lucas Layman, Forrest Shull, Burak Turhan, Thomas Zimmermann, 42 * "Local versus Global Lessons for Defect Prediction and Effort Estimation," IEEE Transactions on 43 * Software Engineering, vol. 39, no. 6, pp. 822-834, June, 2013 31 44 * 32 * With WekaLocalFQTraining we do the following: 33 * 1) Run the Fastmap algorithm on all training data, let it calculate the 2 most significant34 * dimensions and projections of each instance to these dimensions35 * 2) With these 2 dimensions we span a QuadTree which gets recursively split on median(x) and median(y) values.36 * 3) We cluster the QuadTree nodes together if they have similar density (50%)37 * 4) We save the clusters and their training data38 * 5) We only use clusters with > ALPHA instances (currently Math.sqrt(SIZE)), rest is discarded with the training data of this cluster39 * 6) We train a Weka classifier for each cluster with the clusters training data40 * 7) We recalculate Fastmap distances for a single instance with the old pivots and then try to find a cluster containing the coords of the instance.41 * 7.1.) If we can not find a cluster (due to coords outside of all clusters) we find the nearest cluster.42 * 8) We classify the Instance with theclassifier and traindata from the Cluster we found in 7.45 * With WekaLocalFQTraining we do the following: 1) Run the Fastmap algorithm on all training data, 46 * let it calculate the 2 most significant dimensions and projections of each instance to these 47 * dimensions 2) With these 2 dimensions we span a QuadTree which gets recursively split on 48 * median(x) and median(y) values. 3) We cluster the QuadTree nodes together if they have similar 49 * density (50%) 4) We save the clusters and their training data 5) We only use clusters with > 50 * ALPHA instances (currently Math.sqrt(SIZE)), rest is discarded with the training data of this 51 * cluster 6) We train a Weka classifier for each cluster with the clusters training data 7) We 52 * recalculate Fastmap distances for a single instance with the old pivots and then try to find a 53 * cluster containing the coords of the instance. 7.1.) If we can not find a cluster (due to coords 54 * outside of all clusters) we find the nearest cluster. 8) We classify the Instance with the 55 * classifier and traindata from the Cluster we found in 7. 43 56 */ 44 57 public class WekaLocalFQTraining extends WekaBaseTraining implements ITrainingStrategy { 45 46 private final TraindatasetCluster classifier = new TraindatasetCluster(); 47 48 @Override 49 public Classifier getClassifier() { 50 return classifier; 51 } 52 53 @Override 54 public void apply(Instances traindata) { 55 PrintStream errStr = System.err; 56 System.setErr(new PrintStream(new NullOutputStream())); 57 try { 58 classifier.buildClassifier(traindata); 59 } catch (Exception e) { 60 throw new RuntimeException(e); 61 } finally { 62 System.setErr(errStr); 63 } 64 } 65 66 67 public class TraindatasetCluster extends AbstractClassifier { 68 69 private static final long serialVersionUID = 1L; 70 71 /* classifier per cluster */ 72 private HashMap<Integer, Classifier> cclassifier; 73 74 /* instances per cluster */ 75 private HashMap<Integer, Instances> ctraindata; 76 77 /* holds the instances and indices of the pivot objects of the Fastmap calculation in buildClassifier*/ 78 private HashMap<Integer, Instance> cpivots; 79 80 /* holds the indices of the pivot objects for x,y and the dimension [x,y][dimension]*/ 81 private int[][] cpivotindices; 82 83 /* holds the sizes of the cluster multiple "boxes" per cluster */ 84 private HashMap<Integer, ArrayList<Double[][]>> csize; 85 86 /* debug vars */ 87 @SuppressWarnings("unused") 88 private boolean show_biggest = true; 89 90 @SuppressWarnings("unused") 91 private int CFOUND = 0; 92 @SuppressWarnings("unused") 93 private int CNOTFOUND = 0; 94 95 96 private Instance createInstance(Instances instances, Instance instance) { 97 // attributes for feeding instance to classifier 98 Set<String> attributeNames = new HashSet<>(); 99 for( int j=0; j<instances.numAttributes(); j++ ) { 100 attributeNames.add(instances.attribute(j).name()); 101 } 102 103 double[] values = new double[instances.numAttributes()]; 104 int index = 0; 105 for( int j=0; j<instance.numAttributes(); j++ ) { 106 if( attributeNames.contains(instance.attribute(j).name())) { 107 values[index] = instance.value(j); 108 index++; 109 } 110 } 111 112 Instances tmp = new Instances(instances); 113 tmp.clear(); 114 Instance instCopy = new DenseInstance(instance.weight(), values); 115 instCopy.setDataset(tmp); 116 117 return instCopy; 118 } 119 120 /** 121 * Because Fastmap saves only the image not the values of the attributes it used 122 * we can not use the old data directly to classify single instances to clusters. 123 * 124 * To classify a single instance we do a new fastmap computation with only the instance and 125 * the old pivot elements. 126 * 127 * After that we find the cluster with our fastmap result for x and y. 128 */ 129 @Override 130 public double classifyInstance(Instance instance) { 131 132 double ret = 0; 133 try { 134 // classinstance gets passed to classifier 135 Instances traindata = ctraindata.get(0); 136 Instance classInstance = createInstance(traindata, instance); 137 138 // this one keeps the class attribute 139 Instances traindata2 = ctraindata.get(1); 140 141 // remove class attribute before clustering 142 Remove filter = new Remove(); 143 filter.setAttributeIndices("" + (traindata.classIndex() + 1)); 144 filter.setInputFormat(traindata); 145 traindata = Filter.useFilter(traindata, filter); 146 Instance clusterInstance = createInstance(traindata, instance); 147 148 Fastmap FMAP = new Fastmap(2); 149 EuclideanDistance dist = new EuclideanDistance(traindata); 150 151 // we set our pivot indices [x=0,y=1][dimension] 152 int[][] npivotindices = new int[2][2]; 153 npivotindices[0][0] = 1; 154 npivotindices[1][0] = 2; 155 npivotindices[0][1] = 3; 156 npivotindices[1][1] = 4; 157 158 // build temp dist matrix (2 pivots per dimension + 1 instance we want to classify) 159 // the instance we want to classify comes first after that the pivot elements in the order defined above 160 double[][] distmat = new double[2*FMAP.target_dims+1][2*FMAP.target_dims+1]; 161 distmat[0][0] = 0; 162 distmat[0][1] = dist.distance(clusterInstance, this.cpivots.get((Integer)this.cpivotindices[0][0])); 163 distmat[0][2] = dist.distance(clusterInstance, this.cpivots.get((Integer)this.cpivotindices[1][0])); 164 distmat[0][3] = dist.distance(clusterInstance, this.cpivots.get((Integer)this.cpivotindices[0][1])); 165 distmat[0][4] = dist.distance(clusterInstance, this.cpivots.get((Integer)this.cpivotindices[1][1])); 166 167 distmat[1][0] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][0]), clusterInstance); 168 distmat[1][1] = 0; 169 distmat[1][2] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][0]), this.cpivots.get((Integer)this.cpivotindices[1][0])); 170 distmat[1][3] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][0]), this.cpivots.get((Integer)this.cpivotindices[0][1])); 171 distmat[1][4] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][0]), this.cpivots.get((Integer)this.cpivotindices[1][1])); 172 173 distmat[2][0] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][0]), clusterInstance); 174 distmat[2][1] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][0]), this.cpivots.get((Integer)this.cpivotindices[0][0])); 175 distmat[2][2] = 0; 176 distmat[2][3] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][0]), this.cpivots.get((Integer)this.cpivotindices[0][1])); 177 distmat[2][4] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][0]), this.cpivots.get((Integer)this.cpivotindices[1][1])); 178 179 distmat[3][0] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][1]), clusterInstance); 180 distmat[3][1] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][1]), this.cpivots.get((Integer)this.cpivotindices[0][0])); 181 distmat[3][2] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][1]), this.cpivots.get((Integer)this.cpivotindices[1][0])); 182 distmat[3][3] = 0; 183 distmat[3][4] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][1]), this.cpivots.get((Integer)this.cpivotindices[1][1])); 184 185 distmat[4][0] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][1]), clusterInstance); 186 distmat[4][1] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][1]), this.cpivots.get((Integer)this.cpivotindices[0][0])); 187 distmat[4][2] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][1]), this.cpivots.get((Integer)this.cpivotindices[1][0])); 188 distmat[4][3] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][1]), this.cpivots.get((Integer)this.cpivotindices[0][1])); 189 distmat[4][4] = 0; 190 191 192 /* debug output: show biggest distance found within the new distance matrix 193 double biggest = 0; 194 for(int i=0; i < distmat.length; i++) { 195 for(int j=0; j < distmat[0].length; j++) { 196 if(biggest < distmat[i][j]) { 197 biggest = distmat[i][j]; 198 } 199 } 200 } 201 if(this.show_biggest) { 202 Console.traceln(Level.INFO, String.format(""+clusterInstance)); 203 Console.traceln(Level.INFO, String.format("biggest distances: "+ biggest)); 204 this.show_biggest = false; 205 } 206 */ 207 208 FMAP.setDistmat(distmat); 209 FMAP.setPivots(npivotindices); 210 FMAP.calculate(); 211 double[][] x = FMAP.getX(); 212 double[] proj = x[0]; 213 214 // debug output: show the calculated distance matrix, our result vektor for the instance and the complete result matrix 215 /* 216 Console.traceln(Level.INFO, "distmat:"); 217 for(int i=0; i<distmat.length; i++){ 218 for(int j=0; j<distmat[0].length; j++){ 219 Console.trace(Level.INFO, String.format("%20s", distmat[i][j])); 220 } 221 Console.traceln(Level.INFO, ""); 222 } 223 224 Console.traceln(Level.INFO, "vector:"); 225 for(int i=0; i < proj.length; i++) { 226 Console.trace(Level.INFO, String.format("%20s", proj[i])); 227 } 228 Console.traceln(Level.INFO, ""); 229 230 Console.traceln(Level.INFO, "resultmat:"); 231 for(int i=0; i<x.length; i++){ 232 for(int j=0; j<x[0].length; j++){ 233 Console.trace(Level.INFO, String.format("%20s", x[i][j])); 234 } 235 Console.traceln(Level.INFO, ""); 236 } 237 */ 238 239 // now we iterate over all clusters (well, boxes of sizes per cluster really) and save the number of the 240 // cluster in which we are 241 int cnumber; 242 int found_cnumber = -1; 243 Iterator<Integer> clusternumber = this.csize.keySet().iterator(); 244 while ( clusternumber.hasNext() && found_cnumber == -1) { 245 cnumber = clusternumber.next(); 246 247 // now iterate over the boxes of the cluster and hope we find one (cluster could have been removed) 248 // or we are too far away from any cluster because of the fastmap calculation with the initial pivot objects 249 for ( int box=0; box < this.csize.get(cnumber).size(); box++ ) { 250 Double[][] current = this.csize.get(cnumber).get(box); 251 252 if(proj[0] >= current[0][0] && proj[0] <= current[0][1] && // x 253 proj[1] >= current[1][0] && proj[1] <= current[1][1]) { // y 254 found_cnumber = cnumber; 255 } 256 } 257 } 258 259 // we want to count how often we are really inside a cluster 260 //if ( found_cnumber == -1 ) { 261 // CNOTFOUND += 1; 262 //}else { 263 // CFOUND += 1; 264 //} 265 266 // now it can happen that we do not find a cluster because we deleted it previously (too few instances) 267 // or we get bigger distance measures from weka so that we are completely outside of our clusters. 268 // in these cases we just find the nearest cluster to our instance and use it for classification. 269 // to do that we use the EuclideanDistance again to compare our distance to all other Instances 270 // then we take the cluster of the closest weka instance 271 dist = new EuclideanDistance(traindata2); 272 if( !this.ctraindata.containsKey(found_cnumber) ) { 273 double min_distance = Double.MAX_VALUE; 274 clusternumber = ctraindata.keySet().iterator(); 275 while ( clusternumber.hasNext() ) { 276 cnumber = clusternumber.next(); 277 for(int i=0; i < ctraindata.get(cnumber).size(); i++) { 278 if(dist.distance(instance, ctraindata.get(cnumber).get(i)) <= min_distance) { 279 found_cnumber = cnumber; 280 min_distance = dist.distance(instance, ctraindata.get(cnumber).get(i)); 281 } 282 } 283 } 284 } 285 286 // here we have the cluster where an instance has the minimum distance between itself and the 287 // instance we want to classify 288 // if we still have not found a cluster we exit because something is really wrong 289 if( found_cnumber == -1 ) { 290 Console.traceln(Level.INFO, String.format("ERROR matching instance to cluster with full search!")); 291 throw new RuntimeException("cluster not found with full search"); 292 } 293 294 // classify the passed instance with the cluster we found and its training data 295 ret = cclassifier.get(found_cnumber).classifyInstance(classInstance); 296 297 }catch( Exception e ) { 298 Console.traceln(Level.INFO, String.format("ERROR matching instance to cluster!")); 299 throw new RuntimeException(e); 300 } 301 return ret; 302 } 303 304 @Override 305 public void buildClassifier(Instances traindata) throws Exception { 306 307 //Console.traceln(Level.INFO, String.format("found: "+ CFOUND + ", notfound: " + CNOTFOUND)); 308 this.show_biggest = true; 309 310 cclassifier = new HashMap<Integer, Classifier>(); 311 ctraindata = new HashMap<Integer, Instances>(); 312 cpivots = new HashMap<Integer, Instance>(); 313 cpivotindices = new int[2][2]; 314 315 // 1. copy traindata 316 Instances train = new Instances(traindata); 317 Instances train2 = new Instances(traindata); // this one keeps the class attribute 318 319 // 2. remove class attribute for clustering 320 Remove filter = new Remove(); 321 filter.setAttributeIndices("" + (train.classIndex() + 1)); 322 filter.setInputFormat(train); 323 train = Filter.useFilter(train, filter); 324 325 // 3. calculate distance matrix (needed for Fastmap because it starts at dimension 1) 326 double biggest = 0; 327 EuclideanDistance dist = new EuclideanDistance(train); 328 double[][] distmat = new double[train.size()][train.size()]; 329 for( int i=0; i < train.size(); i++ ) { 330 for( int j=0; j < train.size(); j++ ) { 331 distmat[i][j] = dist.distance(train.get(i), train.get(j)); 332 if( distmat[i][j] > biggest ) { 333 biggest = distmat[i][j]; 334 } 335 } 336 } 337 //Console.traceln(Level.INFO, String.format("biggest distances: "+ biggest)); 338 339 // 4. run fastmap for 2 dimensions on the distance matrix 340 Fastmap FMAP = new Fastmap(2); 341 FMAP.setDistmat(distmat); 342 FMAP.calculate(); 343 344 cpivotindices = FMAP.getPivots(); 345 346 double[][] X = FMAP.getX(); 347 distmat = new double[0][0]; 348 System.gc(); 349 350 // quadtree payload generation 351 ArrayList<QuadTreePayload<Instance>> qtp = new ArrayList<QuadTreePayload<Instance>>(); 352 353 // we need these for the sizes of the quadrants 354 double[] big = {0,0}; 355 double[] small = {Double.MAX_VALUE,Double.MAX_VALUE}; 356 357 // set quadtree payload values and get max and min x and y values for size 358 for( int i=0; i<X.length; i++ ){ 359 if(X[i][0] >= big[0]) { 360 big[0] = X[i][0]; 361 } 362 if(X[i][1] >= big[1]) { 363 big[1] = X[i][1]; 364 } 365 if(X[i][0] <= small[0]) { 366 small[0] = X[i][0]; 367 } 368 if(X[i][1] <= small[1]) { 369 small[1] = X[i][1]; 370 } 371 QuadTreePayload<Instance> tmp = new QuadTreePayload<Instance>(X[i][0], X[i][1], train2.get(i)); 372 qtp.add(tmp); 373 } 374 375 //Console.traceln(Level.INFO, String.format("size for cluster ("+small[0]+","+small[1]+") - ("+big[0]+","+big[1]+")")); 376 377 // 5. generate quadtree 378 QuadTree TREE = new QuadTree(null, qtp); 379 QuadTree.size = train.size(); 380 QuadTree.alpha = Math.sqrt(train.size()); 381 QuadTree.ccluster = new ArrayList<ArrayList<QuadTreePayload<Instance>>>(); 382 QuadTree.csize = new HashMap<Integer, ArrayList<Double[][]>>(); 383 384 //Console.traceln(Level.INFO, String.format("Generate QuadTree with "+ QuadTree.size + " size, Alpha: "+ QuadTree.alpha+ "")); 385 386 // set the size and then split the tree recursively at the median value for x, y 387 TREE.setSize(new double[] {small[0], big[0]}, new double[] {small[1], big[1]}); 388 389 // recursive split und grid clustering eher static 390 TREE.recursiveSplit(TREE); 391 392 // generate list of nodes sorted by density (childs only) 393 ArrayList<QuadTree> l = new ArrayList<QuadTree>(TREE.getList(TREE)); 394 395 // recursive grid clustering (tree pruning), the values are stored in ccluster 396 TREE.gridClustering(l); 397 398 // wir iterieren durch die cluster und sammeln uns die instanzen daraus 399 //ctraindata.clear(); 400 for( int i=0; i < QuadTree.ccluster.size(); i++ ) { 401 ArrayList<QuadTreePayload<Instance>> current = QuadTree.ccluster.get(i); 402 403 // i is the clusternumber 404 // we only allow clusters with Instances > ALPHA, other clusters are not considered! 405 //if(current.size() > QuadTree.alpha) { 406 if( current.size() > 4 ) { 407 for( int j=0; j < current.size(); j++ ) { 408 if( !ctraindata.containsKey(i) ) { 409 ctraindata.put(i, new Instances(train2)); 410 ctraindata.get(i).delete(); 411 } 412 ctraindata.get(i).add(current.get(j).getInst()); 413 } 414 }else{ 415 Console.traceln(Level.INFO, String.format("drop cluster, only: " + current.size() + " instances")); 416 } 417 } 418 419 // here we keep things we need later on 420 // QuadTree sizes for later use (matching new instances) 421 this.csize = new HashMap<Integer, ArrayList<Double[][]>>(QuadTree.csize); 422 423 // pivot elements 424 //this.cpivots.clear(); 425 for( int i=0; i < FMAP.PA[0].length; i++ ) { 426 this.cpivots.put(FMAP.PA[0][i], (Instance)train.get(FMAP.PA[0][i]).copy()); 427 } 428 for( int j=0; j < FMAP.PA[0].length; j++ ) { 429 this.cpivots.put(FMAP.PA[1][j], (Instance)train.get(FMAP.PA[1][j]).copy()); 430 } 431 432 433 /* debug output 434 int pnumber; 435 Iterator<Integer> pivotnumber = cpivots.keySet().iterator(); 436 while ( pivotnumber.hasNext() ) { 437 pnumber = pivotnumber.next(); 438 Console.traceln(Level.INFO, String.format("pivot: "+pnumber+ " inst: "+cpivots.get(pnumber))); 439 } 440 */ 441 442 // train one classifier per cluster, we get the cluster number from the traindata 443 int cnumber; 444 Iterator<Integer> clusternumber = ctraindata.keySet().iterator(); 445 //cclassifier.clear(); 446 447 //int traindata_count = 0; 448 while ( clusternumber.hasNext() ) { 449 cnumber = clusternumber.next(); 450 cclassifier.put(cnumber,setupClassifier()); // this is the classifier used for the cluster 451 cclassifier.get(cnumber).buildClassifier(ctraindata.get(cnumber)); 452 //Console.traceln(Level.INFO, String.format("classifier in cluster "+cnumber)); 453 //traindata_count += ctraindata.get(cnumber).size(); 454 //Console.traceln(Level.INFO, String.format("building classifier in cluster "+cnumber +" with "+ ctraindata.get(cnumber).size() +" traindata instances")); 455 } 456 457 // add all traindata 458 //Console.traceln(Level.INFO, String.format("traindata in all clusters: " + traindata_count)); 459 } 460 } 461 462 463 /** 464 * Payload for the QuadTree. 465 * x and y are the calculated Fastmap values. 466 * T is a weka instance. 467 */ 468 public class QuadTreePayload<T> { 469 470 public double x; 471 public double y; 472 private T inst; 473 474 public QuadTreePayload(double x, double y, T value) { 475 this.x = x; 476 this.y = y; 477 this.inst = value; 478 } 479 480 public T getInst() { 481 return this.inst; 482 } 483 } 484 485 486 /** 487 * Fastmap implementation 488 * 489 * Faloutsos, C., & Lin, K. I. (1995). 490 * FastMap: A fast algorithm for indexing, data-mining and visualization of traditional and multimedia datasets 491 * (Vol. 24, No. 2, pp. 163-174). ACM. 492 */ 493 public class Fastmap { 494 495 /*N x k Array, at the end, the i-th row will be the image of the i-th object*/ 496 private double[][] X; 497 498 /*2 x k pivot Array one pair per recursive call*/ 499 private int[][] PA; 500 501 /*Objects we got (distance matrix)*/ 502 private double[][] O; 503 504 /*column of X currently updated (also the dimension)*/ 505 private int col = 0; 506 507 /*number of dimensions we want*/ 508 private int target_dims = 0; 509 510 // if we already have the pivot elements 511 private boolean pivot_set = false; 512 513 514 public Fastmap(int k) { 515 this.target_dims = k; 516 } 517 518 /** 519 * Sets the distance matrix 520 * and params that depend on this 521 * @param O 522 */ 523 public void setDistmat(double[][] O) { 524 this.O = O; 525 int N = O.length; 526 this.X = new double[N][this.target_dims]; 527 this.PA = new int[2][this.target_dims]; 528 } 529 530 /** 531 * Set pivot elements, we need that to classify instances 532 * after the calculation is complete (because we then want to reuse 533 * only the pivot elements). 534 * 535 * @param pi 536 */ 537 public void setPivots(int[][] pi) { 538 this.pivot_set = true; 539 this.PA = pi; 540 } 541 542 /** 543 * Return the pivot elements that were chosen during the calculation 544 * 545 * @return 546 */ 547 public int[][] getPivots() { 548 return this.PA; 549 } 550 551 /** 552 * The distance function for euclidean distance 553 * 554 * Acts according to equation 4 of the fastmap paper 555 * 556 * @param x x index of x image (if k==0 x object) 557 * @param y y index of y image (if k==0 y object) 558 * @param kdimensionality 559 * @return distance 560 */ 561 private double dist(int x, int y, int k) { 562 563 // basis is object distance, we get this from our distance matrix 564 double tmp = this.O[x][y] * this.O[x][y]; 565 566 // decrease by projections 567 for( int i=0; i < k; i++ ) { 568 double tmp2 = (this.X[x][i] - this.X[y][i]); 569 tmp -= tmp2 * tmp2; 570 } 571 572 return Math.abs(tmp); 573 } 574 575 /** 576 * Find the object farthest from the given index 577 * This method is a helper Method for findDistandObjects 578 * 579 * @param index of the object 580 * @return index of the farthest object from the given index 581 */ 582 private int findFarthest(int index) { 583 double furthest = Double.MIN_VALUE; 584 int ret = 0; 585 586 for( int i=0; i < O.length; i++ ) { 587 double dist = this.dist(i, index, this.col); 588 if( i != index && dist > furthest ) { 589 furthest = dist; 590 ret = i; 591 } 592 } 593 return ret; 594 } 595 596 /** 597 * Finds the pivot objects 598 * 599 * This method is basically algorithm 1 of the fastmap paper. 600 * 601 * @return 2 indexes of the choosen pivot objects 602 */ 603 private int[] findDistantObjects() { 604 // 1. choose object randomly 605 Random r = new Random(); 606 int obj = r.nextInt(this.O.length); 607 608 // 2. find farthest object from randomly chosen object 609 int idx1 = this.findFarthest(obj); 610 611 // 3. find farthest object from previously farthest object 612 int idx2 = this.findFarthest(idx1); 613 614 return new int[] {idx1, idx2}; 615 } 616 617 /** 618 * Calculates the new k-vector values (projections) 619 * 620 * This is basically algorithm 2 of the fastmap paper. 621 * We just added the possibility to pre-set the pivot elements because 622 * we need to classify single instances after the computation is already done. 623 * 624 * @param dims dimensionality 625 */ 626 public void calculate() { 627 628 for( int k=0; k < this.target_dims; k++ ) { 629 // 2) choose pivot objects 630 if ( !this.pivot_set ) { 631 int[] pivots = this.findDistantObjects(); 632 633 // 3) record ids of pivot objects 634 this.PA[0][this.col] = pivots[0]; 635 this.PA[1][this.col] = pivots[1]; 636 } 637 638 // 4) inter object distances are zero (this.X is initialized with 0 so we just continue) 639 if( this.dist(this.PA[0][this.col], this.PA[1][this.col], this.col) == 0 ) { 640 continue; 641 } 642 643 // 5) project the objects on the line between the pivots 644 double dxy = this.dist(this.PA[0][this.col], this.PA[1][this.col], this.col); 645 for( int i=0; i < this.O.length; i++ ) { 646 647 double dix = this.dist(i, this.PA[0][this.col], this.col); 648 double diy = this.dist(i, this.PA[1][this.col], this.col); 649 650 double tmp = (dix + dxy - diy) / (2 * Math.sqrt(dxy)); 651 652 // save the projection 653 this.X[i][this.col] = tmp; 654 } 655 656 this.col += 1; 657 } 658 } 659 660 /** 661 * returns the result matrix of the projections 662 * 663 * @return calculated result 664 */ 665 public double[][] getX() { 666 return this.X; 667 } 668 } 58 59 private final TraindatasetCluster classifier = new TraindatasetCluster(); 60 61 @Override 62 public Classifier getClassifier() { 63 return classifier; 64 } 65 66 @Override 67 public void apply(Instances traindata) { 68 PrintStream errStr = System.err; 69 System.setErr(new PrintStream(new NullOutputStream())); 70 try { 71 classifier.buildClassifier(traindata); 72 } 73 catch (Exception e) { 74 throw new RuntimeException(e); 75 } 76 finally { 77 System.setErr(errStr); 78 } 79 } 80 81 public class TraindatasetCluster extends AbstractClassifier { 82 83 private static final long serialVersionUID = 1L; 84 85 /* classifier per cluster */ 86 private HashMap<Integer, Classifier> cclassifier; 87 88 /* instances per cluster */ 89 private HashMap<Integer, Instances> ctraindata; 90 91 /* 92 * holds the instances and indices of the pivot objects of the Fastmap calculation in 93 * buildClassifier 94 */ 95 private HashMap<Integer, Instance> cpivots; 96 97 /* holds the indices of the pivot objects for x,y and the dimension [x,y][dimension] */ 98 private int[][] cpivotindices; 99 100 /* holds the sizes of the cluster multiple "boxes" per cluster */ 101 private HashMap<Integer, ArrayList<Double[][]>> csize; 102 103 /* debug vars */ 104 @SuppressWarnings("unused") 105 private boolean show_biggest = true; 106 107 @SuppressWarnings("unused") 108 private int CFOUND = 0; 109 @SuppressWarnings("unused") 110 private int CNOTFOUND = 0; 111 112 private Instance createInstance(Instances instances, Instance instance) { 113 // attributes for feeding instance to classifier 114 Set<String> attributeNames = new HashSet<>(); 115 for (int j = 0; j < instances.numAttributes(); j++) { 116 attributeNames.add(instances.attribute(j).name()); 117 } 118 119 double[] values = new double[instances.numAttributes()]; 120 int index = 0; 121 for (int j = 0; j < instance.numAttributes(); j++) { 122 if (attributeNames.contains(instance.attribute(j).name())) { 123 values[index] = instance.value(j); 124 index++; 125 } 126 } 127 128 Instances tmp = new Instances(instances); 129 tmp.clear(); 130 Instance instCopy = new DenseInstance(instance.weight(), values); 131 instCopy.setDataset(tmp); 132 133 return instCopy; 134 } 135 136 /** 137 * Because Fastmap saves only the image not the values of the attributes it used we can not 138 * use the old data directly to classify single instances to clusters. 139 * 140 * To classify a single instance we do a new fastmap computation with only the instance and 141 * the old pivot elements. 142 * 143 * After that we find the cluster with our fastmap result for x and y. 144 */ 145 @Override 146 public double classifyInstance(Instance instance) { 147 148 double ret = 0; 149 try { 150 // classinstance gets passed to classifier 151 Instances traindata = ctraindata.get(0); 152 Instance classInstance = createInstance(traindata, instance); 153 154 // this one keeps the class attribute 155 Instances traindata2 = ctraindata.get(1); 156 157 // remove class attribute before clustering 158 Remove filter = new Remove(); 159 filter.setAttributeIndices("" + (traindata.classIndex() + 1)); 160 filter.setInputFormat(traindata); 161 traindata = Filter.useFilter(traindata, filter); 162 Instance clusterInstance = createInstance(traindata, instance); 163 164 Fastmap FMAP = new Fastmap(2); 165 EuclideanDistance dist = new EuclideanDistance(traindata); 166 167 // we set our pivot indices [x=0,y=1][dimension] 168 int[][] npivotindices = new int[2][2]; 169 npivotindices[0][0] = 1; 170 npivotindices[1][0] = 2; 171 npivotindices[0][1] = 3; 172 npivotindices[1][1] = 4; 173 174 // build temp dist matrix (2 pivots per dimension + 1 instance we want to classify) 175 // the instance we want to classify comes first after that the pivot elements in the 176 // order defined above 177 double[][] distmat = new double[2 * FMAP.target_dims + 1][2 * FMAP.target_dims + 1]; 178 distmat[0][0] = 0; 179 distmat[0][1] = 180 dist.distance(clusterInstance, 181 this.cpivots.get((Integer) this.cpivotindices[0][0])); 182 distmat[0][2] = 183 dist.distance(clusterInstance, 184 this.cpivots.get((Integer) this.cpivotindices[1][0])); 185 distmat[0][3] = 186 dist.distance(clusterInstance, 187 this.cpivots.get((Integer) this.cpivotindices[0][1])); 188 distmat[0][4] = 189 dist.distance(clusterInstance, 190 this.cpivots.get((Integer) this.cpivotindices[1][1])); 191 192 distmat[1][0] = 193 dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][0]), 194 clusterInstance); 195 distmat[1][1] = 0; 196 distmat[1][2] = 197 dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][0]), 198 this.cpivots.get((Integer) this.cpivotindices[1][0])); 199 distmat[1][3] = 200 dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][0]), 201 this.cpivots.get((Integer) this.cpivotindices[0][1])); 202 distmat[1][4] = 203 dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][0]), 204 this.cpivots.get((Integer) this.cpivotindices[1][1])); 205 206 distmat[2][0] = 207 dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][0]), 208 clusterInstance); 209 distmat[2][1] = 210 dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][0]), 211 this.cpivots.get((Integer) this.cpivotindices[0][0])); 212 distmat[2][2] = 0; 213 distmat[2][3] = 214 dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][0]), 215 this.cpivots.get((Integer) this.cpivotindices[0][1])); 216 distmat[2][4] = 217 dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][0]), 218 this.cpivots.get((Integer) this.cpivotindices[1][1])); 219 220 distmat[3][0] = 221 dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][1]), 222 clusterInstance); 223 distmat[3][1] = 224 dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][1]), 225 this.cpivots.get((Integer) this.cpivotindices[0][0])); 226 distmat[3][2] = 227 dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][1]), 228 this.cpivots.get((Integer) this.cpivotindices[1][0])); 229 distmat[3][3] = 0; 230 distmat[3][4] = 231 dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][1]), 232 this.cpivots.get((Integer) this.cpivotindices[1][1])); 233 234 distmat[4][0] = 235 dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][1]), 236 clusterInstance); 237 distmat[4][1] = 238 dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][1]), 239 this.cpivots.get((Integer) this.cpivotindices[0][0])); 240 distmat[4][2] = 241 dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][1]), 242 this.cpivots.get((Integer) this.cpivotindices[1][0])); 243 distmat[4][3] = 244 dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][1]), 245 this.cpivots.get((Integer) this.cpivotindices[0][1])); 246 distmat[4][4] = 0; 247 248 /* 249 * debug output: show biggest distance found within the new distance matrix double 250 * biggest = 0; for(int i=0; i < distmat.length; i++) { for(int j=0; j < 251 * distmat[0].length; j++) { if(biggest < distmat[i][j]) { biggest = distmat[i][j]; 252 * } } } if(this.show_biggest) { Console.traceln(Level.INFO, 253 * String.format(""+clusterInstance)); Console.traceln(Level.INFO, 254 * String.format("biggest distances: "+ biggest)); this.show_biggest = false; } 255 */ 256 257 FMAP.setDistmat(distmat); 258 FMAP.setPivots(npivotindices); 259 FMAP.calculate(); 260 double[][] x = FMAP.getX(); 261 double[] proj = x[0]; 262 263 // debug output: show the calculated distance matrix, our result vektor for the 264 // instance and the complete result matrix 265 /* 266 * Console.traceln(Level.INFO, "distmat:"); for(int i=0; i<distmat.length; i++){ 267 * for(int j=0; j<distmat[0].length; j++){ Console.trace(Level.INFO, 268 * String.format("%20s", distmat[i][j])); } Console.traceln(Level.INFO, ""); } 269 * 270 * Console.traceln(Level.INFO, "vector:"); for(int i=0; i < proj.length; i++) { 271 * Console.trace(Level.INFO, String.format("%20s", proj[i])); } 272 * Console.traceln(Level.INFO, ""); 273 * 274 * Console.traceln(Level.INFO, "resultmat:"); for(int i=0; i<x.length; i++){ for(int 275 * j=0; j<x[0].length; j++){ Console.trace(Level.INFO, String.format("%20s", 276 * x[i][j])); } Console.traceln(Level.INFO, ""); } 277 */ 278 279 // now we iterate over all clusters (well, boxes of sizes per cluster really) and 280 // save the number of the 281 // cluster in which we are 282 int cnumber; 283 int found_cnumber = -1; 284 Iterator<Integer> clusternumber = this.csize.keySet().iterator(); 285 while (clusternumber.hasNext() && found_cnumber == -1) { 286 cnumber = clusternumber.next(); 287 288 // now iterate over the boxes of the cluster and hope we find one (cluster could 289 // have been removed) 290 // or we are too far away from any cluster because of the fastmap calculation 291 // with the initial pivot objects 292 for (int box = 0; box < this.csize.get(cnumber).size(); box++) { 293 Double[][] current = this.csize.get(cnumber).get(box); 294 295 if (proj[0] >= current[0][0] && proj[0] <= current[0][1] && // x 296 proj[1] >= current[1][0] && proj[1] <= current[1][1]) 297 { // y 298 found_cnumber = cnumber; 299 } 300 } 301 } 302 303 // we want to count how often we are really inside a cluster 304 // if ( found_cnumber == -1 ) { 305 // CNOTFOUND += 1; 306 // }else { 307 // CFOUND += 1; 308 // } 309 310 // now it can happen that we do not find a cluster because we deleted it previously 311 // (too few instances) 312 // or we get bigger distance measures from weka so that we are completely outside of 313 // our clusters. 314 // in these cases we just find the nearest cluster to our instance and use it for 315 // classification. 316 // to do that we use the EuclideanDistance again to compare our distance to all 317 // other Instances 318 // then we take the cluster of the closest weka instance 319 dist = new EuclideanDistance(traindata2); 320 if (!this.ctraindata.containsKey(found_cnumber)) { 321 double min_distance = Double.MAX_VALUE; 322 clusternumber = ctraindata.keySet().iterator(); 323 while (clusternumber.hasNext()) { 324 cnumber = clusternumber.next(); 325 for (int i = 0; i < ctraindata.get(cnumber).size(); i++) { 326 if (dist.distance(instance, ctraindata.get(cnumber).get(i)) <= min_distance) 327 { 328 found_cnumber = cnumber; 329 min_distance = 330 dist.distance(instance, ctraindata.get(cnumber).get(i)); 331 } 332 } 333 } 334 } 335 336 // here we have the cluster where an instance has the minimum distance between 337 // itself and the 338 // instance we want to classify 339 // if we still have not found a cluster we exit because something is really wrong 340 if (found_cnumber == -1) { 341 Console.traceln(Level.INFO, String 342 .format("ERROR matching instance to cluster with full search!")); 343 throw new RuntimeException("cluster not found with full search"); 344 } 345 346 // classify the passed instance with the cluster we found and its training data 347 ret = cclassifier.get(found_cnumber).classifyInstance(classInstance); 348 349 } 350 catch (Exception e) { 351 Console.traceln(Level.INFO, String.format("ERROR matching instance to cluster!")); 352 throw new RuntimeException(e); 353 } 354 return ret; 355 } 356 357 @Override 358 public void buildClassifier(Instances traindata) throws Exception { 359 360 // Console.traceln(Level.INFO, String.format("found: "+ CFOUND + ", notfound: " + 361 // CNOTFOUND)); 362 this.show_biggest = true; 363 364 cclassifier = new HashMap<Integer, Classifier>(); 365 ctraindata = new HashMap<Integer, Instances>(); 366 cpivots = new HashMap<Integer, Instance>(); 367 cpivotindices = new int[2][2]; 368 369 // 1. copy traindata 370 Instances train = new Instances(traindata); 371 Instances train2 = new Instances(traindata); // this one keeps the class attribute 372 373 // 2. remove class attribute for clustering 374 Remove filter = new Remove(); 375 filter.setAttributeIndices("" + (train.classIndex() + 1)); 376 filter.setInputFormat(train); 377 train = Filter.useFilter(train, filter); 378 379 // 3. calculate distance matrix (needed for Fastmap because it starts at dimension 1) 380 double biggest = 0; 381 EuclideanDistance dist = new EuclideanDistance(train); 382 double[][] distmat = new double[train.size()][train.size()]; 383 for (int i = 0; i < train.size(); i++) { 384 for (int j = 0; j < train.size(); j++) { 385 distmat[i][j] = dist.distance(train.get(i), train.get(j)); 386 if (distmat[i][j] > biggest) { 387 biggest = distmat[i][j]; 388 } 389 } 390 } 391 // Console.traceln(Level.INFO, String.format("biggest distances: "+ biggest)); 392 393 // 4. run fastmap for 2 dimensions on the distance matrix 394 Fastmap FMAP = new Fastmap(2); 395 FMAP.setDistmat(distmat); 396 FMAP.calculate(); 397 398 cpivotindices = FMAP.getPivots(); 399 400 double[][] X = FMAP.getX(); 401 distmat = new double[0][0]; 402 System.gc(); 403 404 // quadtree payload generation 405 ArrayList<QuadTreePayload<Instance>> qtp = new ArrayList<QuadTreePayload<Instance>>(); 406 407 // we need these for the sizes of the quadrants 408 double[] big = 409 { 0, 0 }; 410 double[] small = 411 { Double.MAX_VALUE, Double.MAX_VALUE }; 412 413 // set quadtree payload values and get max and min x and y values for size 414 for (int i = 0; i < X.length; i++) { 415 if (X[i][0] >= big[0]) { 416 big[0] = X[i][0]; 417 } 418 if (X[i][1] >= big[1]) { 419 big[1] = X[i][1]; 420 } 421 if (X[i][0] <= small[0]) { 422 small[0] = X[i][0]; 423 } 424 if (X[i][1] <= small[1]) { 425 small[1] = X[i][1]; 426 } 427 QuadTreePayload<Instance> tmp = 428 new QuadTreePayload<Instance>(X[i][0], X[i][1], train2.get(i)); 429 qtp.add(tmp); 430 } 431 432 // Console.traceln(Level.INFO, 433 // String.format("size for cluster ("+small[0]+","+small[1]+") - ("+big[0]+","+big[1]+")")); 434 435 // 5. generate quadtree 436 QuadTree TREE = new QuadTree(null, qtp); 437 QuadTree.size = train.size(); 438 QuadTree.alpha = Math.sqrt(train.size()); 439 QuadTree.ccluster = new ArrayList<ArrayList<QuadTreePayload<Instance>>>(); 440 QuadTree.csize = new HashMap<Integer, ArrayList<Double[][]>>(); 441 442 // Console.traceln(Level.INFO, String.format("Generate QuadTree with "+ QuadTree.size + 443 // " size, Alpha: "+ QuadTree.alpha+ "")); 444 445 // set the size and then split the tree recursively at the median value for x, y 446 TREE.setSize(new double[] 447 { small[0], big[0] }, new double[] 448 { small[1], big[1] }); 449 450 // recursive split und grid clustering eher static 451 TREE.recursiveSplit(TREE); 452 453 // generate list of nodes sorted by density (childs only) 454 ArrayList<QuadTree> l = new ArrayList<QuadTree>(TREE.getList(TREE)); 455 456 // recursive grid clustering (tree pruning), the values are stored in ccluster 457 TREE.gridClustering(l); 458 459 // wir iterieren durch die cluster und sammeln uns die instanzen daraus 460 // ctraindata.clear(); 461 for (int i = 0; i < QuadTree.ccluster.size(); i++) { 462 ArrayList<QuadTreePayload<Instance>> current = QuadTree.ccluster.get(i); 463 464 // i is the clusternumber 465 // we only allow clusters with Instances > ALPHA, other clusters are not considered! 466 // if(current.size() > QuadTree.alpha) { 467 if (current.size() > 4) { 468 for (int j = 0; j < current.size(); j++) { 469 if (!ctraindata.containsKey(i)) { 470 ctraindata.put(i, new Instances(train2)); 471 ctraindata.get(i).delete(); 472 } 473 ctraindata.get(i).add(current.get(j).getInst()); 474 } 475 } 476 else { 477 Console.traceln(Level.INFO, 478 String.format("drop cluster, only: " + current.size() + 479 " instances")); 480 } 481 } 482 483 // here we keep things we need later on 484 // QuadTree sizes for later use (matching new instances) 485 this.csize = new HashMap<Integer, ArrayList<Double[][]>>(QuadTree.csize); 486 487 // pivot elements 488 // this.cpivots.clear(); 489 for (int i = 0; i < FMAP.PA[0].length; i++) { 490 this.cpivots.put(FMAP.PA[0][i], (Instance) train.get(FMAP.PA[0][i]).copy()); 491 } 492 for (int j = 0; j < FMAP.PA[0].length; j++) { 493 this.cpivots.put(FMAP.PA[1][j], (Instance) train.get(FMAP.PA[1][j]).copy()); 494 } 495 496 /* 497 * debug output int pnumber; Iterator<Integer> pivotnumber = 498 * cpivots.keySet().iterator(); while ( pivotnumber.hasNext() ) { pnumber = 499 * pivotnumber.next(); Console.traceln(Level.INFO, String.format("pivot: "+pnumber+ 500 * " inst: "+cpivots.get(pnumber))); } 501 */ 502 503 // train one classifier per cluster, we get the cluster number from the traindata 504 int cnumber; 505 Iterator<Integer> clusternumber = ctraindata.keySet().iterator(); 506 // cclassifier.clear(); 507 508 // int traindata_count = 0; 509 while (clusternumber.hasNext()) { 510 cnumber = clusternumber.next(); 511 cclassifier.put(cnumber, setupClassifier()); // this is the classifier used for the 512 // cluster 513 cclassifier.get(cnumber).buildClassifier(ctraindata.get(cnumber)); 514 // Console.traceln(Level.INFO, String.format("classifier in cluster "+cnumber)); 515 // traindata_count += ctraindata.get(cnumber).size(); 516 // Console.traceln(Level.INFO, 517 // String.format("building classifier in cluster "+cnumber +" with "+ 518 // ctraindata.get(cnumber).size() +" traindata instances")); 519 } 520 521 // add all traindata 522 // Console.traceln(Level.INFO, String.format("traindata in all clusters: " + 523 // traindata_count)); 524 } 525 } 526 527 /** 528 * Payload for the QuadTree. x and y are the calculated Fastmap values. T is a weka instance. 529 */ 530 public class QuadTreePayload<T> { 531 532 public double x; 533 public double y; 534 private T inst; 535 536 public QuadTreePayload(double x, double y, T value) { 537 this.x = x; 538 this.y = y; 539 this.inst = value; 540 } 541 542 public T getInst() { 543 return this.inst; 544 } 545 } 546 547 /** 548 * Fastmap implementation 549 * 550 * Faloutsos, C., & Lin, K. I. (1995). FastMap: A fast algorithm for indexing, data-mining and 551 * visualization of traditional and multimedia datasets (Vol. 24, No. 2, pp. 163-174). ACM. 552 */ 553 public class Fastmap { 554 555 /* N x k Array, at the end, the i-th row will be the image of the i-th object */ 556 private double[][] X; 557 558 /* 2 x k pivot Array one pair per recursive call */ 559 private int[][] PA; 560 561 /* Objects we got (distance matrix) */ 562 private double[][] O; 563 564 /* column of X currently updated (also the dimension) */ 565 private int col = 0; 566 567 /* number of dimensions we want */ 568 private int target_dims = 0; 569 570 // if we already have the pivot elements 571 private boolean pivot_set = false; 572 573 public Fastmap(int k) { 574 this.target_dims = k; 575 } 576 577 /** 578 * Sets the distance matrix and params that depend on this 579 * 580 * @param O 581 */ 582 public void setDistmat(double[][] O) { 583 this.O = O; 584 int N = O.length; 585 this.X = new double[N][this.target_dims]; 586 this.PA = new int[2][this.target_dims]; 587 } 588 589 /** 590 * Set pivot elements, we need that to classify instances after the calculation is complete 591 * (because we then want to reuse only the pivot elements). 592 * 593 * @param pi 594 */ 595 public void setPivots(int[][] pi) { 596 this.pivot_set = true; 597 this.PA = pi; 598 } 599 600 /** 601 * Return the pivot elements that were chosen during the calculation 602 * 603 * @return 604 */ 605 public int[][] getPivots() { 606 return this.PA; 607 } 608 609 /** 610 * The distance function for euclidean distance 611 * 612 * Acts according to equation 4 of the fastmap paper 613 * 614 * @param x 615 * x index of x image (if k==0 x object) 616 * @param y 617 * y index of y image (if k==0 y object) 618 * @param kdimensionality 619 * @return distance 620 */ 621 private double dist(int x, int y, int k) { 622 623 // basis is object distance, we get this from our distance matrix 624 double tmp = this.O[x][y] * this.O[x][y]; 625 626 // decrease by projections 627 for (int i = 0; i < k; i++) { 628 double tmp2 = (this.X[x][i] - this.X[y][i]); 629 tmp -= tmp2 * tmp2; 630 } 631 632 return Math.abs(tmp); 633 } 634 635 /** 636 * Find the object farthest from the given index This method is a helper Method for 637 * findDistandObjects 638 * 639 * @param index 640 * of the object 641 * @return index of the farthest object from the given index 642 */ 643 private int findFarthest(int index) { 644 double furthest = Double.MIN_VALUE; 645 int ret = 0; 646 647 for (int i = 0; i < O.length; i++) { 648 double dist = this.dist(i, index, this.col); 649 if (i != index && dist > furthest) { 650 furthest = dist; 651 ret = i; 652 } 653 } 654 return ret; 655 } 656 657 /** 658 * Finds the pivot objects 659 * 660 * This method is basically algorithm 1 of the fastmap paper. 661 * 662 * @return 2 indexes of the choosen pivot objects 663 */ 664 private int[] findDistantObjects() { 665 // 1. choose object randomly 666 Random r = new Random(); 667 int obj = r.nextInt(this.O.length); 668 669 // 2. find farthest object from randomly chosen object 670 int idx1 = this.findFarthest(obj); 671 672 // 3. find farthest object from previously farthest object 673 int idx2 = this.findFarthest(idx1); 674 675 return new int[] 676 { idx1, idx2 }; 677 } 678 679 /** 680 * Calculates the new k-vector values (projections) 681 * 682 * This is basically algorithm 2 of the fastmap paper. We just added the possibility to 683 * pre-set the pivot elements because we need to classify single instances after the 684 * computation is already done. 685 * 686 * @param dims 687 * dimensionality 688 */ 689 public void calculate() { 690 691 for (int k = 0; k < this.target_dims; k++) { 692 // 2) choose pivot objects 693 if (!this.pivot_set) { 694 int[] pivots = this.findDistantObjects(); 695 696 // 3) record ids of pivot objects 697 this.PA[0][this.col] = pivots[0]; 698 this.PA[1][this.col] = pivots[1]; 699 } 700 701 // 4) inter object distances are zero (this.X is initialized with 0 so we just 702 // continue) 703 if (this.dist(this.PA[0][this.col], this.PA[1][this.col], this.col) == 0) { 704 continue; 705 } 706 707 // 5) project the objects on the line between the pivots 708 double dxy = this.dist(this.PA[0][this.col], this.PA[1][this.col], this.col); 709 for (int i = 0; i < this.O.length; i++) { 710 711 double dix = this.dist(i, this.PA[0][this.col], this.col); 712 double diy = this.dist(i, this.PA[1][this.col], this.col); 713 714 double tmp = (dix + dxy - diy) / (2 * Math.sqrt(dxy)); 715 716 // save the projection 717 this.X[i][this.col] = tmp; 718 } 719 720 this.col += 1; 721 } 722 } 723 724 /** 725 * returns the result matrix of the projections 726 * 727 * @return calculated result 728 */ 729 public double[][] getX() { 730 return this.X; 731 } 732 } 669 733 } -
trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaTraining.java
r25 r41 1 // Copyright 2015 Georg-August-Universität Göttingen, Germany 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 1 15 package de.ugoe.cs.cpdp.training; 2 16 … … 11 25 /** 12 26 * Programmatic WekaTraining 13 *14 * first parameter is Trainer Name.15 * second parameter is class name16 27 * 17 * all subsequent parameters are configuration params (for example for trees) 18 * Cross Validation params always come last and are prepended with -CVPARAM 28 * first parameter is Trainer Name. second parameter is class name 29 * 30 * all subsequent parameters are configuration params (for example for trees) Cross Validation 31 * params always come last and are prepended with -CVPARAM 19 32 * 20 33 * XML Configurations for Weka Classifiers: 34 * 21 35 * <pre> 22 36 * {@code … … 30 44 public class WekaTraining extends WekaBaseTraining implements ITrainingStrategy { 31 45 32 @Override 33 public void apply(Instances traindata) { 34 PrintStream errStr = System.err; 35 System.setErr(new PrintStream(new NullOutputStream())); 36 try { 37 if(classifier == null) { 38 Console.traceln(Level.WARNING, String.format("classifier null!")); 39 } 40 classifier.buildClassifier(traindata); 41 } catch (Exception e) { 42 throw new RuntimeException(e); 43 } finally { 44 System.setErr(errStr); 45 } 46 } 46 @Override 47 public void apply(Instances traindata) { 48 PrintStream errStr = System.err; 49 System.setErr(new PrintStream(new NullOutputStream())); 50 try { 51 if (classifier == null) { 52 Console.traceln(Level.WARNING, String.format("classifier null!")); 53 } 54 classifier.buildClassifier(traindata); 55 } 56 catch (Exception e) { 57 throw new RuntimeException(e); 58 } 59 finally { 60 System.setErr(errStr); 61 } 62 } 47 63 }
Note: See TracChangeset
for help on using the changeset viewer.