Changeset 137 for trunk/CrossPare/src


Ignore:
Timestamp:
08/17/16 16:10:13 (8 years ago)
Author:
atrautsch
Message:

Metric Matching Training Update

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/CrossPare/src/de/ugoe/cs/cpdp/training/MetricMatchingTraining.java

    r86 r137  
    2727import java.util.logging.Level; 
    2828 
    29 import javax.management.RuntimeErrorException; 
    30  
    3129import java.util.Random; 
    3230 
     
    4644import weka.core.Instances; 
    4745 
    48  
     46/** 
     47 * Implements Heterogenous Defect Prediction after Nam et al. 
     48 *  
     49 * TODO: 
     50 * - spacing, coding conventions 
     51 * - we depend on having exactly one class attribute on multiple locations 
     52 * -  
     53 */ 
    4954public class MetricMatchingTraining extends WekaBaseTraining implements ISetWiseTestdataAwareTrainingStrategy { 
    5055 
     
    7883 
    7984        /** 
    80          * We need the testdata instances to do a metric matching, so in this special case we get this data 
    81          * before evaluation 
     85         * We need the test data instances to do a metric matching, so in this special case we get this data 
     86         * before evaluation. 
    8287         */ 
    8388        @Override 
     
    8590                this.traindataSet = traindataSet; 
    8691 
    87                 int rank = 0; // we want at least 5 matching attributes 
     92                double score = 0; // custom ranking score to select the best training data from the set 
    8893                int num = 0; 
    8994                int biggest_num = 0; 
     
    9297                for (Instances traindata : this.traindataSet) { 
    9398                        num++; 
     99 
    94100                        tmp = new MetricMatch(traindata, testdata); 
    95101 
     
    97103                        try { 
    98104                                tmp.attributeSelection(); 
     105                                tmp.matchAttributes(this.method, this.threshold); 
    99106                        }catch(Exception e) { 
    100107                                e.printStackTrace(); 
     
    102109                        } 
    103110                         
    104                         if (this.method.equals("spearman")) { 
    105                             tmp.spearmansRankCorrelation(this.threshold); 
    106                         } 
    107                         else if (this.method.equals("kolmogorov")) { 
    108                             tmp.kolmogorovSmirnovTest(this.threshold); 
    109                         } 
    110                         else if( this.method.equals("percentile") ) { 
    111                             tmp.percentiles(this.threshold); 
    112                         } 
    113                         else { 
    114                             throw new RuntimeException("unknown method"); 
    115                         } 
    116  
    117111                        // we only select the training data from our set with the most matching attributes 
    118                         if (tmp.getRank() > rank) { 
    119                                 rank = tmp.getRank(); 
     112                        if (tmp.getScore() > score) { 
     113                                score = tmp.getScore(); 
    120114                                biggest = tmp; 
    121115                                biggest_num = num; 
     
    127121                } 
    128122 
    129                 // we use the best match 
    130                  
     123                // we use the best match according to our matching score 
    131124                this.mm = biggest; 
    132125                Instances ilist = this.mm.getMatchedTrain(); 
    133                 Console.traceln(Level.INFO, "Chosing the trainingdata set num "+biggest_num +" with " + rank + " matching attributes, " + ilist.size() + " instances out of a possible set of " + traindataSet.size() + " sets"); 
    134                  
     126                Console.traceln(Level.INFO, "Chosing the trainingdata set num "+biggest_num +" with " + score + " matching score, " + ilist.size() + " instances, and " + biggest.attributes.size() + " matched attributes out of a possible set of " + traindataSet.size() + " sets"); 
     127                 
     128                for(int i = 0; i < this.mm.attributes.size(); i++) { 
     129                    Console.traceln(Level.INFO, "Matched Attribute: " + this.mm.train.attribute(i).name() + " with " + this.mm.test.attribute((int)this.mm.attributes.get(i)).name()); 
     130                } 
    135131                // replace traindataSEt 
    136132                //traindataSet = new SetUniqueList<Instances>(); 
     
    156152         
    157153        /** 
    158          * encapsulates the classifier configured with WekaBase 
     154         * Encapsulates the classifier configured with WekaBase within but use metric matching. 
     155         * This allows us to use any Weka classifier with Heterogenous Defect Prediction. 
    159156         */ 
    160157        public class MetricMatchingClassifier extends AbstractClassifier { 
     
    197194         */ 
    198195    public class MetricMatch { 
    199                  Instances train; 
    200                  Instances test; 
    201                   
    202                  HashMap<Integer, Integer> attributes = new HashMap<Integer,Integer>(); 
    203                   
    204                  ArrayList<double[]> train_values; 
    205                  ArrayList<double[]> test_values; 
    206                   
    207                  // todo: this constructor does not work 
    208                  public MetricMatch() { 
    209                  } 
    210                   
    211                  public MetricMatch(Instances train, Instances test) { 
    212                          this.train = new Instances(train);  // expensive! but we are dropping the attributes 
    213                          this.test = new Instances(test); 
     196            Instances train; 
     197                Instances test; 
     198                 
     199                // used to sum up the matching values of all attributes 
     200                double p_sum = 0; 
     201                 
     202                // attribute matching, train -> test 
     203                HashMap<Integer, Integer> attributes = new HashMap<Integer,Integer>(); 
     204                //double[][] weights;  /* weight matrix, needed to find maximum weighted bipartite matching */ 
     205                  
     206                ArrayList<double[]> train_values; 
     207                ArrayList<double[]> test_values; 
     208 
     209                // todo: this constructor does not work 
     210                public MetricMatch() { 
     211                } 
     212                  
     213                public MetricMatch(Instances train, Instances test) { 
     214                    // expensive! but we are dropping the attributes so we have to copy all of the data 
     215                    this.train = new Instances(train); 
     216                        this.test = new Instances(test); 
    214217                          
    215                          // 1. convert metrics of testdata and traindata to later use in test 
    216                          this.train_values = new ArrayList<double[]>(); 
    217                          for (int i = 0; i < this.train.numAttributes()-1; i++) { 
    218                                 this.train_values.add(train.attributeToDoubleArray(i)); 
    219                          } 
     218                        // 1. convert metrics of testdata and traindata to later use in test 
     219                        this.train_values = new ArrayList<double[]>(); 
     220                        for (int i = 0; i < this.train.numAttributes()-1; i++) { 
     221                            this.train_values.add(train.attributeToDoubleArray(i)); 
     222                        } 
    220223                         
    221                          this.test_values = new ArrayList<double[]>(); 
    222                          for (int i=0; i < this.test.numAttributes()-1; i++) { 
    223                                 this.test_values.add(this.test.attributeToDoubleArray(i)); 
    224                          } 
    225                  } 
    226                   
    227                  /** 
    228                   * returns the number of matched attributes 
    229                   * as a way of scoring traindata sets individually 
    230                   *  
    231                   * @return 
    232                   */ 
    233                  public int getRank() { 
    234                          return this.attributes.size(); 
     224                        this.test_values = new ArrayList<double[]>(); 
     225                        for (int i=0; i < this.test.numAttributes()-1; i++) { 
     226                            this.test_values.add(this.test.attributeToDoubleArray(i)); 
     227                        } 
     228                } 
     229                  
     230                /** 
     231                 * We have a lot of matching possibilities. 
     232                 * Here we try to determine the best one. 
     233                 *  
     234                 * @return double matching score 
     235                 */ 
     236            public double getScore() { 
     237                int as = this.attributes.size(); 
     238                 
     239                // we use thresholding ranking approach for numInstances to influence the matching score 
     240                int instances = this.train.numInstances(); 
     241                int inst_rank = 0; 
     242                if(instances > 100) { 
     243                    inst_rank = 1; 
     244                } 
     245                if(instances > 500) { 
     246                inst_rank = 2; 
     247            } 
     248             
     249                return this.p_sum + as + inst_rank; 
     250            } 
     251                  
     252                 public HashMap<Integer, Integer> getAttributes() { 
     253                     return this.attributes; 
    235254                 } 
    236255                  
     
    260279                          
    261280                         ni.setClassValue(test.value(test.classAttribute())); 
    262                           
    263                          //System.out.println(ni); 
     281 
    264282                         return ni; 
    265283                 } 
     
    283301                 } 
    284302                  
     303                 // todo: there must be a better way 
    285304                 // https://weka.wikispaces.com/Programmatic+Use 
    286305                 private Instances getMatchedInstances(String name, Instances data) { 
    287                          // construct our new attributes 
     306                         //Console.traceln(Level.INFO, "Constructing instances from: " + name); 
     307                     // construct our new attributes 
    288308                         Attribute[] attrs = new Attribute[this.attributes.size()+1]; 
    289309                         FastVector fwTrain = new FastVector(this.attributes.size()); 
     
    301321                         newTrain.setClassIndex(newTrain.numAttributes()-1); 
    302322                          
     323                         //Console.traceln(Level.INFO, "data attributes: " + data.numAttributes() + ", this.attributes: "+this.attributes.size()); 
     324                          
    303325                         for (int i=0; i < data.size(); i++) { 
    304326                                 Instance ni = new DenseInstance(this.attributes.size()+1); 
     
    314336                                                 value = (int)values.getKey(); 
    315337                                         } 
     338                                         //Console.traceln(Level.INFO, "setting attribute " + j + " with data from instance: " + i); 
    316339                                         ni.setValue(newTrain.attribute(j), data.instance(i).value(value)); 
    317340                                         j++; 
     
    328351                /** 
    329352                 * performs the attribute selection 
    330                  * we perform attribute significance tests and drop attributes 
     353                 * we perform attribute significance tests and drop attributes  
     354                 *  
     355                 * attribute selection is only performed on the source dataset 
     356                 * we retain the top 15% attributes (if 15% is a float we just use the integer part) 
    331357                 */ 
    332358                public void attributeSelection() throws Exception { 
     
    335361                        //Console.traceln(Level.INFO, "-----"); 
    336362                        //Console.traceln(Level.INFO, "Attribute Selection on Test Attributes ("+this.test.numAttributes()+")"); 
    337                         this.attributeSelection(this.test); 
     363                        //this.attributeSelection(this.test); 
    338364                        //Console.traceln(Level.INFO, "-----"); 
    339365                } 
     
    358384                        HashMap<String, Double> sorted = sortByValues(saeval); 
    359385                         
    360                         // die letzen 15% wollen wir haben 
    361                         float last = ((float)saeval.size() / 100) * 15; 
     386                        // die besten 15% wollen wir haben 
     387                        double last = ((double)saeval.size() / 100.0) * 15.0; 
    362388                        int drop_first = saeval.size() - (int)last; 
    363389                         
     
    380406                        } 
    381407                        drop_first-=1; 
    382                   
    383                      
    384408                    } 
    385 //                  //Console.traceln(Level.INFO, "Now we have " + which.numAttributes() + " attributes left (incl. class attribute!)"); 
     409                    //Console.traceln(Level.INFO, "Now we have " + which.numAttributes() + " attributes left (incl. class attribute!)"); 
    386410                } 
    387411                 
     
    406430                } 
    407431                  
     432                 
     433         
     434                public void matchAttributes(String type, double cutoff) { 
     435                     
     436 
     437                    MWBMatchingAlgorithm mwbm = new MWBMatchingAlgorithm(this.train.numAttributes(), this.test.numAttributes()); 
     438                     
     439                    if (type.equals("spearman")) { 
     440                        this.spearmansRankCorrelation(cutoff, mwbm); 
     441                    }else if(type.equals("ks")) { 
     442                        this.kolmogorovSmirnovTest(cutoff, mwbm); 
     443                    }else if(type.equals("percentile")) { 
     444                        this.percentiles(cutoff, mwbm); 
     445                    }else { 
     446                        throw new RuntimeException("unknown matching method"); 
     447                    } 
     448                     
     449                    // resulting maximal match 
     450            int[] result = mwbm.getMatching(); 
     451            for( int i = 0; i < result.length; i++) { 
     452                 
     453                // -1 means that it is not in the set of maximal matching 
     454                if( i != -1 && result[i] != -1) { 
     455                    //Console.traceln(Level.INFO, "Found maximal bipartite match between: "+ i + " and " + result[i]); 
     456                    this.attributes.put(i, result[i]); 
     457                } 
     458            } 
     459        } 
     460         
     461         
    408462                /** 
    409463                 * Calculates the Percentiles of the source and target metrics. 
     
    411465                 * @param cutoff 
    412466                 */ 
    413                 public void percentiles(double cutoff) { 
    414                     for( int i = 0; i < this.train.numAttributes()-1; i++ ) { 
    415                 for( int j = 0; j < this.test.numAttributes()-1; j++ ) { 
     467                public void percentiles(double cutoff, MWBMatchingAlgorithm mwbm) { 
     468                    for( int i = 0; i < this.train.numAttributes(); i++ ) { 
     469                for( int j = 0; j < this.test.numAttributes(); j++ ) { 
     470                    // negative infinity counts as not present, we do this so we don't have to map between attribute indexs in weka 
     471                    // and the result of the mwbm computation 
     472                    mwbm.setWeight(i, j, Double.NEGATIVE_INFINITY); 
     473                     
    416474                    // class attributes are not relevant  
    417                     if( this.train.classIndex() == i ) { 
     475                    if (this.test.classIndex() == j) { 
    418476                        continue; 
    419477                    } 
    420                     if( this.test.classIndex() == j ) { 
     478                    if (this.train.classIndex() == i) { 
    421479                        continue; 
    422480                    } 
     481 
     482                    // get percentiles 
     483                    double train[] = this.train_values.get(i); 
     484                    double test[] = this.test_values.get(j); 
    423485                     
     486                    Arrays.sort(train); 
     487                    Arrays.sort(test); 
    424488                     
    425                     if( !this.attributes.containsKey(i) ) { 
    426                         // get percentiles 
    427                         double train[] = this.train_values.get(i); 
    428                         double test[] = this.test_values.get(j); 
    429                          
    430                         Arrays.sort(train); 
    431                         Arrays.sort(test); 
    432                          
    433                         // percentiles 
    434                         double train_p; 
    435                         double test_p; 
    436                         double score = 0.0; 
    437                         for( double p=0.1; p < 1; p+=0.1 ) { 
    438                             train_p = train[(int)Math.ceil(train.length * p)]; 
    439                             test_p = test[(int)Math.ceil(test.length * p)]; 
    440                          
    441                             if( train_p > test_p ) { 
    442                                 score += test_p / train_p; 
    443                             }else { 
    444                                 score += train_p / test_p; 
    445                             } 
    446                         } 
    447                          
    448                         if( score > cutoff ) { 
    449                             this.attributes.put(i, j); 
     489                    // percentiles 
     490                    double train_p; 
     491                    double test_p; 
     492                    double score = 0.0; 
     493                    for( int p=1; p <= 9; p++ ) { 
     494                        train_p = train[(int)Math.ceil(train.length * (p/100))]; 
     495                        test_p = test[(int)Math.ceil(test.length * (p/100))]; 
     496                     
     497                        if( train_p > test_p ) { 
     498                            score += test_p / train_p; 
     499                        }else { 
     500                            score += train_p / test_p; 
    450501                        } 
    451502                    } 
     503                     
     504                    if( score > cutoff ) { 
     505                        this.p_sum += score; 
     506                        mwbm.setWeight(i, j, score); 
     507                    } 
    452508                } 
    453509            } 
     
    455511                  
    456512                 /** 
    457                   * calculate Spearmans rank correlation coefficient as matching score 
     513                  * Calculate Spearmans rank correlation coefficient as matching score 
     514                  * The number of instances for the source and target needs to be the same so we randomly sample from the bigger one. 
    458515                  *  
    459516                  * @param cutoff 
     517                  * @param mwbmatching 
    460518                  */ 
    461                  public void spearmansRankCorrelation(double cutoff) { 
    462                          double p = 0;                    
     519                 public void spearmansRankCorrelation(double cutoff, MWBMatchingAlgorithm mwbm) { 
     520                         double p = 0; 
     521 
    463522                         SpearmansCorrelation t = new SpearmansCorrelation(); 
    464523 
     
    469528                             this.sample(this.test, this.train, this.test_values); 
    470529                         } 
    471                           
    472                          // try out possible attribute combinations 
    473             for (int i=0; i < this.train.numAttributes()-1; i++) { 
    474                 for (int j=0; j < this.test.numAttributes()-1; j++) { 
     530                         
     531            // try out possible attribute combinations 
     532            for (int i=0; i < this.train.numAttributes(); i++) { 
     533 
     534                for (int j=0; j < this.test.numAttributes(); j++) { 
     535                    // negative infinity counts as not present, we do this so we don't have to map between attribute indexs in weka 
     536                    // and the result of the mwbm computation 
     537                    mwbm.setWeight(i, j, Double.NEGATIVE_INFINITY); 
     538                     
    475539                    // class attributes are not relevant  
     540                    if (this.test.classIndex() == j) { 
     541                        continue; 
     542                    } 
    476543                    if (this.train.classIndex() == i) { 
    477544                        continue; 
    478545                    } 
    479                     if (this.test.classIndex() == j) { 
    480                         continue; 
    481                     } 
    482546                     
    483                      
    484                                         if (!this.attributes.containsKey(i)) { 
    485                                                 p = t.correlation(this.train_values.get(i), this.test_values.get(j)); 
    486                                                 if (p > cutoff) { 
    487                                                         this.attributes.put(i, j); 
    488                                                 } 
     547                                        p = t.correlation(this.train_values.get(i), this.test_values.get(j)); 
     548                                        if (p > cutoff) { 
     549                                            this.p_sum += p; 
     550                        mwbm.setWeight(i, j, p); 
     551                                            //Console.traceln(Level.INFO, "Found match: p-val: " + p); 
    489552                                        } 
    490553                                } 
    491554                    } 
    492         } 
    493  
    494                  
     555 
     556            //Console.traceln(Level.INFO, "Found " + this.attributes.size() + " matching attributes"); 
     557        } 
     558 
     559                  
    495560        public void sample(Instances bigger, Instances smaller, ArrayList<double[]> values) { 
    496561            // we want to at keep the indices we select the same 
     
    535600                 * @return p-val 
    536601                 */ 
    537                 public void kolmogorovSmirnovTest(double cutoff) { 
     602                public void kolmogorovSmirnovTest(double cutoff, MWBMatchingAlgorithm mwbm) { 
    538603                        double p = 0; 
    539                          
     604             
    540605                        KolmogorovSmirnovTest t = new KolmogorovSmirnovTest(); 
    541606 
    542                         // todo: this should be symmetrical we don't have to compare i to j and then j to i  
    543                         // todo: this relies on the last attribute being the class,  
    544607                        //Console.traceln(Level.INFO, "Starting Kolmogorov-Smirnov test for traindata size: " + this.train.size() + " attributes("+this.train.numAttributes()+") and testdata size: " + this.test.size() + " attributes("+this.test.numAttributes()+")"); 
    545                         for (int i=0; i < this.train.numAttributes()-1; i++) { 
    546                                 for ( int j=0; j < this.test.numAttributes()-1; j++) { 
    547                                         //p = t.kolmogorovSmirnovTest(this.train_values.get(i), this.test_values.get(j)); 
    548                                         //p = t.kolmogorovSmirnovTest(this.train_values.get(i), this.test_values.get(j)); 
     608                        for (int i=0; i < this.train.numAttributes(); i++) { 
     609                                for ( int j=0; j < this.test.numAttributes(); j++) { 
     610                    // negative infinity counts as not present, we do this so we don't have to map between attribute indexs in weka 
     611                    // and the result of the mwbm computation 
     612                    mwbm.setWeight(i, j, Double.NEGATIVE_INFINITY); 
     613                     
    549614                    // class attributes are not relevant  
    550                     if ( this.train.classIndex() == i ) { 
     615                    if (this.test.classIndex() == j) { 
    551616                        continue; 
    552617                    } 
    553                     if ( this.test.classIndex() == j ) { 
     618                    if (this.train.classIndex() == i) { 
    554619                        continue; 
    555620                    } 
    556                                         // PRoblem: exactP is forced for small sample sizes and it never finishes 
    557                                         if (!this.attributes.containsKey(i)) { 
    558                                                  
    559                                                 // todo: output the values and complain on the math.commons mailinglist 
    560                                                 p = t.kolmogorovSmirnovTest(this.train_values.get(i), this.test_values.get(j)); 
    561                                                 if (p > cutoff) { 
    562                                                         this.attributes.put(i, j); 
    563                                                 } 
     621                     
     622                    // this may invoke exactP on small sample sizes which will not terminate in all cases 
     623                                        //p = t.kolmogorovSmirnovTest(this.train_values.get(i), this.test_values.get(j), false); 
     624                                        p = t.approximateP(t.kolmogorovSmirnovStatistic(this.train_values.get(i), this.test_values.get(j)), this.train_values.get(i).length, this.test_values.get(j).length); 
     625                                        if (p > cutoff) { 
     626                        this.p_sum += p; 
     627                        mwbm.setWeight(i, j, p); 
    564628                                        } 
    565629                                } 
    566630                        } 
    567  
    568                         //Console.traceln(Level.INFO, "Found " + this.attmatch.size() + " matching attributes"); 
    569                 } 
    570          } 
     631                        //Console.traceln(Level.INFO, "Found " + this.attributes.size() + " matching attributes"); 
     632            } 
     633    } 
     634 
     635    /* 
     636     * Copyright (c) 2007, Massachusetts Institute of Technology 
     637     * Copyright (c) 2005-2006, Regents of the University of California 
     638     * All rights reserved. 
     639     *  
     640     * Redistribution and use in source and binary forms, with or without 
     641     * modification, are permitted provided that the following conditions 
     642     * are met: 
     643     * 
     644     * * Redistributions of source code must retain the above copyright 
     645     *   notice, this list of conditions and the following disclaimer. 
     646     * 
     647     * * Redistributions in binary form must reproduce the above copyright 
     648     *   notice, this list of conditions and the following disclaimer in 
     649     *   the documentation and/or other materials provided with the 
     650     *   distribution.   
     651     * 
     652     * * Neither the name of the University of California, Berkeley nor 
     653     *   the names of its contributors may be used to endorse or promote 
     654     *   products derived from this software without specific prior  
     655     *   written permission. 
     656     * 
     657     * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 
     658     * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 
     659     * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 
     660     * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 
     661     * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 
     662     * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 
     663     * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 
     664     * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 
     665     * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 
     666     * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 
     667     * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED 
     668     * OF THE POSSIBILITY OF SUCH DAMAGE. 
     669     */ 
     670 
     671 
     672 
     673    /** 
     674     * An engine for finding the maximum-weight matching in a complete 
     675     * bipartite graph.  Suppose we have two sets <i>S</i> and <i>T</i>, 
     676     * both of size <i>n</i>.  For each <i>i</i> in <i>S</i> and <i>j</i> 
     677     * in <i>T</i>, we have a weight <i>w<sub>ij</sub></i>.  A perfect 
     678     * matching <i>X</i> is a subset of <i>S</i> x <i>T</i> such that each 
     679     * <i>i</i> in <i>S</i> occurs in exactly one element of <i>X</i>, and 
     680     * each <i>j</i> in <i>T</i> occurs in exactly one element of 
     681     * <i>X</i>.  Thus, <i>X</i> can be thought of as a one-to-one 
     682     * function from <i>S</i> to <i>T</i>.  The weight of <i>X</i> is the 
     683     * sum, over (<i>i</i>, <i>j</i>) in <i>X</i>, of 
     684     * <i>w<sub>ij</sub></i>.  A BipartiteMatcher takes the number 
     685     * <i>n</i> and the weights <i>w<sub>ij</sub></i>, and finds a perfect 
     686     * matching of maximum weight. 
     687     * 
     688     * It uses the Hungarian algorithm of Kuhn (1955), as improved and 
     689     * presented by E. L. Lawler in his book <cite>Combinatorial 
     690     * Optimization: Networks and Matroids</cite> (Holt, Rinehart and 
     691     * Winston, 1976, p. 205-206).  The running time is 
     692     * O(<i>n</i><sup>3</sup>).  The weights can be any finite real 
     693     * numbers; Lawler's algorithm assumes positive weights, so if 
     694     * necessary we add a constant <i>c</i> to all the weights before 
     695     * running the algorithm.  This increases the weight of every perfect 
     696     * matching by <i>nc</i>, which doesn't change which perfect matchings 
     697     * have maximum weight. 
     698     * 
     699     * If a weight is set to Double.NEGATIVE_INFINITY, then the algorithm will  
     700     * behave as if that edge were not in the graph.  If all the edges incident on  
     701     * a given node have weight Double.NEGATIVE_INFINITY, then the final result  
     702     * will not be a perfect matching, and an exception will be thrown.   
     703     */ 
     704     class MWBMatchingAlgorithm { 
     705        /** 
     706         * Creates a BipartiteMatcher without specifying the graph size.  Calling  
     707         * any other method before calling reset will yield an  
     708         * IllegalStateException. 
     709         */ 
     710         
     711         /** 
     712         * Tolerance for comparisons to zero, to account for 
     713         * floating-point imprecision.  We consider a positive number to 
     714         * be essentially zero if it is strictly less than TOL. 
     715         */ 
     716        private static final double TOL = 1e-10; 
     717        //Number of left side nodes 
     718        int n; 
     719 
     720        //Number of right side nodes 
     721        int m; 
     722 
     723        double[][] weights; 
     724        double minWeight; 
     725        double maxWeight; 
     726 
     727        // If (i, j) is in the mapping, then sMatches[i] = j and tMatches[j] = i.   
     728        // If i is unmatched, then sMatches[i] = -1 (and likewise for tMatches).  
     729        int[] sMatches; 
     730        int[] tMatches; 
     731 
     732        static final int NO_LABEL = -1; 
     733        static final int EMPTY_LABEL = -2; 
     734 
     735        int[] sLabels; 
     736        int[] tLabels; 
     737 
     738        double[] u; 
     739        double[] v; 
     740         
     741        double[] pi; 
     742 
     743        List<Integer> eligibleS = new ArrayList<Integer>(); 
     744        List<Integer> eligibleT = new ArrayList<Integer>();  
     745         
     746         
     747        public MWBMatchingAlgorithm() { 
     748        n = -1; 
     749        m = -1; 
     750        } 
     751 
     752        /** 
     753         * Creates a BipartiteMatcher and prepares it to run on an n x m graph.   
     754         * All the weights are initially set to 1.   
     755         */ 
     756        public MWBMatchingAlgorithm(int n, int m) { 
     757        reset(n, m); 
     758        } 
     759 
     760        /** 
     761         * Resets the BipartiteMatcher to run on an n x m graph.  The weights are  
     762         * all reset to 1. 
     763         */ 
     764        private void reset(int n, int m) { 
     765            if (n < 0 || m < 0) { 
     766                throw new IllegalArgumentException("Negative num nodes: " + n + " or " + m); 
     767            } 
     768            this.n = n; 
     769            this.m = m; 
     770 
     771            weights = new double[n][m]; 
     772            for (int i = 0; i < n; i++) { 
     773                for (int j = 0; j < m; j++) { 
     774                weights[i][j] = 1; 
     775                } 
     776            } 
     777            minWeight = 1; 
     778            maxWeight = Double.NEGATIVE_INFINITY; 
     779 
     780            sMatches = new int[n]; 
     781            tMatches = new int[m]; 
     782            sLabels = new int[n]; 
     783            tLabels = new int[m]; 
     784            u = new double[n]; 
     785            v = new double[m]; 
     786            pi = new double[m]; 
     787             
     788        } 
     789        /** 
     790         * Sets the weight w<sub>ij</sub> to the given value w.  
     791         * 
     792         * @throws IllegalArgumentException if i or j is outside the range [0, n). 
     793         */ 
     794        public void setWeight(int i, int j, double w) { 
     795        if (n == -1 || m == -1) { 
     796            throw new IllegalStateException("Graph size not specified."); 
     797        } 
     798        if ((i < 0) || (i >= n)) { 
     799            throw new IllegalArgumentException("i-value out of range: " + i); 
     800        } 
     801        if ((j < 0) || (j >= m)) { 
     802            throw new IllegalArgumentException("j-value out of range: " + j); 
     803        } 
     804        if (Double.isNaN(w)) { 
     805            throw new IllegalArgumentException("Illegal weight: " + w); 
     806        } 
     807 
     808        weights[i][j] = w; 
     809        if ((w > Double.NEGATIVE_INFINITY) && (w < minWeight)) { 
     810            minWeight = w; 
     811        } 
     812        if (w > maxWeight) { 
     813            maxWeight = w; 
     814        } 
     815        } 
     816 
     817        /** 
     818         * Returns a maximum-weight perfect matching relative to the weights  
     819         * specified with setWeight.  The matching is represented as an array arr  
     820         * of length n, where arr[i] = j if (i,j) is in the matching. 
     821         */ 
     822        public int[] getMatching() { 
     823        if (n == -1 || m == -1 ) { 
     824            throw new IllegalStateException("Graph size not specified."); 
     825        } 
     826        if (n == 0) { 
     827            return new int[0]; 
     828        } 
     829        ensurePositiveWeights(); 
     830 
     831        // Step 0: Initialization 
     832        eligibleS.clear(); 
     833        eligibleT.clear(); 
     834        for (Integer i = 0; i < n; i++) { 
     835            sMatches[i] = -1; 
     836 
     837            u[i] = maxWeight; // ambiguous on p. 205 of Lawler, but see p. 202 
     838 
     839            // this is really first run of Step 1.0 
     840            sLabels[i] = EMPTY_LABEL;  
     841            eligibleS.add(i); 
     842        } 
     843 
     844        for (int j = 0; j < m; j++) { 
     845            tMatches[j] = -1; 
     846 
     847            v[j] = 0; 
     848            pi[j] = Double.POSITIVE_INFINITY; 
     849 
     850            // this is really first run of Step 1.0 
     851            tLabels[j] = NO_LABEL; 
     852        } 
     853         
     854        while (true) { 
     855            // Augment the matching until we can't augment any more given the  
     856            // current settings of the dual variables.   
     857            while (true) { 
     858            // Steps 1.1-1.4: Find an augmenting path 
     859            int lastNode = findAugmentingPath(); 
     860            if (lastNode == -1) { 
     861                break; // no augmenting path 
     862            } 
     863                     
     864            // Step 2: Augmentation 
     865            flipPath(lastNode); 
     866            for (int i = 0; i < n; i++) 
     867                sLabels[i] = NO_LABEL; 
     868             
     869            for (int j = 0; j < m; j++) { 
     870                pi[j] = Double.POSITIVE_INFINITY; 
     871                tLabels[j] = NO_LABEL; 
     872            } 
     873             
     874             
     875             
     876            // This is Step 1.0 
     877            eligibleS.clear(); 
     878            for (int i = 0; i < n; i++) { 
     879                if (sMatches[i] == -1) { 
     880                sLabels[i] = EMPTY_LABEL; 
     881                eligibleS.add(new Integer(i)); 
     882                } 
     883            } 
     884 
     885             
     886            eligibleT.clear(); 
     887            } 
     888 
     889            // Step 3: Change the dual variables 
     890 
     891            // delta1 = min_i u[i] 
     892            double delta1 = Double.POSITIVE_INFINITY; 
     893            for (int i = 0; i < n; i++) { 
     894            if (u[i] < delta1) { 
     895                delta1 = u[i]; 
     896            } 
     897            } 
     898 
     899            // delta2 = min_{j : pi[j] > 0} pi[j] 
     900            double delta2 = Double.POSITIVE_INFINITY; 
     901            for (int j = 0; j < m; j++) { 
     902            if ((pi[j] >= TOL) && (pi[j] < delta2)) { 
     903                delta2 = pi[j]; 
     904            } 
     905            } 
     906 
     907            if (delta1 < delta2) { 
     908            // In order to make another pi[j] equal 0, we'd need to  
     909            // make some u[i] negative.   
     910            break; // we have a maximum-weight matching 
     911            } 
     912                 
     913            changeDualVars(delta2); 
     914        } 
     915 
     916        int[] matching = new int[n]; 
     917        for (int i = 0; i < n; i++) { 
     918            matching[i] = sMatches[i]; 
     919        } 
     920        return matching; 
     921        } 
     922 
     923        /** 
     924         * Tries to find an augmenting path containing only edges (i,j) for which  
     925         * u[i] + v[j] = weights[i][j].  If it succeeds, returns the index of the  
     926         * last node in the path.  Otherwise, returns -1.  In any case, updates  
     927         * the labels and pi values. 
     928         */ 
     929        int findAugmentingPath() { 
     930        while ((!eligibleS.isEmpty()) || (!eligibleT.isEmpty())) { 
     931            if (!eligibleS.isEmpty()) { 
     932            int i = ((Integer) eligibleS.get(eligibleS.size() - 1)). 
     933                intValue(); 
     934            eligibleS.remove(eligibleS.size() - 1); 
     935            for (int j = 0; j < m; j++) { 
     936                // If pi[j] has already been decreased essentially 
     937                // to zero, then j is already labeled, and we 
     938                // can't decrease pi[j] any more.  Omitting the  
     939                // pi[j] >= TOL check could lead us to relabel j 
     940                // unnecessarily, since the diff we compute on the 
     941                // next line may end up being less than pi[j] due 
     942                // to floating point imprecision. 
     943                if ((tMatches[j] != i) && (pi[j] >= TOL)) { 
     944                double diff = u[i] + v[j] - weights[i][j]; 
     945                if (diff < pi[j]) { 
     946                    tLabels[j] = i; 
     947                    pi[j] = diff; 
     948                    if (pi[j] < TOL) { 
     949                    eligibleT.add(new Integer(j)); 
     950                    } 
     951                } 
     952                } 
     953            } 
     954            } else { 
     955            int j = ((Integer) eligibleT.get(eligibleT.size() - 1)). 
     956                intValue(); 
     957            eligibleT.remove(eligibleT.size() - 1); 
     958            if (tMatches[j] == -1) { 
     959                return j; // we've found an augmenting path 
     960            }  
     961 
     962            int i = tMatches[j]; 
     963            sLabels[i] = j; 
     964            eligibleS.add(new Integer(i)); // ok to add twice 
     965            } 
     966        } 
     967 
     968        return -1; 
     969        } 
     970 
     971        /** 
     972         * Given an augmenting path ending at lastNode, "flips" the path.  This  
     973         * means that an edge on the path is in the matching after the flip if  
     974         * and only if it was not in the matching before the flip.  An augmenting  
     975         * path connects two unmatched nodes, so the result is still a matching.  
     976         */  
     977        void flipPath(int lastNode) { 
     978            while (lastNode != EMPTY_LABEL) { 
     979                int parent = tLabels[lastNode]; 
     980     
     981                // Add (parent, lastNode) to matching.  We don't need to  
     982                // explicitly remove any edges from the matching because:  
     983                //  * We know at this point that there is no i such that  
     984                //    sMatches[i] = lastNode.   
     985                //  * Although there might be some j such that tMatches[j] = 
     986                //    parent, that j must be sLabels[parent], and will change  
     987                //    tMatches[j] in the next time through this loop.   
     988                sMatches[parent] = lastNode; 
     989                tMatches[lastNode] = parent; 
     990                             
     991                lastNode = sLabels[parent]; 
     992            } 
     993        } 
     994 
     995        void changeDualVars(double delta) { 
     996            for (int i = 0; i < n; i++) { 
     997                if (sLabels[i] != NO_LABEL) { 
     998                u[i] -= delta; 
     999                } 
     1000            } 
     1001                 
     1002            for (int j = 0; j < m; j++) { 
     1003                if (pi[j] < TOL) { 
     1004                v[j] += delta; 
     1005                } else if (tLabels[j] != NO_LABEL) { 
     1006                pi[j] -= delta; 
     1007                if (pi[j] < TOL) { 
     1008                    eligibleT.add(new Integer(j)); 
     1009                } 
     1010                } 
     1011            } 
     1012        } 
     1013 
     1014        /** 
     1015         * Ensures that all weights are either Double.NEGATIVE_INFINITY,  
     1016         * or strictly greater than zero. 
     1017         */ 
     1018        private void ensurePositiveWeights() { 
     1019            // minWeight is the minimum non-infinite weight 
     1020            if (minWeight < TOL) { 
     1021                for (int i = 0; i < n; i++) { 
     1022                for (int j = 0; j < m; j++) { 
     1023                    weights[i][j] = weights[i][j] - minWeight + 1; 
     1024                } 
     1025                } 
     1026     
     1027                maxWeight = maxWeight - minWeight + 1; 
     1028                minWeight = 1; 
     1029            } 
     1030        } 
     1031 
     1032        @SuppressWarnings("unused") 
     1033        private void printWeights() { 
     1034            for (int i = 0; i < n; i++) { 
     1035                for (int j = 0; j < m; j++) { 
     1036                System.out.print(weights[i][j] + " "); 
     1037                } 
     1038                System.out.println(""); 
     1039            } 
     1040        } 
     1041    } 
    5711042} 
Note: See TracChangeset for help on using the changeset viewer.