- Timestamp:
- 09/05/14 15:52:48 (10 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaLocalTraining2.java
r16 r17 3 3 import java.io.PrintStream; 4 4 import java.util.ArrayList; 5 import java.util.Collections;6 import java.util.Comparator;7 5 import java.util.HashMap; 8 6 import java.util.HashSet; … … 14 12 import org.apache.commons.io.output.NullOutputStream; 15 13 14 import de.ugoe.cs.cpdp.training.QuadTree; 16 15 import de.ugoe.cs.util.console.Console; 17 16 import weka.classifiers.AbstractClassifier; … … 25 24 26 25 /** 27 * ACHTUNG UNFERTIG 26 * Trainer with reimplementation of WHERE clustering algorithm from: 27 * Tim Menzies, Andrew Butcher, David Cok, Andrian Marcus, Lucas Layman, 28 * Forrest Shull, Burak Turhan, Thomas Zimmermann, 29 * "Local versus Global Lessons for Defect Prediction and Effort Estimation," 30 * IEEE Transactions on Software Engineering, vol. 39, no. 6, pp. 822-834, June, 2013 28 31 * 29 32 * With WekaLocalTraining2 we do the following: … … 33 36 * 3) We cluster the QuadTree nodes together if they have similar density (50%) 34 37 * 4) We save the clusters and their training data 35 * 5) We only use clusters with > ALPHA instances (currently Math.sqrt(SIZE)), rest is discarded 38 * 5) We only use clusters with > ALPHA instances (currently Math.sqrt(SIZE)), rest is discarded with the training data of this cluster 36 39 * 6) We train a Weka classifier for each cluster with the clusters training data 37 * 7) We recalculate Fastmap distances for a single instance and then try to find a cluster containing the coords of the instance.40 * 7) We recalculate Fastmap distances for a single instance with the old pivots and then try to find a cluster containing the coords of the instance. 38 41 * 7.1.) If we can not find a cluster (due to coords outside of all clusters) we find the nearest cluster. 39 * 8) We classif iy the Instance with the classifier and traindata from the Cluster we found in 7.42 * 8) We classify the Instance with the classifier and traindata from the Cluster we found in 7. 40 43 */ 41 44 public class WekaLocalTraining2 extends WekaBaseTraining2 implements ITrainingStrategy { 42 45 43 46 private final TraindatasetCluster classifier = new TraindatasetCluster(); 44 45 // these values are set later when we have all the information we need (size)46 /*Stopping rule for tree recursion (Math.sqrt(Instances)*/47 public static double ALPHA = 0;48 /*size of the complete set (used for density function)*/49 public static int SIZE = 0;50 /*Stopping rule for clustering*/51 public static double DELTA = 0.5;52 53 // we need these references later in the testing54 private static QuadTree TREE;55 private static Fastmap FMAP;56 private static EuclideanDistance DIST;57 private static Instances TRAIN;58 59 // cluster payloads60 private static ArrayList<ArrayList<QuadTreePayload<Instance>>> cluster = new ArrayList<ArrayList<QuadTreePayload<Instance>>>();61 62 // cluster sizes (index is cluster number, arraylist is list of boxes (x0,y0,x1,y1)63 private static HashMap<Integer, ArrayList<Double[][]>> CSIZE = new HashMap<Integer, ArrayList<Double[][]>>();64 47 65 48 @Override … … 67 50 return classifier; 68 51 } 69 70 52 71 53 @Override … … 86 68 87 69 private static final long serialVersionUID = 1L; 88 70 71 /* classifier per cluster */ 89 72 private HashMap<Integer, Classifier> cclassifier = new HashMap<Integer, Classifier>(); 73 74 /* instances per cluster */ 90 75 private HashMap<Integer, Instances> ctraindata = new HashMap<Integer, Instances>(); 76 77 /* holds the instances and indices of the pivot objects of the Fastmap calculation in buildClassifier*/ 78 private HashMap<Integer, Instance> cpivots = new HashMap<Integer, Instance>(); 79 80 /* holds the indices of the pivot objects for x,y and the dimension [x,y][dimension]*/ 81 private int[][] cpivotindices = new int[2][2]; 82 83 /* holds the sizes of the cluster multiple "boxes" per cluster */ 84 private HashMap<Integer, ArrayList<Double[][]>> csize; 85 86 private boolean show_biggest = true; 87 88 private int CFOUND = 0; 89 private int CNOTFOUND = 0; 91 90 92 91 … … 117 116 /** 118 117 * Because Fastmap saves only the image not the values of the attributes it used 119 * we can not use it or the QuadTree to classify single instances to clusters. 120 * 121 * To classify a single instance we measure the distance to all instances we have clustered and 122 * use the cluster where the distance is minimal. 123 * 124 * TODO: class attribute filter raus 125 * TODO: werden auf die übergebene Instance ebenfalls die preprocessors angewendet? müsste eigentlich 118 * we can not use the old data directly to classify single instances to clusters. 119 * 120 * To classify a single instance we do a new fastmap computation with only the instance and 121 * the old pivot elements. 122 * 123 * After that we find the cluster with our fastmap result for x and y. 126 124 */ 127 125 @Override … … 130 128 double ret = 0; 131 129 try { 130 // classinstance gets passed to classifier 132 131 Instances traindata = ctraindata.get(0); 133 132 Instance classInstance = createInstance(traindata, instance); 133 134 // this one keeps the class attribute 135 Instances traindata2 = ctraindata.get(1); 134 136 135 137 // remove class attribute before clustering … … 138 140 filter.setInputFormat(traindata); 139 141 traindata = Filter.useFilter(traindata, filter); 140 141 142 Instance clusterInstance = createInstance(traindata, instance); 142 143 143 // build temp dist matrix (2 Pivot per dimension + 1 instance we want to classify) 144 Fastmap FMAP = new Fastmap(2); 145 EuclideanDistance dist = new EuclideanDistance(traindata); 146 147 148 // we set our pivot indices [x=0,y=1][dimension] 149 int[][] npivotindices = new int[2][2]; 150 npivotindices[0][0] = 1; 151 npivotindices[1][0] = 2; 152 npivotindices[0][1] = 3; 153 npivotindices[1][1] = 4; 154 155 // build temp dist matrix (2 pivots per dimension + 1 instance we want to classify) 156 // the instance we want to classify comes first after that the pivot elements in the order defined above 144 157 double[][] distmat = new double[2*FMAP.target_dims+1][2*FMAP.target_dims+1]; 145 146 // vector of instances of pivots + 1 (for the instance we want to classify) 147 int[] tmp = new int[FMAP.PA.length+1]; 148 149 Instance tmpi; 150 Instance tmpj; 151 for(int i=0; i < tmp.length; i++) { 152 for(int j=0; j < tmp.length; j++) { 153 if(i==0) { 154 tmpi = instance; 155 }else{ 156 tmpi = TRAIN.get(i); 158 distmat[0][0] = 0; 159 distmat[0][1] = dist.distance(clusterInstance, this.cpivots.get((Integer)this.cpivotindices[0][0])); 160 distmat[0][2] = dist.distance(clusterInstance, this.cpivots.get((Integer)this.cpivotindices[1][0])); 161 distmat[0][3] = dist.distance(clusterInstance, this.cpivots.get((Integer)this.cpivotindices[0][1])); 162 distmat[0][4] = dist.distance(clusterInstance, this.cpivots.get((Integer)this.cpivotindices[1][1])); 163 164 distmat[1][0] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][0]), clusterInstance); 165 distmat[1][1] = 0; 166 distmat[1][2] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][0]), this.cpivots.get((Integer)this.cpivotindices[1][0])); 167 distmat[1][3] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][0]), this.cpivots.get((Integer)this.cpivotindices[0][1])); 168 distmat[1][4] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][0]), this.cpivots.get((Integer)this.cpivotindices[1][1])); 169 170 distmat[2][0] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][0]), clusterInstance); 171 distmat[2][1] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][0]), this.cpivots.get((Integer)this.cpivotindices[0][0])); 172 distmat[2][2] = 0; 173 distmat[2][3] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][0]), this.cpivots.get((Integer)this.cpivotindices[0][1])); 174 distmat[2][4] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][0]), this.cpivots.get((Integer)this.cpivotindices[1][1])); 175 176 distmat[3][0] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][1]), clusterInstance); 177 distmat[3][1] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][1]), this.cpivots.get((Integer)this.cpivotindices[0][0])); 178 distmat[3][2] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][1]), this.cpivots.get((Integer)this.cpivotindices[1][0])); 179 distmat[3][3] = 0; 180 distmat[3][4] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][1]), this.cpivots.get((Integer)this.cpivotindices[1][1])); 181 182 distmat[4][0] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][1]), clusterInstance); 183 distmat[4][1] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][1]), this.cpivots.get((Integer)this.cpivotindices[0][0])); 184 distmat[4][2] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][1]), this.cpivots.get((Integer)this.cpivotindices[1][0])); 185 distmat[4][3] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][1]), this.cpivots.get((Integer)this.cpivotindices[0][1])); 186 distmat[4][4] = 0; 187 188 189 /* debug output: show biggest distance found within the new distance matrix 190 double biggest = 0; 191 for(int i=0; i < distmat.length; i++) { 192 for(int j=0; j < distmat[0].length; j++) { 193 if(biggest < distmat[i][j]) { 194 biggest = distmat[i][j]; 157 195 } 158 159 if(j == 0) {160 tmpj = instance;161 }else {162 tmpj = TRAIN.get(j);163 }164 165 distmat[i][j] = DIST.distance(tmpi, tmpj);166 196 } 167 197 } 168 169 // this is the projection vector for our instance 170 double[] proj = FMAP.addInstance(distmat); 198 if(this.show_biggest) { 199 Console.traceln(Level.INFO, String.format(""+clusterInstance)); 200 Console.traceln(Level.INFO, String.format("biggest distances: "+ biggest)); 201 this.show_biggest = false; 202 } 203 */ 204 205 FMAP.setDistmat(distmat); 206 FMAP.setPivots(npivotindices); 207 FMAP.calculate(); 208 double[][] x = FMAP.getX(); 209 double[] proj = x[0]; 210 211 // debug output: show the calculated distance matrix, our result vektor for the instance and the complete result matrix 212 /* 213 Console.traceln(Level.INFO, "distmat:"); 214 for(int i=0; i<distmat.length; i++){ 215 for(int j=0; j<distmat[0].length; j++){ 216 Console.trace(Level.INFO, String.format("%20s", distmat[i][j])); 217 } 218 Console.traceln(Level.INFO, ""); 219 } 220 221 Console.traceln(Level.INFO, "vector:"); 222 for(int i=0; i < proj.length; i++) { 223 Console.trace(Level.INFO, String.format("%20s", proj[i])); 224 } 225 Console.traceln(Level.INFO, ""); 226 227 Console.traceln(Level.INFO, "resultmat:"); 228 for(int i=0; i<x.length; i++){ 229 for(int j=0; j<x[0].length; j++){ 230 Console.trace(Level.INFO, String.format("%20s", x[i][j])); 231 } 232 Console.traceln(Level.INFO, ""); 233 } 234 */ 235 236 // TODO: can we be in more cluster than one? 237 // now we iterate over all clusters (well, boxes of sizes per cluster really) and save the number of the 238 // cluster in which we are 171 239 int cnumber; 172 240 int found_cnumber = -1; 173 Iterator<Integer> clusternumber = CSIZE.keySet().iterator();174 while ( clusternumber.hasNext() ) {241 Iterator<Integer> clusternumber = this.csize.keySet().iterator(); 242 while ( clusternumber.hasNext() && found_cnumber == -1) { 175 243 cnumber = clusternumber.next(); 176 244 177 // jetzt iterieren wir über die boxen und hoffen wir finden was (cluster könnte auch entfernt worden sein) 178 for ( int box=0; box < CSIZE.get(cnumber).size(); box++ ) { 179 Double[][] current = CSIZE.get(cnumber).get(box); 180 if(proj[0] <= current[0][0] && proj[0] >= current[0][1] && // x 181 proj[1] <= current[1][0] && proj[1] >= current[1][1]) { // y 245 // now iterate over the boxes of the cluster and hope we find one (cluster could have been removed) 246 // or we are too far away from any cluster 247 for ( int box=0; box < this.csize.get(cnumber).size(); box++ ) { 248 Double[][] current = this.csize.get(cnumber).get(box); 249 250 if(proj[0] >= current[0][0] && proj[0] <= current[0][1] && // x 251 proj[1] >= current[1][0] && proj[1] <= current[1][1]) { // y 182 252 found_cnumber = cnumber; 183 253 } … … 185 255 } 186 256 187 // wenn wir keinen cluster finden, liegen wir außerhalb des bereichs 188 // kann das vorkommen mit fastmap? 189 190 // ja das kann vorkommen wir suchen also weiterhin den nächsten 191 // müssten mal durchzählen wie oft das vorkommt 257 // we want to count how often we are really inside a cluster 192 258 if ( found_cnumber == -1 ) { 193 //Console.traceln(Level.INFO, String.format("ERROR matching instance to cluster!")); 194 //throw new RuntimeException("no cluster for test instance found!"); 195 } 196 197 // jetzt kann es vorkommen das der cluster gelöscht wurde (weil zuwenig instanzen), jetzt müssen wir den 198 // finden der am nächsten dran ist 259 CNOTFOUND += 1; 260 }else { 261 CFOUND += 1; 262 } 263 264 // now it can happen that we dont find a cluster because we deleted it previously (too few instances) 265 // or we get bigger distance measures from weka so that we are completely outside of our clusters. 266 // in these cases we just find the nearest cluster to our instance and use it for classification. 267 // to do that we use the EuclideanDistance again to compare our distance to all other Instances 268 // then we take the cluster of the closest weka instance 269 dist = new EuclideanDistance(traindata2); 199 270 if( !this.ctraindata.containsKey(found_cnumber) ) { 200 271 double min_distance = 99999999; … … 203 274 cnumber = clusternumber.next(); 204 275 for(int i=0; i < ctraindata.get(cnumber).size(); i++) { 205 if( DIST.distance(clusterInstance, ctraindata.get(cnumber).get(i)) <= min_distance) {276 if(dist.distance(instance, ctraindata.get(cnumber).get(i)) <= min_distance) { 206 277 found_cnumber = cnumber; 207 min_distance = DIST.distance(clusterInstance, ctraindata.get(cnumber).get(i));278 min_distance = dist.distance(instance, ctraindata.get(cnumber).get(i)); 208 279 } 209 280 } … … 213 284 // here we have the cluster where an instance has the minimum distance between itself the 214 285 // instance we want to classify 286 // if we still have not found a cluster we exit because something is really wrong 215 287 if( found_cnumber == -1 ) { 216 // this is an error condition217 288 Console.traceln(Level.INFO, String.format("ERROR matching instance to cluster with full search!")); 218 throw new RuntimeException(" min_cluster not found");219 } 220 221 // classify the passed instance with the cluster we found 289 throw new RuntimeException("cluster not found with full search"); 290 } 291 292 // classify the passed instance with the cluster we found and its training data 222 293 ret = cclassifier.get(found_cnumber).classifyInstance(classInstance); 223 294 … … 232 303 public void buildClassifier(Instances traindata) throws Exception { 233 304 305 //Console.traceln(Level.INFO, String.format("found: "+ CFOUND + ", notfound: " + CNOTFOUND)); 306 this.show_biggest = true; 307 308 234 309 // 1. copy traindata 235 310 Instances train = new Instances(traindata); 311 Instances train2 = new Instances(traindata); // this one keeps the class attribute 236 312 237 313 // 2. remove class attribute for clustering 238 //Remove filter = new Remove(); 239 //filter.setAttributeIndices("" + (train.classIndex() + 1)); 240 //filter.setInputFormat(train); 241 //train = Filter.useFilter(train, filter); 242 243 TRAIN = train; 314 Remove filter = new Remove(); 315 filter.setAttributeIndices("" + (train.classIndex() + 1)); 316 filter.setInputFormat(train); 317 train = Filter.useFilter(train, filter); 318 244 319 // 3. calculate distance matrix (needed for Fastmap because it starts at dimension 1) 245 DIST = new EuclideanDistance(train); 246 double[][] dist = new double[train.size()][train.size()]; 247 for(int i=0; i < train.size(); i++) { 248 for(int j=0; j < train.size(); j++) { 249 dist[i][j] = DIST.distance(train.get(i), train.get(j)); 250 } 251 } 320 double biggest = 0; 321 EuclideanDistance dist = new EuclideanDistance(train); 322 double[][] distmat = new double[train.size()][train.size()]; 323 for( int i=0; i < train.size(); i++ ) { 324 for( int j=0; j < train.size(); j++ ) { 325 distmat[i][j] = dist.distance(train.get(i), train.get(j)); 326 if( distmat[i][j] > biggest ) { 327 biggest = distmat[i][j]; 328 } 329 } 330 } 331 //Console.traceln(Level.INFO, String.format("biggest distances: "+ biggest)); 252 332 253 333 // 4. run fastmap for 2 dimensions on the distance matrix 254 FMAP = new Fastmap(2, dist); 334 Fastmap FMAP = new Fastmap(2); 335 FMAP.setDistmat(distmat); 255 336 FMAP.calculate(); 337 338 cpivotindices = FMAP.getPivots(); 339 256 340 double[][] X = FMAP.getX(); 257 341 … … 264 348 265 349 // set quadtree payload values and get max and min x and y values for size 266 for( int i=0; i<X.length; i++){350 for( int i=0; i<X.length; i++ ){ 267 351 if(X[i][0] >= big[0]) { 268 352 big[0] = X[i][0]; … … 277 361 small[1] = X[i][1]; 278 362 } 279 QuadTreePayload<Instance> tmp = new QuadTreePayload<Instance>(X[i][0], X[i][1], train .get(i));363 QuadTreePayload<Instance> tmp = new QuadTreePayload<Instance>(X[i][0], X[i][1], train2.get(i)); 280 364 qtp.add(tmp); 281 365 } 282 366 367 Console.traceln(Level.INFO, String.format("size for cluster ("+small[0]+","+small[1]+") - ("+big[0]+","+big[1]+")")); 368 283 369 // 5. generate quadtree 284 TREE = new QuadTree(null, qtp); 285 ALPHA = Math.sqrt(train.size()); 286 SIZE = train.size(); 287 288 //Console.traceln(Level.INFO, String.format("Generate QuadTree with "+ SIZE + " size, Alpha: "+ ALPHA+ "")); 370 QuadTree TREE = new QuadTree(null, qtp); 371 QuadTree.size = train.size(); 372 QuadTree.alpha = Math.sqrt(train.size()); 373 QuadTree.ccluster = new ArrayList<ArrayList<QuadTreePayload<Instance>>>(); 374 QuadTree.csize = new HashMap<Integer, ArrayList<Double[][]>>(); 375 376 //Console.traceln(Level.INFO, String.format("Generate QuadTree with "+ QuadTree.size + " size, Alpha: "+ QuadTree.alpha+ "")); 289 377 290 378 // set the size and then split the tree recursively at the median value for x, y 291 379 TREE.setSize(new double[] {small[0], big[0]}, new double[] {small[1], big[1]}); 380 381 // recursive split und grid clustering eher static 292 382 TREE.recursiveSplit(TREE); 293 383 … … 295 385 ArrayList<QuadTree> l = new ArrayList<QuadTree>(TREE.getList(TREE)); 296 386 297 // recursive grid clustering (tree pruning), the values are stored in c luster387 // recursive grid clustering (tree pruning), the values are stored in ccluster 298 388 TREE.gridClustering(l); 299 389 300 390 // wir iterieren durch die cluster und sammeln uns die instanzen daraus 301 for(int i=0; i < cluster.size(); i++) { 302 ArrayList<QuadTreePayload<Instance>> current = cluster.get(i); 391 //ctraindata.clear(); 392 for( int i=0; i < QuadTree.ccluster.size(); i++ ) { 393 ArrayList<QuadTreePayload<Instance>> current = QuadTree.ccluster.get(i); 303 394 304 395 // i is the clusternumber 305 396 // we only allow clusters with Instances > ALPHA, other clusters are not considered! 306 if(current.size() > ALPHA) { 307 for(int j=0; j < current.size(); j++ ) { 308 if(!ctraindata.containsKey(i)) { 309 ctraindata.put(i, new Instances(traindata)); 397 //if(current.size() > QuadTree.alpha) { 398 if( current.size() > 4 ) { 399 for( int j=0; j < current.size(); j++ ) { 400 if( !ctraindata.containsKey(i) ) { 401 ctraindata.put(i, new Instances(train2)); 310 402 ctraindata.get(i).delete(); 311 403 } 312 404 ctraindata.get(i).add(current.get(j).getInst()); 313 405 } 406 }else{ 407 Console.traceln(Level.INFO, String.format("drop cluster, only: " + current.size() + " instances")); 314 408 } 315 316 317 409 } 318 410 411 // here we keep things we need later on 412 // QuadTree sizes for later use 413 this.csize = new HashMap<Integer, ArrayList<Double[][]>>(QuadTree.csize); 414 415 // pivot elements 416 //this.cpivots.clear(); 417 for( int i=0; i < FMAP.PA[0].length; i++ ) { 418 this.cpivots.put(FMAP.PA[0][i], (Instance)train.get(FMAP.PA[0][i]).copy()); 419 } 420 for( int j=0; j < FMAP.PA[0].length; j++ ) { 421 this.cpivots.put(FMAP.PA[1][j], (Instance)train.get(FMAP.PA[1][j]).copy()); 422 } 423 424 425 /* debug output 426 int pnumber; 427 Iterator<Integer> pivotnumber = cpivots.keySet().iterator(); 428 while ( pivotnumber.hasNext() ) { 429 pnumber = pivotnumber.next(); 430 Console.traceln(Level.INFO, String.format("pivot: "+pnumber+ " inst: "+cpivots.get(pnumber))); 431 } 432 */ 433 319 434 // train one classifier per cluster, we get the clusternumber from the traindata 320 435 int cnumber; 321 436 Iterator<Integer> clusternumber = ctraindata.keySet().iterator(); 437 //cclassifier.clear(); 322 438 while ( clusternumber.hasNext() ) { 323 cnumber = clusternumber.next(); 324 cclassifier.put(cnumber,setupClassifier()); 439 cnumber = clusternumber.next(); 440 cclassifier.put(cnumber,setupClassifier()); // das hier ist der eigentliche trainer 325 441 cclassifier.get(cnumber).buildClassifier(ctraindata.get(cnumber)); 326 442 //Console.traceln(Level.INFO, String.format("classifier in cluster "+cnumber)); 327 443 //Console.traceln(Level.INFO, String.format("" + ctraindata.get(cnumber).size() + " instances in cluster "+cnumber)); 328 444 } 445 446 //Console.traceln(Level.INFO, String.format("num clusters: "+cclassifier.size())); 329 447 } 330 448 } … … 332 450 333 451 /** 334 * hier stecken die Fastmap koordinaten drin 335 * sowie als Payload jeweils 1 weka instanz 452 * Payload for the QuadTree. 453 * x and y are the calculated Fastmap values. 454 * T is a weka instance. 336 455 */ 337 456 public class QuadTreePayload<T> { … … 377 496 private int target_dims = 0; 378 497 379 /*3 x k tmp projections array, we need this for later projections*/ 380 double[][] tmpX; 381 382 /**/ 383 public Fastmap(int k, double[][] O) { 384 this.tmpX = new double[2*k+1][k]; 498 // if we already have the pivot elements 499 private boolean pivot_set = false; 500 501 502 public Fastmap(int k) { 503 this.target_dims = k; 504 } 505 506 /** 507 * Sets the distance matrix 508 * and params that depend on this 509 * @param O 510 */ 511 public void setDistmat(double[][] O) { 385 512 this.O = O; 386 513 int N = O.length; 387 388 this.target_dims = k; 389 390 this.X = new double[N][k]; 391 this.PA = new int[2][k]; 514 this.X = new double[N][this.target_dims]; 515 this.PA = new int[2][this.target_dims]; 516 } 517 518 /** 519 * Set pivot elements, we need that to classify instances 520 * after the calculation is complete (because we then want to reuse 521 * only the pivot elements). 522 * 523 * @param pi 524 */ 525 public void setPivots(int[][] pi) { 526 this.pivot_set = true; 527 this.PA = pi; 528 } 529 530 /** 531 * Return the pivot elements that were chosen during the calculation 532 * 533 * @return 534 */ 535 public int[][] getPivots() { 536 return this.PA; 392 537 } 393 538 … … 405 550 406 551 // basis is object distance, we get this from our distance matrix 407 // alternatively we could provide a distance function that takes 2 vectors408 552 double tmp = this.O[x][y] * this.O[x][y]; 409 553 410 554 // decrease by projections 411 for(int i=0; i < k; i++) { 412 //double tmp2 = Math.abs(this.X[x][i] - this.X[y][i]); 413 double tmp2 = (this.X[x][i] - this.X[y][i]); 555 for( int i=0; i < k; i++ ) { 556 double tmp2 = (this.X[x][i] - this.X[y][i]); 414 557 tmp -= tmp2 * tmp2; 415 558 } 416 559 417 560 return Math.abs(tmp); 418 }419 420 /**421 * Distance calculation used for adding an Instance after initialization is complete422 *423 * @param x x index of x image (if k==0 x object)424 * @param y y index of y image (if k==0 y object)425 * @param kdimensionality426 * @param distmat temp distmatrix for the instance to be added427 * @return distance between x, y428 */429 public double tmpDist(int x, int y, int k, double[][] distmat) {430 double tmp = distmat[x][y] * distmat[x][y];431 432 // decrease by projections433 for(int i=0; i < k; i++) {434 double tmp2 = (this.tmpX[x][i] - this.tmpX[y][i]);435 tmp -= tmp2 * tmp2;436 }437 438 //return Math.abs(tmp);439 return tmp;440 }441 442 /**443 * Projects an instance after initialization is complete444 *445 * This uses the previously saved pivot elements.446 *447 * @param distmat distance matrix of the instance and pivot elements (3x3 matrix)448 * @return vector of the projection values (k-vector)449 */450 public double[] addInstance(double[][] distmat) {451 452 for(int k=0; k < this.target_dims; k++) {453 454 double dxy = this.dist(this.PA[0][k], this.PA[1][k], k);455 456 for(int i=0; i < distmat.length; i++) {457 458 double dix = this.tmpDist(i, 2*k+1, k, distmat);459 double diy = this.tmpDist(i, 2*k+2, k, distmat);460 461 // projektion speichern462 this.tmpX[i][k] = (dix + dxy - diy) / (2 * Math.sqrt(dxy));463 }464 }465 466 double[] ret = new double[this.target_dims];467 for(int k=0; k < this.target_dims; k++) {468 ret[k] = this.tmpX[0][k];469 }470 return ret;471 561 } 472 562 … … 482 572 int ret = 0; 483 573 484 for( int i=0; i < O.length; i++) {574 for( int i=0; i < O.length; i++ ) { 485 575 double dist = this.dist(i, index, this.col); 486 if( i != index && dist > furthest) {576 if( i != index && dist > furthest ) { 487 577 furthest = dist; 488 578 ret = i; … … 514 604 515 605 /** 516 * Calculates the new k-vector values 606 * Calculates the new k-vector values (projections) 607 * 608 * This is basically algorithm 2 of the fastmap paper. 609 * We just added the possibility to pre-set the pivot elements because 610 * we need to classify single instances after the computation is already done. 517 611 * 518 612 * @param dims dimensionality … … 520 614 public void calculate() { 521 615 522 for(int k=0; k <this.target_dims; k++) { 523 616 for( int k=0; k < this.target_dims; k++ ) { 524 617 // 2) choose pivot objects 525 int[] pivots = this.findDistantObjects(); 526 527 // 3) record ids of pivot objects 528 this.PA[0][this.col] = pivots[0]; 529 this.PA[1][this.col] = pivots[1]; 618 if ( !this.pivot_set ) { 619 int[] pivots = this.findDistantObjects(); 620 621 // 3) record ids of pivot objects 622 this.PA[0][this.col] = pivots[0]; 623 this.PA[1][this.col] = pivots[1]; 624 } 530 625 531 626 // 4) inter object distances are zero (this.X is initialized with 0 so we just continue) 532 if( this.dist(pivots[0], pivots[1], this.col) == 0) {627 if( this.dist(this.PA[0][this.col], this.PA[1][this.col], this.col) == 0 ) { 533 628 continue; 534 629 } 535 630 536 631 // 5) project the objects on the line between the pivots 537 double dxy = this.dist( pivots[0], pivots[1], this.col);538 for( int i=0; i < this.O.length; i++) {632 double dxy = this.dist(this.PA[0][this.col], this.PA[1][this.col], this.col); 633 for( int i=0; i < this.O.length; i++ ) { 539 634 540 double dix = this.dist(i, pivots[0], this.col);541 double diy = this.dist(i, pivots[1], this.col);542 635 double dix = this.dist(i, this.PA[0][this.col], this.col); 636 double diy = this.dist(i, this.PA[1][this.col], this.col); 637 543 638 double tmp = (dix + dxy - diy) / (2 * Math.sqrt(dxy)); 544 639 640 // save the projection 545 641 this.X[i][this.col] = tmp; 546 642 } … … 551 647 552 648 /** 553 * returns the result matrix 649 * returns the result matrix of the projections 650 * 554 651 * @return calculated result 555 652 */ … … 558 655 } 559 656 } 560 561 562 /**563 * QuadTree implementation564 *565 * QuadTree gets a list of instances and then recursively split them into 4 childs566 * For this it uses the median of the 2 values x,y567 */568 public class QuadTree {569 570 // 1 parent or null571 private QuadTree parent = null;572 573 // 4 childs, 1 per quadrant574 private QuadTree child_nw;575 private QuadTree child_ne;576 private QuadTree child_se;577 private QuadTree child_sw;578 579 // list (only helps with generate list of childs!)580 private ArrayList<QuadTree> l = new ArrayList<QuadTree>();581 582 // level only used for debugging583 public int level = 0;584 585 // size of the quadrant586 private double[] x;587 private double[] y;588 589 public boolean verbose = false;590 591 // payload, mal sehen ob das geht mit dem generic592 // evtl. statt ArrayList eigene QuadTreePayloadlist593 private ArrayList<QuadTreePayload<Instance>> payload;594 595 public QuadTree(QuadTree parent, ArrayList<QuadTreePayload<Instance>> payload) {596 this.parent = parent;597 this.payload = payload;598 }599 600 601 public String toString() {602 String n = "";603 if(this.parent == null) {604 n += "rootnode ";605 }606 String level = new String(new char[this.level]).replace("\0", "-");607 n += level + " instances: " + this.getNumbers();608 return n;609 }610 611 /**612 * Returns the payload, used for clustering613 * in the clustering list we only have children with paylod614 *615 * @return payload616 */617 public ArrayList<QuadTreePayload<Instance>> getPayload() {618 return this.payload;619 }620 621 /**622 * Calculate the density of this quadrant623 *624 * density = number of instances / global size (all instances)625 *626 * @return density627 */628 public double getDensity() {629 double dens = 0;630 dens = (double)this.getNumbers() / SIZE;631 return dens;632 }633 634 public void setSize(double[] x, double[] y){635 this.x = x;636 this.y = y;637 }638 639 public double[][] getSize() {640 return new double[][] {this.x, this.y};641 }642 643 public Double[][] getSizeDouble() {644 Double[] tmpX = new Double[2];645 Double[] tmpY = new Double[2];646 647 tmpX[0] = this.x[0];648 tmpX[1] = this.x[1];649 650 tmpY[0] = this.y[0];651 tmpY[1] = this.y[1];652 653 return new Double[][] {tmpX, tmpY};654 }655 656 /**657 * TODO: DRY, median ist immer dasselbe658 *659 * @return median for x660 */661 private double getMedianForX() {662 double med_x =0 ;663 664 Collections.sort(this.payload, new Comparator<QuadTreePayload<Instance>>() {665 @Override666 public int compare(QuadTreePayload<Instance> x1, QuadTreePayload<Instance> x2) {667 return Double.compare(x1.x, x2.x);668 }669 });670 671 if(this.payload.size() % 2 == 0) {672 int mid = this.payload.size() / 2;673 med_x = (this.payload.get(mid).x + this.payload.get(mid+1).x) / 2;674 }else {675 int mid = this.payload.size() / 2;676 med_x = this.payload.get(mid).x;677 }678 679 if(this.verbose) {680 System.out.println("sorted:");681 for(int i = 0; i < this.payload.size(); i++) {682 System.out.print(""+this.payload.get(i).x+",");683 }684 System.out.println("median x: " + med_x);685 }686 return med_x;687 }688 689 private double getMedianForY() {690 double med_y =0 ;691 692 Collections.sort(this.payload, new Comparator<QuadTreePayload<Instance>>() {693 @Override694 public int compare(QuadTreePayload<Instance> y1, QuadTreePayload<Instance> y2) {695 return Double.compare(y1.y, y2.y);696 }697 });698 699 if(this.payload.size() % 2 == 0) {700 int mid = this.payload.size() / 2;701 med_y = (this.payload.get(mid).y + this.payload.get(mid+1).y) / 2;702 }else {703 int mid = this.payload.size() / 2;704 med_y = this.payload.get(mid).y;705 }706 707 if(this.verbose) {708 System.out.println("sorted:");709 for(int i = 0; i < this.payload.size(); i++) {710 System.out.print(""+this.payload.get(i).y+",");711 }712 System.out.println("median y: " + med_y);713 }714 return med_y;715 }716 717 /**718 * Reurns the number of instances in the payload719 *720 * @return int number of instances721 */722 public int getNumbers() {723 int number = 0;724 if(this.payload != null) {725 number = this.payload.size();726 }727 return number;728 }729 730 /**731 * Calculate median values of payload for x, y and split into 4 sectors732 *733 * @return Array of QuadTree nodes (4 childs)734 * @throws Exception if we would run into an recursive loop735 */736 public QuadTree[] split() throws Exception {737 738 double medx = this.getMedianForX();739 double medy = this.getMedianForY();740 741 // Payload lists for each child742 ArrayList<QuadTreePayload<Instance>> nw = new ArrayList<QuadTreePayload<Instance>>();743 ArrayList<QuadTreePayload<Instance>> sw = new ArrayList<QuadTreePayload<Instance>>();744 ArrayList<QuadTreePayload<Instance>> ne = new ArrayList<QuadTreePayload<Instance>>();745 ArrayList<QuadTreePayload<Instance>> se = new ArrayList<QuadTreePayload<Instance>>();746 747 // sort the payloads to new payloads748 // here we have the problem that payloads with the same values are sorted749 // into the same slots and it could happen that medx and medy = size_x[1] and size_y[1]750 // in that case we would have an endless loop751 for(int i=0; i < this.payload.size(); i++) {752 753 QuadTreePayload<Instance> item = this.payload.get(i);754 755 // north west756 if(item.x <= medx && item.y >= medy) {757 nw.add(item);758 }759 760 // south west761 else if(item.x <= medx && item.y <= medy) {762 sw.add(item);763 }764 765 // north east766 else if(item.x >= medx && item.y >= medy) {767 ne.add(item);768 }769 770 // south east771 else if(item.x >= medx && item.y <= medy) {772 se.add(item);773 }774 }775 776 // if we assign one child a payload equal to our own (see problem above)777 // we throw an exceptions which stops the recursion on this node778 // second error is minimum number of instances779 //Console.traceln(Level.INFO, String.format("NW: "+ nw.size() + " SW: " + sw.size() + " NE: " + ne.size() + " SE: " + se.size()));780 if(nw.equals(this.payload)) {781 throw new Exception("payload equal");782 }783 if(sw.equals(this.payload)) {784 throw new Exception("payload equal");785 }786 if(ne.equals(this.payload)) {787 throw new Exception("payload equal");788 }789 if(se.equals(this.payload)) {790 throw new Exception("payload equal");791 }792 793 this.child_nw = new QuadTree(this, nw);794 this.child_nw.setSize(new double[] {this.x[0], medx}, new double[] {medy, this.y[1]});795 this.child_nw.level = this.level + 1;796 797 this.child_sw = new QuadTree(this, sw);798 this.child_sw.setSize(new double[] {this.x[0], medx}, new double[] {this.y[0], medy});799 this.child_sw.level = this.level + 1;800 801 this.child_ne = new QuadTree(this, ne);802 this.child_ne.setSize(new double[] {medx, this.x[1]}, new double[] {medy, this.y[1]});803 this.child_ne.level = this.level + 1;804 805 this.child_se = new QuadTree(this, se);806 this.child_se.setSize(new double[] {medx, this.x[1]}, new double[] {this.y[0], medy});807 this.child_se.level = this.level + 1;808 809 this.payload = null;810 return new QuadTree[] {this.child_nw, this.child_ne, this.child_se, this.child_sw};811 }812 813 /**814 * TODO: evt. auslagern, eigentlich auch eher ne statische methode815 *816 * @param q817 */818 public void recursiveSplit(QuadTree q) {819 if(this.verbose) {820 System.out.println("splitting: "+ q);821 }822 if(q.getNumbers() < ALPHA) {823 return;824 }else{825 // exception wird geworfen wenn es zur endlosrekursion kommen würde (siehe text bei split())826 try {827 QuadTree[] childs = q.split();828 this.recursiveSplit(childs[0]);829 this.recursiveSplit(childs[1]);830 this.recursiveSplit(childs[2]);831 this.recursiveSplit(childs[3]);832 }catch(Exception e) {833 return;834 }835 }836 }837 838 /**839 * returns an list of childs sorted by density840 *841 * @param q QuadTree842 * @return list of QuadTrees843 */844 private void generateList(QuadTree q) {845 846 // entweder es gibtes 4 childs oder keins847 if(q.child_ne == null) {848 this.l.add(q);849 //return;850 }851 852 if(q.child_ne != null) {853 this.generateList(q.child_ne);854 }855 if(q.child_nw != null) {856 this.generateList(q.child_nw);857 }858 if(q.child_se != null) {859 this.generateList(q.child_se);860 }861 if(q.child_sw != null) {862 this.generateList(q.child_sw);863 }864 }865 866 /**867 * Checks if passed QuadTree is neighbouring to us868 *869 * @param q QuadTree870 * @return true if passed QuadTree is a neighbour871 */872 public boolean isNeighbour(QuadTree q) {873 boolean is_neighbour = false;874 875 double[][] our_size = this.getSize();876 double[][] new_size = q.getSize();877 878 // X is i=0, Y is i=1879 for(int i =0; i < 2; i++) {880 // check X and Y (0,1)881 // we are smaller than q882 // -------------- q883 // ------- we884 if(our_size[i][0] >= new_size[i][0] && our_size[i][1] <= new_size[i][1]) {885 is_neighbour = true;886 }887 // we overlap with q at some point888 //a) ---------------q889 // ----------- we890 //b) --------- q891 // --------- we892 if((our_size[i][0] >= new_size[i][0] && our_size[i][0] <= new_size[i][1]) ||893 (our_size[i][1] >= new_size[i][0] && our_size[i][1] <= new_size[i][1])) {894 is_neighbour = true;895 }896 // we are larger than q897 // ---- q898 // ---------- we899 if(our_size[i][1] >= new_size[i][1] && our_size[i][0] <= new_size[i][0]) {900 is_neighbour = true;901 }902 }903 904 if(is_neighbour && this.verbose) {905 System.out.println(this + " neighbour of: " + q);906 }907 908 return is_neighbour;909 }910 911 912 /**913 * todo914 */915 public boolean isInside(double x, double y) {916 boolean is_inside_x = false;917 boolean is_inside_y = false;918 double[][] our_size = this.getSize();919 920 921 if(our_size[0][0] <= x && our_size[0][1] >= x) {922 is_inside_x = true;923 }924 925 if(our_size[1][0] <= y && our_size[1][1] >= y) {926 is_inside_y = true;927 }928 929 930 if(is_inside_y && is_inside_x && this.verbose) {931 System.out.println(this + " contains: " + x + ", "+ y);932 }933 934 return is_inside_x && is_inside_y;935 }936 937 938 /**939 * Perform Pruning and clustering of the quadtree940 *941 * 1) get list of leaf quadrants942 * 2) sort by their density943 * 3) set stop_rule to 0.5 * highest Density in the list944 * 4) merge all nodes with a density > stop_rule to the new cluster and remove all from list945 * 5) repeat946 *947 * @param q List of QuadTree (children only)948 */949 public void gridClustering(ArrayList<QuadTree> list) {950 951 //System.out.println("listsize: " + list.size());952 953 // basisfall954 if(list.size() == 0) {955 return;956 }957 958 double stop_rule;959 QuadTree biggest;960 QuadTree current;961 962 // current clusterlist963 ArrayList<QuadTreePayload<Instance>> current_cluster;964 965 // remove list966 ArrayList<Integer> remove = new ArrayList<Integer>();967 968 // 1. find biggest969 biggest = list.get(list.size()-1);970 stop_rule = biggest.getDensity() * 0.5;971 972 current_cluster = new ArrayList<QuadTreePayload<Instance>>();973 current_cluster.addAll(biggest.getPayload());974 //System.out.println("adding "+biggest.getDensity() + " to cluster");975 976 // remove the biggest because we are starting with it977 remove.add(list.size()-1);978 //System.out.println("removing "+biggest.getDensity() + " from list");979 980 ArrayList<Double[][]> tmpSize = new ArrayList<Double[][]>();981 tmpSize.add(biggest.getSizeDouble());982 983 // check the items for their density984 for(int i=list.size()-1; i >= 0; i--) {985 current = list.get(i);986 987 // 2. find neighbours with correct density988 // if density > stop_rule and is_neighbour add to cluster and remove from list989 if(current.getDensity() > stop_rule && !current.equals(biggest) && current.isNeighbour(biggest)) {990 //System.out.println("adding " + current.getDensity() + " to cluster");991 //System.out.println("removing "+current.getDensity() + " from list");992 current_cluster.addAll(current.getPayload());993 994 // wir können hier nicht removen weil wir sonst den index verschieben995 remove.add(i);996 997 // außerdem brauchen wir die größe998 tmpSize.add(current.getSizeDouble());999 }1000 }1001 1002 // 3. remove from list1003 for(Integer item: remove) {1004 list.remove((int)item);1005 }1006 1007 // 4. add to cluster1008 cluster.add(current_cluster);1009 1010 // 5. add size of our current (biggest)1011 // we need that to classify test instances to a cluster1012 Integer cnumber = new Integer(cluster.size()-1);1013 if(CSIZE.containsKey(cnumber) == false) {1014 CSIZE.put(cnumber, tmpSize);1015 }1016 1017 // recurse1018 //System.out.println("restlist " + list.size());1019 this.gridClustering(list);1020 }1021 1022 public void printInfo() {1023 System.out.println("we have " + cluster.size() + " clusters");1024 1025 for(int i=0; i < cluster.size(); i++) {1026 System.out.println("cluster: "+i+ " size: "+ cluster.get(i).size());1027 }1028 }1029 1030 /**1031 * Helper Method to get a sortet list (by density) for all1032 * children1033 *1034 * @param q QuadTree1035 * @return Sorted ArrayList of quadtrees1036 */1037 public ArrayList<QuadTree> getList(QuadTree q) {1038 this.generateList(q);1039 1040 Collections.sort(this.l, new Comparator<QuadTree>() {1041 @Override1042 public int compare(QuadTree x1, QuadTree x2) {1043 return Double.compare(x1.getDensity(), x2.getDensity());1044 }1045 });1046 1047 return this.l;1048 }1049 }1050 657 }
Note: See TracChangeset
for help on using the changeset viewer.