Changeset 12
- Timestamp:
- 08/25/14 13:27:30 (10 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaLocalTraining2.java
r9 r12 17 17 import weka.classifiers.AbstractClassifier; 18 18 import weka.classifiers.Classifier; 19 import weka.clusterers.EM;20 19 import weka.core.DenseInstance; 21 20 import weka.core.EuclideanDistance; … … 28 27 * ACHTUNG UNFERTIG 29 28 * 29 * 30 * Basically a copy of WekaClusterTraining2 with internal classes for the Fastmap and QuadTree implementations 30 31 */ 31 32 public class WekaLocalTraining2 extends WekaBaseTraining2 implements ITrainingStrategy { 32 33 33 34 private final TraindatasetCluster classifier = new TraindatasetCluster(); 34 private final QuadTree q = null; 35 private final Fastmap f = null; 36 37 /*Stopping rule for tree reqursion (Math.sqrt(Instances)*/ 38 public static double ALPHA = 3; 35 36 // we do not need to keep them around 37 //private final QuadTree q = null; 38 //private final Fastmap f = null; 39 40 // these values are set later when we have all the information we need 41 /*Stopping rule for tree recursion (Math.sqrt(Instances)*/ 42 public static double ALPHA = 0; 39 43 /*Stopping rule for clustering*/ 40 44 public static double DELTA = 0.5; 41 /*size of the complete set (used for density )*/45 /*size of the complete set (used for density function)*/ 42 46 public static int SIZE = 0; 47 48 public static int MIN_INST = 10; 43 49 44 50 // cluster … … 69 75 private static final long serialVersionUID = 1L; 70 76 71 private EM clusterer = null;72 73 77 private HashMap<Integer, Classifier> cclassifier = new HashMap<Integer, Classifier>(); 74 78 private HashMap<Integer, Instances> ctraindata = new HashMap<Integer, Instances>(); 75 76 79 77 80 … … 100 103 } 101 104 102 105 /** 106 * Because Fastmap saves only the image not the values of the attributes 107 * we can not use it to classify single instances to values 108 * 109 * TODO: mehr erklärung 110 * TODO: class lavel filter raus 111 * 112 * Finde die am nächsten liegende Instanz zur übergebenen 113 * dann bestimme den cluster der instanz und führe dann den 114 * classifier des clusters aus 115 */ 103 116 @Override 104 117 public double classifyInstance(Instance instance) { 118 105 119 double ret = 0; 106 120 try { … … 116 130 Instance clusterInstance = createInstance(traindata, instance); 117 131 118 // 1. classify testdata instance to a cluster number 119 int cnum = clusterer.clusterInstance(clusterInstance); 120 121 // 2. classify testata instance to the classifier 122 ret = cclassifier.get(cnum).classifyInstance(classInstance); 132 // get distance of this instance to every other instance 133 // if the distance is minimal apply the classifier of the current cluster 134 int cnumber; 135 int min_cluster = -1; 136 double min_distance = 99999999; 137 EuclideanDistance d; 138 Iterator<Integer> clusternumber = ctraindata.keySet().iterator(); 139 while ( clusternumber.hasNext() ) { 140 cnumber = clusternumber.next(); 141 142 d = new EuclideanDistance(ctraindata.get(cnumber)); 143 for(int i=0; i < ctraindata.get(cnumber).size(); i++) { 144 if(d.distance(clusterInstance, ctraindata.get(cnumber).get(i)) <= min_distance) { 145 min_cluster = cnumber; 146 min_distance = d.distance(clusterInstance, ctraindata.get(cnumber).get(i)); 147 } 148 } 149 } 150 151 // here we have the cluster where an instance has the minimum distance between itself the 152 // instance we want to classify 153 if(min_cluster == -1) { 154 // this is an error condition 155 throw new RuntimeException("min_cluster not found"); 156 } 157 158 // classify the passed instance with the cluster we found 159 ret = cclassifier.get(min_cluster).classifyInstance(classInstance); 123 160 124 161 }catch( Exception e ) { … … 128 165 return ret; 129 166 } 130 131 132 167 133 168 @Override … … 138 173 139 174 // 2. remove class attribute for clustering 140 Remove filter = new Remove();141 filter.setAttributeIndices("" + (train.classIndex() + 1));142 filter.setInputFormat(train);143 train = Filter.useFilter(train, filter);144 145 146 // 3. calculate distance matrix 175 //Remove filter = new Remove(); 176 //filter.setAttributeIndices("" + (train.classIndex() + 1)); 177 //filter.setInputFormat(train); 178 //train = Filter.useFilter(train, filter); 179 180 181 // 3. calculate distance matrix (needed for Fastmap because it starts at dimension 1) 147 182 EuclideanDistance d = new EuclideanDistance(train); 148 183 double[][] dist = new double[train.size()][train.size()]; … … 153 188 } 154 189 155 // 4. run fastmap for 2 dimensions on distance matrix190 // 4. run fastmap for 2 dimensions on the distance matrix 156 191 Fastmap f = new Fastmap(2, dist); 157 192 f.calculate(2); … … 163 198 // die max und min brauchen wir für die größenangaben der sektoren 164 199 double[] big = {0,0}; 165 double[] small = { -99999,-99999};166 167 // set quadtree payload values 200 double[] small = {9999999,99999999}; 201 202 // set quadtree payload values and get max and min x and y values for size 168 203 for(int i=0; i<X.length; i++){ 169 204 if(X[i][0] >= big[0]) { … … 188 223 SIZE = train.size(); 189 224 190 // split recursively 225 Console.traceln(Level.INFO, String.format("Generate QuadTree with "+ SIZE + " size, Alpha: "+ ALPHA+ "")); 226 227 // set the size and then split the tree recursively at the median value for x, y 191 228 q.setSize(new double[] {small[0], big[0]}, new double[] {small[1], big[1]}); 192 229 q.recursiveSplit(q); 193 230 194 // generate list of nodes sorted by density 231 // generate list of nodes sorted by density (childs only) 195 232 ArrayList<QuadTree> l = new ArrayList<QuadTree>(q.getList(q)); 196 233 197 // grid clustering recursive (tree pruning)234 // recursive grid clustering (tree pruning), the values are stored in cluster! 198 235 q.gridClustering(l); 199 236 237 // after grid clustering we need to remove the clusters with < 2 * ALPHA instances 200 238 201 239 // hier müssten wir sowas haben wie welche instanz in welchem cluster ist 202 // oder wir iterieren durch die cluster und sammeln uns die insta znen daraus240 // oder wir iterieren durch die cluster und sammeln uns die instanzen daraus 203 241 for(int i=0; i < cluster.size(); i++) { 204 242 ArrayList<QuadTreePayload<Instance>> current = cluster.get(i); 205 for(int j=0; j < current.size(); j++ ) { 206 243 244 // i is the clusternumber 245 // we only allow clusters with Instances > ALPHA 246 if(current.size() > ALPHA) { 247 for(int j=0; j < current.size(); j++ ) { 248 if(!ctraindata.containsKey(i)) { 249 ctraindata.put(i, new Instances(traindata)); 250 ctraindata.get(i).delete(); 251 } 252 ctraindata.get(i).add(current.get(j).getInst()); 253 } 207 254 } 208 255 } 209 210 Instances ctrain = new Instances(train); 211 212 // get traindata per cluster 213 int cnumber; 214 for ( int j=0; j < ctrain.numInstances(); j++ ) { 215 // get the cluster number from the attributes, subract 1 because if we clusterInstance we get 0-n, and this is 1-n 216 //cnumber = Integer.parseInt(ctrain.get(j).stringValue(ctrain.get(j).numAttributes()-1).replace("cluster", "")) - 1; 217 218 cnumber = clusterer.clusterInstance(ctrain.get(j)); 219 // add training data to list of instances for this cluster number 220 if ( !ctraindata.containsKey(cnumber) ) { 221 ctraindata.put(cnumber, new Instances(traindata)); 222 ctraindata.get(cnumber).delete(); 223 } 224 ctraindata.get(cnumber).add(traindata.get(j)); 225 } 226 227 // train one classifier per cluster, we get the clusternumber from the traindata 256 257 // train one classifier per cluster, we get the clusternumber from the traindata 258 int cnumber; 228 259 Iterator<Integer> clusternumber = ctraindata.keySet().iterator(); 229 260 while ( clusternumber.hasNext() ) { … … 231 262 cclassifier.put(cnumber,setupClassifier()); 232 263 cclassifier.get(cnumber).buildClassifier(ctraindata.get(cnumber)); 233 234 264 //Console.traceln(Level.INFO, String.format("classifier in cluster "+cnumber)); 265 //Console.traceln(Level.INFO, String.format("" + ctraindata.get(cnumber).size() + " instances in cluster "+cnumber)); 235 266 } 236 267 } … … 254 285 } 255 286 256 public T get inst() {287 public T getInst() { 257 288 return this.inst; 258 289 } … … 261 292 /** 262 293 * Fastmap implementation 294 * 295 * TODO: only one place to pass dimension! 263 296 * 264 297 * Faloutsos, C., & Lin, K. I. (1995). … … 289 322 } 290 323 291 292 /*recursive function ALT*/293 private double dist2(int x, int y, int k) {294 // basisfall295 if(k == 0) {296 return Math.pow(this.O[x][y], 2);297 }298 299 double dist_rec = Math.pow(this.dist(x, y, k-1), 2);300 double dist_norm = Math.pow(Math.abs(this.X[x][k] - this.X[y][k]), 2);301 302 return Math.sqrt(Math.abs(dist_rec - dist_norm));303 //return Math.abs(dist_rec - dist_norm);304 }305 306 324 /** 307 325 * The distance function for eculidean distance 308 326 * 309 * Acts according to equation 4 of the fastmap paper327 * Acts according to equation 4 of the fastmap paper 310 328 * 311 329 * @param x x index of x image (if k==0 x object) … … 316 334 private double dist(int x, int y, int k) { 317 335 318 // objectabstand ist basis 319 320 // das hier wäre ein abstand zwischen 2 weka instanzen, z.B. euclidischer abstand zwischen den beiden vektoren 336 // basis is object distance, we get this from our distance matrix 337 // alternatively we could provide a distance function that takes 2 vectors 321 338 double tmp = this.O[x][y] * this.O[x][y]; 322 323 339 324 340 // decrease by projections … … 333 349 334 350 /** 335 * Find the object f urthest from the given index351 * Find the object farthest from the given index 336 352 * This method is a helper Method for findDistandObjects 337 353 * 338 354 * @param index of the object 339 * @return index of the f urthest object from the given index340 */ 341 private int findF urthest(int index) {355 * @return index of the farthest object from the given index 356 */ 357 private int findFarthest(int index) { 342 358 double furthest = -1000000; 343 359 int ret = 0; … … 353 369 } 354 370 355 356 371 /** 357 372 * Finds the pivot objects … … 366 381 int obj = r.nextInt(this.O.length); 367 382 368 // 2. find furthest object from randomly chooen object 369 int idx1 = this.findFurthest(obj); 370 371 // 3. find furthest object from previously furthest object 372 int idx2 = this.findFurthest(idx1); 373 374 int[] ret = {idx1, idx2}; 375 return ret; 376 } 377 378 379 /** 380 * Gives image of object (projection on the line between px, py) 381 * 382 * @param index of the object to project 383 * @param px pivot 1 384 * @param py pivot 2 385 * @return projection 386 */ 387 private double project(int index, int px, int py) { 388 389 double dix = this.dist(index, px, this.col); 390 double diy = this.dist(index, py, this.col); 391 double dxy = this.dist(px, py, this.col); 392 393 return (dix + dxy - diy) / 2 * Math.sqrt(dxy); 394 } 395 396 397 /*recursive function ALT, geht auch sequentiell*/ 398 public void calculate2(int k) { 399 400 // 1) basisfall 401 if(k <= 0) { 402 return; 403 } 404 405 // 2) choose pivot objects 406 int[] pivots = this.findDistantObjects(); 407 408 // 3) record ids of pivot objects 409 this.PA[0][this.col] = pivots[0]; 410 this.PA[1][this.col] = pivots[1]; 411 412 System.out.println("found pivots with index: " + pivots[0] + ","+ pivots[1]); 413 414 // 4) inter object distances are zero (this.X is initialized with 0 so we just return) 415 if(this.dist(pivots[0], pivots[1], this.col) == 0) { 416 return; 417 } 418 419 double dxy = this.dist(pivots[0], pivots[1], this.col); 420 421 if(dxy == 0) { 422 return; 423 } 424 425 // 5) project the objects on the line between the pivots 426 for(int i=0; i < this.O.length; i++) { 427 428 double dix = this.dist(i, pivots[0], this.col); 429 double diy = this.dist(i, pivots[1], this.col); 430 431 this.X[i][this.col] = (dix + dxy - diy) / 2 * Math.sqrt(dxy); 432 433 //this.X[i][this.col] = this.project(i, pivots[0], pivots[1]); 434 } 435 436 this.col += 1; 437 438 // 6) recurse 439 this.calculate2(k-1); 440 } 441 442 443 // test funktion, reproduziert ergebnisse aus dem technical report von fastmap 444 public int[] findDistantObjects2() { 445 int[] ret = {0,0}; 446 if(this.col == 0) { 447 ret = new int[] {0,3}; 448 } 449 if(this.col == 1) { 450 ret = new int[] {4,1}; 451 } 452 if(this.col == 2) { 453 ret = new int[] {2,4}; 454 } 455 456 return ret; 457 } 458 459 383 // 2. find farthest object from randomly chosen object 384 int idx1 = this.findFarthest(obj); 385 386 // 3. find farthest object from previously farthest object 387 int idx2 = this.findFarthest(idx1); 388 389 return new int[] {idx1, idx2}; 390 } 391 460 392 /** 461 393 * Calculates the new k-vector values … … 488 420 double tmp = (dix + dxy - diy) / 2 * Math.sqrt(dxy); 489 421 490 this.X[i][this.col] = tmp; // / 10000; 491 //this.X[i][this.col] = this.project(i, pivots[0], pivots[1]); 422 this.X[i][this.col] = tmp; 492 423 } 493 424 … … 495 426 } 496 427 } 497 498 428 499 429 /** … … 506 436 } 507 437 508 438 /** 439 * QuadTree implementation 440 * 441 * QuadTree gets a list of instances and then recursively split them into 4 childs 442 * For this it uses the median of the 2 values x,y 443 */ 509 444 public class QuadTree { 510 445 … … 521 456 private ArrayList<QuadTree> l = new ArrayList<QuadTree>(); 522 457 523 458 // level only used for debugging 524 459 public int level = 0; 525 460 … … 533 468 // evtl. statt ArrayList eigene QuadTreePayloadlist 534 469 private ArrayList<QuadTreePayload<Instance>> payload; 535 536 470 537 471 public QuadTree(QuadTree parent, ArrayList<QuadTreePayload<Instance>> payload) { … … 551 485 } 552 486 553 554 487 /** 555 488 * Returns the payload, used for clustering … … 567 500 * density = number of instances / global size (all instances) 568 501 * 569 * @return 502 * @return density 570 503 */ 571 504 public double getDensity() { … … 586 519 587 520 /** 588 * Todo: dry, median ist immer dasselbe521 * Todo: DRY, median ist immer dasselbe 589 522 * 590 523 * @return median for x … … 663 596 664 597 /** 665 * Calculate median values of payload for x, y and split into sectors598 * Calculate median values of payload for x, y and split into 4 sectors 666 599 * 667 600 * @return Array of QuadTree nodes (4 childs) … … 673 606 double medy = this.getMedianForY(); 674 607 675 608 // Payload lists for each child 676 609 ArrayList<QuadTreePayload<Instance>> nw = new ArrayList<QuadTreePayload<Instance>>(); 677 610 ArrayList<QuadTreePayload<Instance>> sw = new ArrayList<QuadTreePayload<Instance>>(); … … 710 643 // if we assign one child a payload equal to our own (see problem above) 711 644 // we throw an exceptions which stops the recursion on this node 645 // second error is minimum number of instances 646 //Console.traceln(Level.INFO, String.format("NW: "+ nw.size() + " SW: " + sw.size() + " NE: " + ne.size() + " SE: " + se.size())); 712 647 if(nw.equals(this.payload)) { 713 648 throw new Exception("payload equal"); 714 649 } 715 650 if(sw.equals(this.payload)) { 716 throw new Exception(" ayload equal");651 throw new Exception("payload equal"); 717 652 } 718 653 if(ne.equals(this.payload)) { … … 737 672 this.child_se = new QuadTree(this, se); 738 673 this.child_se.setSize(new double[] {medx, this.x[1]}, new double[] {this.y[0], medy}); 739 this.child_se.level = this.level + 1; 674 this.child_se.level = this.level + 1; 740 675 741 676 this.payload = null; … … 844 779 } 845 780 846 847 848 781 /** 849 782 * Perform Pruning and clustering of the quadtree … … 851 784 * 1) get list of leaf quadrants 852 785 * 2) sort by their density 853 * 3) merge similar densities to new leaf quadrant 854 * @param q QuadTree 786 * 3) set stop_rule to 0.5 * highest Density in the list 787 * 4) merge all nodes with a density > stop_rule to the new cluster and remove all from list 788 * 5) repeat 789 * 790 * @param q List of QuadTree (children only) 855 791 */ 856 792 public void gridClustering(ArrayList<QuadTree> list) { … … 885 821 //System.out.println("removing "+biggest.getDensity() + " from list"); 886 822 887 // while items in list823 // check the items for their density 888 824 for(int i=list.size()-1; i >= 0; i--) { 889 825 current = list.get(i); 890 826 891 827 // 2. find neighbours with correct density 892 // if density > 0.5 * DELTA and is_neighbour add to cluster828 // if density > stop_rule and is_neighbour add to cluster and remove from list 893 829 if(current.getDensity() > stop_rule && !current.equals(biggest) && current.isNeighbour(biggest)) { 894 830 //System.out.println("adding " + current.getDensity() + " to cluster"); … … 923 859 924 860 /** 925 * 926 * @param q 927 * @return 861 * Helper Method to get a sortet list (by density) for all 862 * children 863 * 864 * @param q QuadTree 865 * @return Sorted ArrayList of quadtrees 928 866 */ 929 867 public ArrayList<QuadTree> getList(QuadTree q) {
Note: See TracChangeset
for help on using the changeset viewer.