- Timestamp:
- 08/17/16 16:10:13 (8 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/CrossPare/src/de/ugoe/cs/cpdp/training/MetricMatchingTraining.java
r86 r137 27 27 import java.util.logging.Level; 28 28 29 import javax.management.RuntimeErrorException;30 31 29 import java.util.Random; 32 30 … … 46 44 import weka.core.Instances; 47 45 48 46 /** 47 * Implements Heterogenous Defect Prediction after Nam et al. 48 * 49 * TODO: 50 * - spacing, coding conventions 51 * - we depend on having exactly one class attribute on multiple locations 52 * - 53 */ 49 54 public class MetricMatchingTraining extends WekaBaseTraining implements ISetWiseTestdataAwareTrainingStrategy { 50 55 … … 78 83 79 84 /** 80 * We need the test data instances to do a metric matching, so in this special case we get this data81 * before evaluation 85 * We need the test data instances to do a metric matching, so in this special case we get this data 86 * before evaluation. 82 87 */ 83 88 @Override … … 85 90 this.traindataSet = traindataSet; 86 91 87 int rank = 0; // we want at least 5 matching attributes92 double score = 0; // custom ranking score to select the best training data from the set 88 93 int num = 0; 89 94 int biggest_num = 0; … … 92 97 for (Instances traindata : this.traindataSet) { 93 98 num++; 99 94 100 tmp = new MetricMatch(traindata, testdata); 95 101 … … 97 103 try { 98 104 tmp.attributeSelection(); 105 tmp.matchAttributes(this.method, this.threshold); 99 106 }catch(Exception e) { 100 107 e.printStackTrace(); … … 102 109 } 103 110 104 if (this.method.equals("spearman")) {105 tmp.spearmansRankCorrelation(this.threshold);106 }107 else if (this.method.equals("kolmogorov")) {108 tmp.kolmogorovSmirnovTest(this.threshold);109 }110 else if( this.method.equals("percentile") ) {111 tmp.percentiles(this.threshold);112 }113 else {114 throw new RuntimeException("unknown method");115 }116 117 111 // we only select the training data from our set with the most matching attributes 118 if (tmp.get Rank() > rank) {119 rank = tmp.getRank();112 if (tmp.getScore() > score) { 113 score = tmp.getScore(); 120 114 biggest = tmp; 121 115 biggest_num = num; … … 127 121 } 128 122 129 // we use the best match 130 123 // we use the best match according to our matching score 131 124 this.mm = biggest; 132 125 Instances ilist = this.mm.getMatchedTrain(); 133 Console.traceln(Level.INFO, "Chosing the trainingdata set num "+biggest_num +" with " + rank + " matching attributes, " + ilist.size() + " instances out of a possible set of " + traindataSet.size() + " sets"); 134 126 Console.traceln(Level.INFO, "Chosing the trainingdata set num "+biggest_num +" with " + score + " matching score, " + ilist.size() + " instances, and " + biggest.attributes.size() + " matched attributes out of a possible set of " + traindataSet.size() + " sets"); 127 128 for(int i = 0; i < this.mm.attributes.size(); i++) { 129 Console.traceln(Level.INFO, "Matched Attribute: " + this.mm.train.attribute(i).name() + " with " + this.mm.test.attribute((int)this.mm.attributes.get(i)).name()); 130 } 135 131 // replace traindataSEt 136 132 //traindataSet = new SetUniqueList<Instances>(); … … 156 152 157 153 /** 158 * encapsulates the classifier configured with WekaBase 154 * Encapsulates the classifier configured with WekaBase within but use metric matching. 155 * This allows us to use any Weka classifier with Heterogenous Defect Prediction. 159 156 */ 160 157 public class MetricMatchingClassifier extends AbstractClassifier { … … 197 194 */ 198 195 public class MetricMatch { 199 Instances train; 200 Instances test; 201 202 HashMap<Integer, Integer> attributes = new HashMap<Integer,Integer>(); 203 204 ArrayList<double[]> train_values; 205 ArrayList<double[]> test_values; 206 207 // todo: this constructor does not work 208 public MetricMatch() { 209 } 210 211 public MetricMatch(Instances train, Instances test) { 212 this.train = new Instances(train); // expensive! but we are dropping the attributes 213 this.test = new Instances(test); 196 Instances train; 197 Instances test; 198 199 // used to sum up the matching values of all attributes 200 double p_sum = 0; 201 202 // attribute matching, train -> test 203 HashMap<Integer, Integer> attributes = new HashMap<Integer,Integer>(); 204 //double[][] weights; /* weight matrix, needed to find maximum weighted bipartite matching */ 205 206 ArrayList<double[]> train_values; 207 ArrayList<double[]> test_values; 208 209 // todo: this constructor does not work 210 public MetricMatch() { 211 } 212 213 public MetricMatch(Instances train, Instances test) { 214 // expensive! but we are dropping the attributes so we have to copy all of the data 215 this.train = new Instances(train); 216 this.test = new Instances(test); 214 217 215 216 217 218 219 218 // 1. convert metrics of testdata and traindata to later use in test 219 this.train_values = new ArrayList<double[]>(); 220 for (int i = 0; i < this.train.numAttributes()-1; i++) { 221 this.train_values.add(train.attributeToDoubleArray(i)); 222 } 220 223 221 this.test_values = new ArrayList<double[]>(); 222 for (int i=0; i < this.test.numAttributes()-1; i++) { 223 this.test_values.add(this.test.attributeToDoubleArray(i)); 224 } 225 } 226 227 /** 228 * returns the number of matched attributes 229 * as a way of scoring traindata sets individually 230 * 231 * @return 232 */ 233 public int getRank() { 234 return this.attributes.size(); 224 this.test_values = new ArrayList<double[]>(); 225 for (int i=0; i < this.test.numAttributes()-1; i++) { 226 this.test_values.add(this.test.attributeToDoubleArray(i)); 227 } 228 } 229 230 /** 231 * We have a lot of matching possibilities. 232 * Here we try to determine the best one. 233 * 234 * @return double matching score 235 */ 236 public double getScore() { 237 int as = this.attributes.size(); 238 239 // we use thresholding ranking approach for numInstances to influence the matching score 240 int instances = this.train.numInstances(); 241 int inst_rank = 0; 242 if(instances > 100) { 243 inst_rank = 1; 244 } 245 if(instances > 500) { 246 inst_rank = 2; 247 } 248 249 return this.p_sum + as + inst_rank; 250 } 251 252 public HashMap<Integer, Integer> getAttributes() { 253 return this.attributes; 235 254 } 236 255 … … 260 279 261 280 ni.setClassValue(test.value(test.classAttribute())); 262 263 //System.out.println(ni); 281 264 282 return ni; 265 283 } … … 283 301 } 284 302 303 // todo: there must be a better way 285 304 // https://weka.wikispaces.com/Programmatic+Use 286 305 private Instances getMatchedInstances(String name, Instances data) { 287 // construct our new attributes 306 //Console.traceln(Level.INFO, "Constructing instances from: " + name); 307 // construct our new attributes 288 308 Attribute[] attrs = new Attribute[this.attributes.size()+1]; 289 309 FastVector fwTrain = new FastVector(this.attributes.size()); … … 301 321 newTrain.setClassIndex(newTrain.numAttributes()-1); 302 322 323 //Console.traceln(Level.INFO, "data attributes: " + data.numAttributes() + ", this.attributes: "+this.attributes.size()); 324 303 325 for (int i=0; i < data.size(); i++) { 304 326 Instance ni = new DenseInstance(this.attributes.size()+1); … … 314 336 value = (int)values.getKey(); 315 337 } 338 //Console.traceln(Level.INFO, "setting attribute " + j + " with data from instance: " + i); 316 339 ni.setValue(newTrain.attribute(j), data.instance(i).value(value)); 317 340 j++; … … 328 351 /** 329 352 * performs the attribute selection 330 * we perform attribute significance tests and drop attributes 353 * we perform attribute significance tests and drop attributes 354 * 355 * attribute selection is only performed on the source dataset 356 * we retain the top 15% attributes (if 15% is a float we just use the integer part) 331 357 */ 332 358 public void attributeSelection() throws Exception { … … 335 361 //Console.traceln(Level.INFO, "-----"); 336 362 //Console.traceln(Level.INFO, "Attribute Selection on Test Attributes ("+this.test.numAttributes()+")"); 337 this.attributeSelection(this.test);363 //this.attributeSelection(this.test); 338 364 //Console.traceln(Level.INFO, "-----"); 339 365 } … … 358 384 HashMap<String, Double> sorted = sortByValues(saeval); 359 385 360 // die letzen 15% wollen wir haben361 float last = ((float)saeval.size() / 100) * 15;386 // die besten 15% wollen wir haben 387 double last = ((double)saeval.size() / 100.0) * 15.0; 362 388 int drop_first = saeval.size() - (int)last; 363 389 … … 380 406 } 381 407 drop_first-=1; 382 383 384 408 } 385 ////Console.traceln(Level.INFO, "Now we have " + which.numAttributes() + " attributes left (incl. class attribute!)");409 //Console.traceln(Level.INFO, "Now we have " + which.numAttributes() + " attributes left (incl. class attribute!)"); 386 410 } 387 411 … … 406 430 } 407 431 432 433 434 public void matchAttributes(String type, double cutoff) { 435 436 437 MWBMatchingAlgorithm mwbm = new MWBMatchingAlgorithm(this.train.numAttributes(), this.test.numAttributes()); 438 439 if (type.equals("spearman")) { 440 this.spearmansRankCorrelation(cutoff, mwbm); 441 }else if(type.equals("ks")) { 442 this.kolmogorovSmirnovTest(cutoff, mwbm); 443 }else if(type.equals("percentile")) { 444 this.percentiles(cutoff, mwbm); 445 }else { 446 throw new RuntimeException("unknown matching method"); 447 } 448 449 // resulting maximal match 450 int[] result = mwbm.getMatching(); 451 for( int i = 0; i < result.length; i++) { 452 453 // -1 means that it is not in the set of maximal matching 454 if( i != -1 && result[i] != -1) { 455 //Console.traceln(Level.INFO, "Found maximal bipartite match between: "+ i + " and " + result[i]); 456 this.attributes.put(i, result[i]); 457 } 458 } 459 } 460 461 408 462 /** 409 463 * Calculates the Percentiles of the source and target metrics. … … 411 465 * @param cutoff 412 466 */ 413 public void percentiles(double cutoff) { 414 for( int i = 0; i < this.train.numAttributes()-1; i++ ) { 415 for( int j = 0; j < this.test.numAttributes()-1; j++ ) { 467 public void percentiles(double cutoff, MWBMatchingAlgorithm mwbm) { 468 for( int i = 0; i < this.train.numAttributes(); i++ ) { 469 for( int j = 0; j < this.test.numAttributes(); j++ ) { 470 // negative infinity counts as not present, we do this so we don't have to map between attribute indexs in weka 471 // and the result of the mwbm computation 472 mwbm.setWeight(i, j, Double.NEGATIVE_INFINITY); 473 416 474 // class attributes are not relevant 417 if ( this.train.classIndex() == i) {475 if (this.test.classIndex() == j) { 418 476 continue; 419 477 } 420 if ( this.test.classIndex() == j) {478 if (this.train.classIndex() == i) { 421 479 continue; 422 480 } 481 482 // get percentiles 483 double train[] = this.train_values.get(i); 484 double test[] = this.test_values.get(j); 423 485 486 Arrays.sort(train); 487 Arrays.sort(test); 424 488 425 if( !this.attributes.containsKey(i) ) { 426 // get percentiles 427 double train[] = this.train_values.get(i); 428 double test[] = this.test_values.get(j); 429 430 Arrays.sort(train); 431 Arrays.sort(test); 432 433 // percentiles 434 double train_p; 435 double test_p; 436 double score = 0.0; 437 for( double p=0.1; p < 1; p+=0.1 ) { 438 train_p = train[(int)Math.ceil(train.length * p)]; 439 test_p = test[(int)Math.ceil(test.length * p)]; 440 441 if( train_p > test_p ) { 442 score += test_p / train_p; 443 }else { 444 score += train_p / test_p; 445 } 446 } 447 448 if( score > cutoff ) { 449 this.attributes.put(i, j); 489 // percentiles 490 double train_p; 491 double test_p; 492 double score = 0.0; 493 for( int p=1; p <= 9; p++ ) { 494 train_p = train[(int)Math.ceil(train.length * (p/100))]; 495 test_p = test[(int)Math.ceil(test.length * (p/100))]; 496 497 if( train_p > test_p ) { 498 score += test_p / train_p; 499 }else { 500 score += train_p / test_p; 450 501 } 451 502 } 503 504 if( score > cutoff ) { 505 this.p_sum += score; 506 mwbm.setWeight(i, j, score); 507 } 452 508 } 453 509 } … … 455 511 456 512 /** 457 * calculate Spearmans rank correlation coefficient as matching score 513 * Calculate Spearmans rank correlation coefficient as matching score 514 * The number of instances for the source and target needs to be the same so we randomly sample from the bigger one. 458 515 * 459 516 * @param cutoff 517 * @param mwbmatching 460 518 */ 461 public void spearmansRankCorrelation(double cutoff) { 462 double p = 0; 519 public void spearmansRankCorrelation(double cutoff, MWBMatchingAlgorithm mwbm) { 520 double p = 0; 521 463 522 SpearmansCorrelation t = new SpearmansCorrelation(); 464 523 … … 469 528 this.sample(this.test, this.train, this.test_values); 470 529 } 471 472 // try out possible attribute combinations 473 for (int i=0; i < this.train.numAttributes()-1; i++) { 474 for (int j=0; j < this.test.numAttributes()-1; j++) { 530 531 // try out possible attribute combinations 532 for (int i=0; i < this.train.numAttributes(); i++) { 533 534 for (int j=0; j < this.test.numAttributes(); j++) { 535 // negative infinity counts as not present, we do this so we don't have to map between attribute indexs in weka 536 // and the result of the mwbm computation 537 mwbm.setWeight(i, j, Double.NEGATIVE_INFINITY); 538 475 539 // class attributes are not relevant 540 if (this.test.classIndex() == j) { 541 continue; 542 } 476 543 if (this.train.classIndex() == i) { 477 544 continue; 478 545 } 479 if (this.test.classIndex() == j) {480 continue;481 }482 546 483 484 if (!this.attributes.containsKey(i)) { 485 p = t.correlation(this.train_values.get(i), this.test_values.get(j)); 486 if (p > cutoff) { 487 this.attributes.put(i, j); 488 } 547 p = t.correlation(this.train_values.get(i), this.test_values.get(j)); 548 if (p > cutoff) { 549 this.p_sum += p; 550 mwbm.setWeight(i, j, p); 551 //Console.traceln(Level.INFO, "Found match: p-val: " + p); 489 552 } 490 553 } 491 554 } 492 } 493 494 555 556 //Console.traceln(Level.INFO, "Found " + this.attributes.size() + " matching attributes"); 557 } 558 559 495 560 public void sample(Instances bigger, Instances smaller, ArrayList<double[]> values) { 496 561 // we want to at keep the indices we select the same … … 535 600 * @return p-val 536 601 */ 537 public void kolmogorovSmirnovTest(double cutoff ) {602 public void kolmogorovSmirnovTest(double cutoff, MWBMatchingAlgorithm mwbm) { 538 603 double p = 0; 539 604 540 605 KolmogorovSmirnovTest t = new KolmogorovSmirnovTest(); 541 606 542 // todo: this should be symmetrical we don't have to compare i to j and then j to i543 // todo: this relies on the last attribute being the class,544 607 //Console.traceln(Level.INFO, "Starting Kolmogorov-Smirnov test for traindata size: " + this.train.size() + " attributes("+this.train.numAttributes()+") and testdata size: " + this.test.size() + " attributes("+this.test.numAttributes()+")"); 545 for (int i=0; i < this.train.numAttributes()-1; i++) { 546 for ( int j=0; j < this.test.numAttributes()-1; j++) { 547 //p = t.kolmogorovSmirnovTest(this.train_values.get(i), this.test_values.get(j)); 548 //p = t.kolmogorovSmirnovTest(this.train_values.get(i), this.test_values.get(j)); 608 for (int i=0; i < this.train.numAttributes(); i++) { 609 for ( int j=0; j < this.test.numAttributes(); j++) { 610 // negative infinity counts as not present, we do this so we don't have to map between attribute indexs in weka 611 // and the result of the mwbm computation 612 mwbm.setWeight(i, j, Double.NEGATIVE_INFINITY); 613 549 614 // class attributes are not relevant 550 if ( this.train.classIndex() == i) {615 if (this.test.classIndex() == j) { 551 616 continue; 552 617 } 553 if ( this.test.classIndex() == j) {618 if (this.train.classIndex() == i) { 554 619 continue; 555 620 } 556 // PRoblem: exactP is forced for small sample sizes and it never finishes 557 if (!this.attributes.containsKey(i)) { 558 559 // todo: output the values and complain on the math.commons mailinglist 560 p = t.kolmogorovSmirnovTest(this.train_values.get(i), this.test_values.get(j)); 561 if (p > cutoff) { 562 this.attributes.put(i, j); 563 } 621 622 // this may invoke exactP on small sample sizes which will not terminate in all cases 623 //p = t.kolmogorovSmirnovTest(this.train_values.get(i), this.test_values.get(j), false); 624 p = t.approximateP(t.kolmogorovSmirnovStatistic(this.train_values.get(i), this.test_values.get(j)), this.train_values.get(i).length, this.test_values.get(j).length); 625 if (p > cutoff) { 626 this.p_sum += p; 627 mwbm.setWeight(i, j, p); 564 628 } 565 629 } 566 630 } 567 568 //Console.traceln(Level.INFO, "Found " + this.attmatch.size() + " matching attributes"); 569 } 570 } 631 //Console.traceln(Level.INFO, "Found " + this.attributes.size() + " matching attributes"); 632 } 633 } 634 635 /* 636 * Copyright (c) 2007, Massachusetts Institute of Technology 637 * Copyright (c) 2005-2006, Regents of the University of California 638 * All rights reserved. 639 * 640 * Redistribution and use in source and binary forms, with or without 641 * modification, are permitted provided that the following conditions 642 * are met: 643 * 644 * * Redistributions of source code must retain the above copyright 645 * notice, this list of conditions and the following disclaimer. 646 * 647 * * Redistributions in binary form must reproduce the above copyright 648 * notice, this list of conditions and the following disclaimer in 649 * the documentation and/or other materials provided with the 650 * distribution. 651 * 652 * * Neither the name of the University of California, Berkeley nor 653 * the names of its contributors may be used to endorse or promote 654 * products derived from this software without specific prior 655 * written permission. 656 * 657 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 658 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 659 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 660 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 661 * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 662 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 663 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 664 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 665 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 666 * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 667 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED 668 * OF THE POSSIBILITY OF SUCH DAMAGE. 669 */ 670 671 672 673 /** 674 * An engine for finding the maximum-weight matching in a complete 675 * bipartite graph. Suppose we have two sets <i>S</i> and <i>T</i>, 676 * both of size <i>n</i>. For each <i>i</i> in <i>S</i> and <i>j</i> 677 * in <i>T</i>, we have a weight <i>w<sub>ij</sub></i>. A perfect 678 * matching <i>X</i> is a subset of <i>S</i> x <i>T</i> such that each 679 * <i>i</i> in <i>S</i> occurs in exactly one element of <i>X</i>, and 680 * each <i>j</i> in <i>T</i> occurs in exactly one element of 681 * <i>X</i>. Thus, <i>X</i> can be thought of as a one-to-one 682 * function from <i>S</i> to <i>T</i>. The weight of <i>X</i> is the 683 * sum, over (<i>i</i>, <i>j</i>) in <i>X</i>, of 684 * <i>w<sub>ij</sub></i>. A BipartiteMatcher takes the number 685 * <i>n</i> and the weights <i>w<sub>ij</sub></i>, and finds a perfect 686 * matching of maximum weight. 687 * 688 * It uses the Hungarian algorithm of Kuhn (1955), as improved and 689 * presented by E. L. Lawler in his book <cite>Combinatorial 690 * Optimization: Networks and Matroids</cite> (Holt, Rinehart and 691 * Winston, 1976, p. 205-206). The running time is 692 * O(<i>n</i><sup>3</sup>). The weights can be any finite real 693 * numbers; Lawler's algorithm assumes positive weights, so if 694 * necessary we add a constant <i>c</i> to all the weights before 695 * running the algorithm. This increases the weight of every perfect 696 * matching by <i>nc</i>, which doesn't change which perfect matchings 697 * have maximum weight. 698 * 699 * If a weight is set to Double.NEGATIVE_INFINITY, then the algorithm will 700 * behave as if that edge were not in the graph. If all the edges incident on 701 * a given node have weight Double.NEGATIVE_INFINITY, then the final result 702 * will not be a perfect matching, and an exception will be thrown. 703 */ 704 class MWBMatchingAlgorithm { 705 /** 706 * Creates a BipartiteMatcher without specifying the graph size. Calling 707 * any other method before calling reset will yield an 708 * IllegalStateException. 709 */ 710 711 /** 712 * Tolerance for comparisons to zero, to account for 713 * floating-point imprecision. We consider a positive number to 714 * be essentially zero if it is strictly less than TOL. 715 */ 716 private static final double TOL = 1e-10; 717 //Number of left side nodes 718 int n; 719 720 //Number of right side nodes 721 int m; 722 723 double[][] weights; 724 double minWeight; 725 double maxWeight; 726 727 // If (i, j) is in the mapping, then sMatches[i] = j and tMatches[j] = i. 728 // If i is unmatched, then sMatches[i] = -1 (and likewise for tMatches). 729 int[] sMatches; 730 int[] tMatches; 731 732 static final int NO_LABEL = -1; 733 static final int EMPTY_LABEL = -2; 734 735 int[] sLabels; 736 int[] tLabels; 737 738 double[] u; 739 double[] v; 740 741 double[] pi; 742 743 List<Integer> eligibleS = new ArrayList<Integer>(); 744 List<Integer> eligibleT = new ArrayList<Integer>(); 745 746 747 public MWBMatchingAlgorithm() { 748 n = -1; 749 m = -1; 750 } 751 752 /** 753 * Creates a BipartiteMatcher and prepares it to run on an n x m graph. 754 * All the weights are initially set to 1. 755 */ 756 public MWBMatchingAlgorithm(int n, int m) { 757 reset(n, m); 758 } 759 760 /** 761 * Resets the BipartiteMatcher to run on an n x m graph. The weights are 762 * all reset to 1. 763 */ 764 private void reset(int n, int m) { 765 if (n < 0 || m < 0) { 766 throw new IllegalArgumentException("Negative num nodes: " + n + " or " + m); 767 } 768 this.n = n; 769 this.m = m; 770 771 weights = new double[n][m]; 772 for (int i = 0; i < n; i++) { 773 for (int j = 0; j < m; j++) { 774 weights[i][j] = 1; 775 } 776 } 777 minWeight = 1; 778 maxWeight = Double.NEGATIVE_INFINITY; 779 780 sMatches = new int[n]; 781 tMatches = new int[m]; 782 sLabels = new int[n]; 783 tLabels = new int[m]; 784 u = new double[n]; 785 v = new double[m]; 786 pi = new double[m]; 787 788 } 789 /** 790 * Sets the weight w<sub>ij</sub> to the given value w. 791 * 792 * @throws IllegalArgumentException if i or j is outside the range [0, n). 793 */ 794 public void setWeight(int i, int j, double w) { 795 if (n == -1 || m == -1) { 796 throw new IllegalStateException("Graph size not specified."); 797 } 798 if ((i < 0) || (i >= n)) { 799 throw new IllegalArgumentException("i-value out of range: " + i); 800 } 801 if ((j < 0) || (j >= m)) { 802 throw new IllegalArgumentException("j-value out of range: " + j); 803 } 804 if (Double.isNaN(w)) { 805 throw new IllegalArgumentException("Illegal weight: " + w); 806 } 807 808 weights[i][j] = w; 809 if ((w > Double.NEGATIVE_INFINITY) && (w < minWeight)) { 810 minWeight = w; 811 } 812 if (w > maxWeight) { 813 maxWeight = w; 814 } 815 } 816 817 /** 818 * Returns a maximum-weight perfect matching relative to the weights 819 * specified with setWeight. The matching is represented as an array arr 820 * of length n, where arr[i] = j if (i,j) is in the matching. 821 */ 822 public int[] getMatching() { 823 if (n == -1 || m == -1 ) { 824 throw new IllegalStateException("Graph size not specified."); 825 } 826 if (n == 0) { 827 return new int[0]; 828 } 829 ensurePositiveWeights(); 830 831 // Step 0: Initialization 832 eligibleS.clear(); 833 eligibleT.clear(); 834 for (Integer i = 0; i < n; i++) { 835 sMatches[i] = -1; 836 837 u[i] = maxWeight; // ambiguous on p. 205 of Lawler, but see p. 202 838 839 // this is really first run of Step 1.0 840 sLabels[i] = EMPTY_LABEL; 841 eligibleS.add(i); 842 } 843 844 for (int j = 0; j < m; j++) { 845 tMatches[j] = -1; 846 847 v[j] = 0; 848 pi[j] = Double.POSITIVE_INFINITY; 849 850 // this is really first run of Step 1.0 851 tLabels[j] = NO_LABEL; 852 } 853 854 while (true) { 855 // Augment the matching until we can't augment any more given the 856 // current settings of the dual variables. 857 while (true) { 858 // Steps 1.1-1.4: Find an augmenting path 859 int lastNode = findAugmentingPath(); 860 if (lastNode == -1) { 861 break; // no augmenting path 862 } 863 864 // Step 2: Augmentation 865 flipPath(lastNode); 866 for (int i = 0; i < n; i++) 867 sLabels[i] = NO_LABEL; 868 869 for (int j = 0; j < m; j++) { 870 pi[j] = Double.POSITIVE_INFINITY; 871 tLabels[j] = NO_LABEL; 872 } 873 874 875 876 // This is Step 1.0 877 eligibleS.clear(); 878 for (int i = 0; i < n; i++) { 879 if (sMatches[i] == -1) { 880 sLabels[i] = EMPTY_LABEL; 881 eligibleS.add(new Integer(i)); 882 } 883 } 884 885 886 eligibleT.clear(); 887 } 888 889 // Step 3: Change the dual variables 890 891 // delta1 = min_i u[i] 892 double delta1 = Double.POSITIVE_INFINITY; 893 for (int i = 0; i < n; i++) { 894 if (u[i] < delta1) { 895 delta1 = u[i]; 896 } 897 } 898 899 // delta2 = min_{j : pi[j] > 0} pi[j] 900 double delta2 = Double.POSITIVE_INFINITY; 901 for (int j = 0; j < m; j++) { 902 if ((pi[j] >= TOL) && (pi[j] < delta2)) { 903 delta2 = pi[j]; 904 } 905 } 906 907 if (delta1 < delta2) { 908 // In order to make another pi[j] equal 0, we'd need to 909 // make some u[i] negative. 910 break; // we have a maximum-weight matching 911 } 912 913 changeDualVars(delta2); 914 } 915 916 int[] matching = new int[n]; 917 for (int i = 0; i < n; i++) { 918 matching[i] = sMatches[i]; 919 } 920 return matching; 921 } 922 923 /** 924 * Tries to find an augmenting path containing only edges (i,j) for which 925 * u[i] + v[j] = weights[i][j]. If it succeeds, returns the index of the 926 * last node in the path. Otherwise, returns -1. In any case, updates 927 * the labels and pi values. 928 */ 929 int findAugmentingPath() { 930 while ((!eligibleS.isEmpty()) || (!eligibleT.isEmpty())) { 931 if (!eligibleS.isEmpty()) { 932 int i = ((Integer) eligibleS.get(eligibleS.size() - 1)). 933 intValue(); 934 eligibleS.remove(eligibleS.size() - 1); 935 for (int j = 0; j < m; j++) { 936 // If pi[j] has already been decreased essentially 937 // to zero, then j is already labeled, and we 938 // can't decrease pi[j] any more. Omitting the 939 // pi[j] >= TOL check could lead us to relabel j 940 // unnecessarily, since the diff we compute on the 941 // next line may end up being less than pi[j] due 942 // to floating point imprecision. 943 if ((tMatches[j] != i) && (pi[j] >= TOL)) { 944 double diff = u[i] + v[j] - weights[i][j]; 945 if (diff < pi[j]) { 946 tLabels[j] = i; 947 pi[j] = diff; 948 if (pi[j] < TOL) { 949 eligibleT.add(new Integer(j)); 950 } 951 } 952 } 953 } 954 } else { 955 int j = ((Integer) eligibleT.get(eligibleT.size() - 1)). 956 intValue(); 957 eligibleT.remove(eligibleT.size() - 1); 958 if (tMatches[j] == -1) { 959 return j; // we've found an augmenting path 960 } 961 962 int i = tMatches[j]; 963 sLabels[i] = j; 964 eligibleS.add(new Integer(i)); // ok to add twice 965 } 966 } 967 968 return -1; 969 } 970 971 /** 972 * Given an augmenting path ending at lastNode, "flips" the path. This 973 * means that an edge on the path is in the matching after the flip if 974 * and only if it was not in the matching before the flip. An augmenting 975 * path connects two unmatched nodes, so the result is still a matching. 976 */ 977 void flipPath(int lastNode) { 978 while (lastNode != EMPTY_LABEL) { 979 int parent = tLabels[lastNode]; 980 981 // Add (parent, lastNode) to matching. We don't need to 982 // explicitly remove any edges from the matching because: 983 // * We know at this point that there is no i such that 984 // sMatches[i] = lastNode. 985 // * Although there might be some j such that tMatches[j] = 986 // parent, that j must be sLabels[parent], and will change 987 // tMatches[j] in the next time through this loop. 988 sMatches[parent] = lastNode; 989 tMatches[lastNode] = parent; 990 991 lastNode = sLabels[parent]; 992 } 993 } 994 995 void changeDualVars(double delta) { 996 for (int i = 0; i < n; i++) { 997 if (sLabels[i] != NO_LABEL) { 998 u[i] -= delta; 999 } 1000 } 1001 1002 for (int j = 0; j < m; j++) { 1003 if (pi[j] < TOL) { 1004 v[j] += delta; 1005 } else if (tLabels[j] != NO_LABEL) { 1006 pi[j] -= delta; 1007 if (pi[j] < TOL) { 1008 eligibleT.add(new Integer(j)); 1009 } 1010 } 1011 } 1012 } 1013 1014 /** 1015 * Ensures that all weights are either Double.NEGATIVE_INFINITY, 1016 * or strictly greater than zero. 1017 */ 1018 private void ensurePositiveWeights() { 1019 // minWeight is the minimum non-infinite weight 1020 if (minWeight < TOL) { 1021 for (int i = 0; i < n; i++) { 1022 for (int j = 0; j < m; j++) { 1023 weights[i][j] = weights[i][j] - minWeight + 1; 1024 } 1025 } 1026 1027 maxWeight = maxWeight - minWeight + 1; 1028 minWeight = 1; 1029 } 1030 } 1031 1032 @SuppressWarnings("unused") 1033 private void printWeights() { 1034 for (int i = 0; i < n; i++) { 1035 for (int j = 0; j < m; j++) { 1036 System.out.print(weights[i][j] + " "); 1037 } 1038 System.out.println(""); 1039 } 1040 } 1041 } 571 1042 }
Note: See TracChangeset
for help on using the changeset viewer.