source: trunk/CrossPare/src/de/ugoe/cs/cpdp/training/MetricMatchingTraining.java @ 140

Last change on this file since 140 was 140, checked in by atrautsch, 8 years ago

More cleanup and comments

File size: 39.1 KB
Line 
1// Copyright 2015 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.cpdp.training;
16
17import java.util.ArrayList;
18import java.util.Arrays;
19import java.util.Collections;
20import java.util.Comparator;
21import java.util.HashMap;
22import java.util.Iterator;
23import java.util.LinkedHashMap;
24import java.util.LinkedList;
25import java.util.List;
26import java.util.Map;
27import java.util.Map.Entry;
28import java.util.logging.Level;
29
30import java.util.Random;
31
32import org.apache.commons.collections4.list.SetUniqueList;
33import org.apache.commons.math3.stat.correlation.SpearmansCorrelation;
34import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest;
35
36import de.ugoe.cs.util.console.Console;
37import weka.attributeSelection.SignificanceAttributeEval;
38import weka.classifiers.AbstractClassifier;
39import weka.classifiers.Classifier;
40import weka.core.Attribute;
41import weka.core.DenseInstance;
42import weka.core.Instance;
43import weka.core.Instances;
44
45/**
46 * Implements Heterogenous Defect Prediction after Nam et al. 2015.
47 *
48 * We extend WekaBaseTraining because we have to Wrap the Classifier to use MetricMatching.
49 * This also means we can use any Weka Classifier not just LogisticRegression.
50 *
51 * Config: <setwisetestdataawaretrainer name="MetricMatchingTraining" param="Logistic weka.classifiers.functions.Logistic" threshold="0.05" method="spearman"/>
52 * Instead of spearman metchod it also takes ks, percentile.
53 * Instead of Logistic every other weka classifier can be chosen.
54 *
55 * Future work:
56 * implement chisquare test in addition to significance for attribute selection
57 * http://commons.apache.org/proper/commons-math/apidocs/org/apache/commons/math3/stat/inference/ChiSquareTest.html
58 * use chiSquareTestDataSetsComparison
59 */
60public class MetricMatchingTraining extends WekaBaseTraining implements ISetWiseTestdataAwareTrainingStrategy {
61
62    private MetricMatch mm = null;
63    private Classifier classifier = null;
64   
65    private String method;
66    private float threshold;
67   
68    /**
69     * We wrap the classifier here because of classifyInstance with our MetricMatchingClassfier
70     * @return
71     */
72    @Override
73    public Classifier getClassifier() {
74        return this.classifier;
75    }
76
77    /**
78     * Set similarity measure method.
79     */
80    @Override
81    public void setMethod(String method) {
82        this.method = method;
83    }
84
85    /**
86     * Set threshold for similarity measure.
87     */
88    @Override
89    public void setThreshold(String threshold) {
90        this.threshold = Float.parseFloat(threshold);
91    }
92
93        /**
94         * We need the test data instances to do a metric matching, so in this special case we get this data
95         * before evaluation.
96         */
97        @Override
98        public void apply(SetUniqueList<Instances> traindataSet, Instances testdata) {
99            // reset these for each run
100            this.mm = null;
101            this.classifier = null;
102           
103                double score = 0; // matching score to select the best matching training data from the set
104                int num = 0;
105                int biggest_num = 0;
106                MetricMatch tmp;
107                for (Instances traindata : traindataSet) {
108                        num++;
109
110                        tmp = new MetricMatch(traindata, testdata);
111
112                        // metric selection may create error, continue to next training set
113                        try {
114                                tmp.attributeSelection();
115                                tmp.matchAttributes(this.method, this.threshold);
116                        }catch(Exception e) {
117                                e.printStackTrace();
118                                throw new RuntimeException(e);
119                        }
120                       
121                        // we only select the training data from our set with the most matching attributes
122                        if (tmp.getScore() > score && tmp.attributes.size() > 0) {
123                                score = tmp.getScore();
124                                this.mm = tmp;
125                                biggest_num = num;
126                        }
127                }
128               
129                // if we have found a matching instance we use it, log information about the match for additional eval later
130                Instances ilist = null;
131                if (this.mm != null) {
132                ilist = this.mm.getMatchedTrain();
133                Console.traceln(Level.INFO, "[MATCH FOUND] match: ["+biggest_num +"], score: [" + score + "], instances: [" + ilist.size() + "], attributes: [" + this.mm.attributes.size() + "], ilist attrs: ["+ilist.numAttributes()+"]");
134                for(Map.Entry<Integer, Integer> attmatch : this.mm.attributes.entrySet()) {
135                    Console.traceln(Level.INFO, "[MATCHED ATTRIBUTE] source attribute: [" + this.mm.train.attribute(attmatch.getKey()).name() + "], target attribute: [" + this.mm.test.attribute(attmatch.getValue()).name() + "]");
136                }
137            }else {
138                Console.traceln(Level.INFO, "[NO MATCH FOUND]");
139            }
140               
141                // if we have a match we build the MetricMatchingClassifier, if not we fall back to FixClass Classifier
142                try {
143                        if(this.mm != null) {
144                            this.classifier = new MetricMatchingClassifier();
145                            this.classifier.buildClassifier(ilist);
146                            ((MetricMatchingClassifier) this.classifier).setMetricMatching(this.mm);
147                        }else {
148                            this.classifier = new FixClass();
149                            this.classifier.buildClassifier(ilist);  // this is null, but the FixClass Classifier does not use it anyway
150                        }
151                }catch(Exception e) {
152                        e.printStackTrace();
153                        throw new RuntimeException(e);
154                }
155        }
156
157       
158        /**
159         * Encapsulates the classifier configured with WekaBase within but use metric matching.
160         * This allows us to use any Weka classifier with Heterogenous Defect Prediction.
161         */
162        public class MetricMatchingClassifier extends AbstractClassifier {
163
164                private static final long serialVersionUID = -1342172153473770935L;
165                private MetricMatch mm;
166                private Classifier classifier;
167               
168                @Override
169                public void buildClassifier(Instances traindata) throws Exception {
170                        this.classifier = setupClassifier();
171                        this.classifier.buildClassifier(traindata);
172                }
173               
174                /**
175                 * Sets the MetricMatch instance so that we can use matched test data later.
176                 * @param mm
177                 */
178                public void setMetricMatching(MetricMatch mm) {
179                        this.mm = mm;
180                }
181               
182                /**
183                 * Here we can not do the metric matching because we only get one instance.
184                 * Therefore we need a MetricMatch instance beforehand to use here.
185                 */
186                public double classifyInstance(Instance testdata) {
187                    // get a copy of testdata Instance with only the matched attributes
188                        Instance ntest = this.mm.getMatchedTestInstance(testdata);
189
190                        double ret = 0.0;
191                        try {
192                                ret = this.classifier.classifyInstance(ntest);
193                        }catch(Exception e) {
194                                e.printStackTrace();
195                                throw new RuntimeException(e);
196                        }
197                       
198                        return ret;
199                }
200        }
201       
202        /**
203         * Encapsulates one MetricMatching process.
204         * One source (train) matches against one target (test).
205         */
206    public class MetricMatch {
207            Instances train;
208                Instances test;
209               
210                // used to sum up the matching values of all attributes
211                protected double p_sum = 0;
212               
213                // attribute matching, train -> test
214                HashMap<Integer, Integer> attributes = new HashMap<Integer,Integer>();
215               
216                // used for similarity tests
217                protected ArrayList<double[]> train_values;
218                protected ArrayList<double[]> test_values;
219
220               
221                public MetricMatch(Instances train, Instances test) {
222                        // this is expensive but we need to keep the original data intact
223                    this.train = this.deepCopy(train);
224                        this.test = test; // we do not need a copy here because we do not drop attributes before the matching and after the matching we create a new Instances with only the matched attributes
225                       
226                        // convert metrics of testdata and traindata to later use in similarity tests
227                        this.train_values = new ArrayList<double[]>();
228                        for (int i = 0; i < this.train.numAttributes(); i++) {
229                            if(this.train.classIndex() != i) {
230                                this.train_values.add(this.train.attributeToDoubleArray(i));
231                            }
232                        }
233                       
234                        this.test_values = new ArrayList<double[]>();
235                        for (int i=0; i < this.test.numAttributes(); i++) {
236                            if(this.test.classIndex() != i) {
237                                this.test_values.add(this.test.attributeToDoubleArray(i));
238                            }
239                        }
240                }
241                 
242                /**
243                 * We have a lot of matching possibilities.
244                 * Here we try to determine the best one.
245                 *
246                 * @return double matching score
247                 */
248            public double getScore() {
249                int as = this.attributes.size();  // # of attributes that were matched
250               
251                // we use thresholding ranking approach for numInstances to influence the matching score
252                int instances = this.train.numInstances();
253                int inst_rank = 0;
254                if(instances > 100) {
255                    inst_rank = 1;
256                }
257                if(instances > 500) {
258                inst_rank = 2;
259            }
260           
261                return this.p_sum + as + inst_rank;
262            }
263                 
264                public HashMap<Integer, Integer> getAttributes() {
265                    return this.attributes;
266                }
267                 
268                public int getNumInstances() {
269                    return this.train_values.get(0).length;
270                }
271
272               
273                /**
274                 * The test instance must be of the same dataset as the train data, otherwise WekaEvaluation will die.
275                 * This means we have to force the dataset of this.train (after matching) and only
276                 * set the values for the attributes we matched but with the index of the traindata attributes we matched.
277                 *
278                 * @param test
279                 * @return
280                 */
281                public Instance getMatchedTestInstance(Instance test) {
282            Instance ni = new DenseInstance(this.attributes.size()+1);
283           
284            Instances inst = this.getMatchedTrain();
285           
286            ni.setDataset(inst);
287           
288            // assign only the matched attributes to new indexes
289            double val;
290            int k = 0;
291            for(Map.Entry<Integer, Integer> attmatch : this.attributes.entrySet()) {
292                // get value from matched attribute
293                val = test.value(attmatch.getValue());
294               
295                // set it to new index, the order of the attributes is the same
296                ni.setValue(k, val);
297                k++;
298            }
299            ni.setClassValue(test.value(test.classAttribute()));
300
301            return ni;
302                }
303
304               
305        /**
306         * returns a new instances array with the metric matched training data
307         *
308         * @return instances
309         */
310                public Instances getMatchedTrain() {
311                    return this.getMatchedInstances("train", this.train);
312                }
313                 
314        /**
315                 * returns a new instances array with the metric matched test data
316                 *
317                 * @return instances
318                 */
319                public Instances getMatchedTest() {
320                    return this.getMatchedInstances("test", this.test);
321                }
322               
323                /**
324                 * We could drop unmatched attributes from our instances datasets.
325                 * Alas, that would not be nice for the following postprocessing jobs and would not work at all for evaluation.
326                 * We keep this as a warning for future generations.
327                 *
328                 * @param name
329                 * @param data
330                 */
331                @SuppressWarnings("unused")
332        private void dropUnmatched(String name, Instances data) {
333                    for(int i = 0; i < data.numAttributes(); i++) {
334                        if(data.classIndex() == i) {
335                            continue;
336                        }
337                       
338                        if(name.equals("train") && !this.attributes.containsKey(i)) {
339                            data.deleteAttributeAt(i);
340                        }
341                       
342                        if(name.equals("test") && !this.attributes.containsValue(i)) {
343                            data.deleteAttributeAt(i);
344                        }
345                    }
346                }
347
348        /**
349         * Deep Copy (well, reasonably deep, not sure about header information of attributes) Weka Instances.
350         *
351         * @param data Instances
352         * @return copy of Instances passed
353         */
354                private Instances deepCopy(Instances data) {
355                    Instances newInst = new Instances(data);
356                   
357                    newInst.clear();
358                   
359            for (int i=0; i < data.size(); i++) {
360                Instance ni = new DenseInstance(data.numAttributes());
361                for(int j = 0; j < data.numAttributes(); j++) {
362                    ni.setValue(newInst.attribute(j), data.instance(i).value(data.attribute(j)));
363                }
364                newInst.add(ni);
365            }
366           
367            return newInst;
368                }
369
370        /**
371         * Returns a deep copy of passed Instances data for Train or Test data.
372         * It only keeps attributes that have been matched.
373         *
374         * @param name
375         * @param data
376         * @return matched Instances
377         */
378                private Instances getMatchedInstances(String name, Instances data) {
379                    ArrayList<Attribute> attrs = new ArrayList<Attribute>();
380           
381                    // bug attr is a string, really!
382                    ArrayList<String> bug = new ArrayList<String>();
383            bug.add("0");
384            bug.add("1");
385           
386            // add our matched attributes and last the bug
387                    for(Map.Entry<Integer, Integer> attmatch : this.attributes.entrySet()) {
388                        attrs.add(new Attribute(String.valueOf(attmatch.getValue())));
389                    }
390                    attrs.add(new Attribute("bug", bug));
391                   
392                    // create new instances object of the same size (at least for instances)
393                    Instances newInst = new Instances(name, attrs, data.size());
394                   
395                    // set last as class
396                    newInst.setClassIndex(newInst.numAttributes()-1);
397                   
398                    // copy data for matched attributes, this depends if we return train or test data
399            for (int i=0; i < data.size(); i++) {
400                Instance ni = new DenseInstance(this.attributes.size()+1);
401               
402                int j = 0; // new indices!
403                for(Map.Entry<Integer, Integer> attmatch : this.attributes.entrySet()) {
404 
405                    // test attribute match
406                    int value = attmatch.getValue();
407                   
408                    // train attribute match
409                    if(name.equals("train")) {
410                        value = attmatch.getKey();
411                    }
412                   
413                    ni.setValue(newInst.attribute(j), data.instance(i).value(value));
414                    j++;
415                }
416                ni.setValue(ni.numAttributes()-1, data.instance(i).value(data.classAttribute()));
417                newInst.add(ni);
418            }
419
420            return newInst;
421                }
422                 
423                /**
424                 * performs the attribute selection
425                 * we perform attribute significance tests and drop attributes
426                 *
427                 * attribute selection is only performed on the source dataset
428                 * we retain the top 15% attributes (if 15% is a float we just use the integer part)
429                 */
430                public void attributeSelection() throws Exception {
431                   
432                    // it is a wrapper, we may decide to implement ChiSquare or other means of selecting attributes
433                        this.attributeSelectionBySignificance(this.train);
434                }
435               
436                private void attributeSelectionBySignificance(Instances which) throws Exception {
437                        // Uses: http://weka.sourceforge.net/doc.packages/probabilisticSignificanceAE/weka/attributeSelection/SignificanceAttributeEval.html
438                        SignificanceAttributeEval et = new SignificanceAttributeEval();
439                        et.buildEvaluator(which);
440                       
441                        // evaluate all training attributes
442                        HashMap<String,Double> saeval = new HashMap<String,Double>();
443                        for(int i=0; i < which.numAttributes(); i++) {
444                                if(which.classIndex() != i) {
445                                        saeval.put(which.attribute(i).name(), et.evaluateAttribute(i));
446                                }
447                        }
448                       
449                        // sort by significance
450                        HashMap<String, Double> sorted = (HashMap<String, Double>) sortByValues(saeval);
451                       
452                        // Keep the best 15%
453                        double last = ((double)saeval.size() / 100.0) * 15.0;
454                        int drop_first = saeval.size() - (int)last;
455                       
456                        // drop attributes above last
457                        Iterator<Entry<String, Double>> it = sorted.entrySet().iterator();
458                    while (drop_first > 0) {
459                        Map.Entry<String, Double> pair = (Map.Entry<String, Double>)it.next();
460                        if(which.attribute((String)pair.getKey()).index() != which.classIndex()) {
461                                which.deleteAttributeAt(which.attribute((String)pair.getKey()).index());
462                        }
463                        drop_first-=1;
464                    }
465                }
466               
467                /**
468                 * Helper method to sort a hashmap by its values.
469                 *
470                 * @param map
471                 * @return sorted map
472                 */
473                private HashMap<String, Double> sortByValues(HashMap<String, Double> map) {
474               List<Map.Entry<String, Double>> list = new LinkedList<Map.Entry<String, Double>>(map.entrySet());
475
476               Collections.sort(list, new Comparator<Map.Entry<String, Double>>() {
477                    public int compare(Map.Entry<String, Double> o1, Map.Entry<String, Double> o2) {
478                       return (o1.getValue()).compareTo( o2.getValue() );
479                    }
480               });
481
482               HashMap<String, Double> sortedHashMap = new LinkedHashMap<String, Double>();
483               for(Map.Entry<String, Double> item : list) {
484                   sortedHashMap.put(item.getKey(), item.getValue());
485               }
486               return sortedHashMap;
487                }
488               
489               
490        /**
491         * Executes the similarity matching between train and test data.
492         *
493         * After this function is finished we have this.attributes with the correct matching between train and test data attributes.
494         *
495         * @param type
496         * @param cutoff
497         */
498                public void matchAttributes(String type, double cutoff) {
499                   
500                    MWBMatchingAlgorithm mwbm = new MWBMatchingAlgorithm(this.train.numAttributes(), this.test.numAttributes());
501                   
502                    if (type.equals("spearman")) {
503                        this.spearmansRankCorrelation(cutoff, mwbm);
504                    }else if(type.equals("ks")) {
505                        this.kolmogorovSmirnovTest(cutoff, mwbm);
506                    }else if(type.equals("percentile")) {
507                        this.percentiles(cutoff, mwbm);
508                    }else {
509                        throw new RuntimeException("unknown matching method");
510                    }
511                   
512                    // resulting maximal match gets assigned to this.attributes
513            int[] result = mwbm.getMatching();
514            for( int i = 0; i < result.length; i++) {
515               
516                // -1 means that it is not in the set of maximal matching
517                if( i != -1 && result[i] != -1) {
518                    this.p_sum += mwbm.weights[i][result[i]];  // we add the weight of the returned matching for scoring the complete match later
519                    this.attributes.put(i, result[i]);
520                }
521            }
522        }
523       
524       
525                /**
526                 * Calculates the Percentiles of the source and target metrics.
527                 *
528                 * @param cutoff
529                 */
530                public void percentiles(double cutoff, MWBMatchingAlgorithm mwbm) {
531                    for( int i = 0; i < this.train.numAttributes(); i++ ) {
532                for( int j = 0; j < this.test.numAttributes(); j++ ) {
533                    // negative infinity counts as not present, we do this so we don't have to map between attribute indexes in weka
534                    // and the result of the mwbm computation
535                    mwbm.setWeight(i, j, Double.NEGATIVE_INFINITY);
536                   
537                    // class attributes are not relevant
538                    if (this.test.classIndex() == j) {
539                        continue;
540                    }
541                    if (this.train.classIndex() == i) {
542                        continue;
543                    }
544
545                    // get percentiles
546                    double train[] = this.train_values.get(i);
547                    double test[] = this.test_values.get(j);
548                   
549                    Arrays.sort(train);
550                    Arrays.sort(test);
551                   
552                    // percentiles
553                    double train_p;
554                    double test_p;
555                    double score = 0.0;
556                    for( int p=1; p <= 9; p++ ) {
557                        train_p = train[(int)Math.ceil(train.length * (p/100))];
558                        test_p = test[(int)Math.ceil(test.length * (p/100))];
559                   
560                        if( train_p > test_p ) {
561                            score += test_p / train_p;
562                        }else {
563                            score += train_p / test_p;
564                        }
565                    }
566                   
567                    if( score > cutoff ) {
568                        mwbm.setWeight(i, j, score);
569                    }
570                }
571            }
572                }
573                 
574                /**
575                 * Calculate Spearmans rank correlation coefficient as matching score.
576                 * The number of instances for the source and target needs to be the same so we randomly sample from the bigger one.
577                 *
578                 * @param cutoff
579                 * @param mwbmatching
580                 */
581                public void spearmansRankCorrelation(double cutoff, MWBMatchingAlgorithm mwbm) {
582                    double p = 0;
583
584                        SpearmansCorrelation t = new SpearmansCorrelation();
585
586                        // size has to be the same so we randomly sample the number of the smaller sample from the big sample
587                        if (this.train.size() > this.test.size()) {
588                            this.sample(this.train, this.test, this.train_values);
589                        }else if (this.test.size() > this.train.size()) {
590                            this.sample(this.test, this.train, this.test_values);
591                        }
592                       
593            // try out possible attribute combinations
594            for (int i=0; i < this.train.numAttributes(); i++) {
595                for (int j=0; j < this.test.numAttributes(); j++) {
596                    // negative infinity counts as not present, we do this so we don't have to map between attribute indexs in weka
597                    // and the result of the mwbm computation
598                    mwbm.setWeight(i, j, Double.NEGATIVE_INFINITY);
599                   
600                    // class attributes are not relevant
601                    if (this.test.classIndex() == j) {
602                        continue;
603                    }
604                    if (this.train.classIndex() == i) {
605                        continue;
606                    }
607                   
608                                        p = t.correlation(this.train_values.get(i), this.test_values.get(j));
609                                        if (p > cutoff) {
610                        mwbm.setWeight(i, j, p);
611                                        }
612                                }
613                    }
614        }
615
616                /**
617                 * Helper method to sample instances for the Spearman rank correlation coefficient method.
618                 *
619                 * @param bigger
620                 * @param smaller
621                 * @param values
622                 */
623        private void sample(Instances bigger, Instances smaller, ArrayList<double[]> values) {
624            // we want to at keep the indices we select the same
625            int indices_to_draw = smaller.size();
626            ArrayList<Integer> indices = new ArrayList<Integer>();
627            Random rand = new Random();
628            while (indices_to_draw > 0) {
629               
630                int index = rand.nextInt(bigger.size()-1);
631               
632                if (!indices.contains(index)) {
633                    indices.add(index);
634                    indices_to_draw--;
635                }
636            }
637           
638            // now reduce our values to the indices we choose above for every attribute
639            for (int att=0; att < bigger.numAttributes()-1; att++) {
640               
641                // get double for the att
642                double[] vals = values.get(att);
643                double[] new_vals = new double[indices.size()];
644               
645                int i = 0;
646                for (Iterator<Integer> it = indices.iterator(); it.hasNext();) {
647                    new_vals[i] = vals[it.next()];
648                    i++;
649                }
650               
651                values.set(att, new_vals);
652            }
653                }
654               
655               
656                /**
657                 * We run the kolmogorov-smirnov test on the data from our test an traindata
658                 * if the p value is above the cutoff we include it in the results
659                 * p value tends to be 0 when the distributions of the data are significantly different
660                 * but we want them to be the same
661                 *
662                 * @param cutoff
663                 * @return p-val
664                 */
665                public void kolmogorovSmirnovTest(double cutoff, MWBMatchingAlgorithm mwbm) {
666                        double p = 0;
667           
668                        KolmogorovSmirnovTest t = new KolmogorovSmirnovTest();
669                        for (int i=0; i < this.train.numAttributes(); i++) {
670                                for ( int j=0; j < this.test.numAttributes(); j++) {
671                    // negative infinity counts as not present, we do this so we don't have to map between attribute indexs in weka
672                    // and the result of the mwbm computation
673                    mwbm.setWeight(i, j, Double.NEGATIVE_INFINITY);
674                   
675                    // class attributes are not relevant
676                    if (this.test.classIndex() == j) {
677                        continue;
678                    }
679                    if (this.train.classIndex() == i) {
680                        continue;
681                    }
682                   
683                    // this may invoke exactP on small sample sizes which will not terminate in all cases
684                                        //p = t.kolmogorovSmirnovTest(this.train_values.get(i), this.test_values.get(j), false);
685
686                    // this uses approximateP everytime
687                                        p = t.approximateP(t.kolmogorovSmirnovStatistic(this.train_values.get(i), this.test_values.get(j)), this.train_values.get(i).length, this.test_values.get(j).length);
688                                        if (p > cutoff) {
689                        mwbm.setWeight(i, j, p);
690                                        }
691                                }
692                        }
693            }
694    }
695
696    /*
697     * Copyright (c) 2007, Massachusetts Institute of Technology
698     * Copyright (c) 2005-2006, Regents of the University of California
699     * All rights reserved.
700     *
701     * Redistribution and use in source and binary forms, with or without
702     * modification, are permitted provided that the following conditions
703     * are met:
704     *
705     * * Redistributions of source code must retain the above copyright
706     *   notice, this list of conditions and the following disclaimer.
707     *
708     * * Redistributions in binary form must reproduce the above copyright
709     *   notice, this list of conditions and the following disclaimer in
710     *   the documentation and/or other materials provided with the
711     *   distribution. 
712     *
713     * * Neither the name of the University of California, Berkeley nor
714     *   the names of its contributors may be used to endorse or promote
715     *   products derived from this software without specific prior
716     *   written permission.
717     *
718     * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
719     * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
720     * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
721     * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
722     * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
723     * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
724     * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
725     * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
726     * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
727     * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
728     * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
729     * OF THE POSSIBILITY OF SUCH DAMAGE.
730     */
731
732
733
734    /**
735     * An engine for finding the maximum-weight matching in a complete
736     * bipartite graph.  Suppose we have two sets <i>S</i> and <i>T</i>,
737     * both of size <i>n</i>.  For each <i>i</i> in <i>S</i> and <i>j</i>
738     * in <i>T</i>, we have a weight <i>w<sub>ij</sub></i>.  A perfect
739     * matching <i>X</i> is a subset of <i>S</i> x <i>T</i> such that each
740     * <i>i</i> in <i>S</i> occurs in exactly one element of <i>X</i>, and
741     * each <i>j</i> in <i>T</i> occurs in exactly one element of
742     * <i>X</i>.  Thus, <i>X</i> can be thought of as a one-to-one
743     * function from <i>S</i> to <i>T</i>.  The weight of <i>X</i> is the
744     * sum, over (<i>i</i>, <i>j</i>) in <i>X</i>, of
745     * <i>w<sub>ij</sub></i>.  A BipartiteMatcher takes the number
746     * <i>n</i> and the weights <i>w<sub>ij</sub></i>, and finds a perfect
747     * matching of maximum weight.
748     *
749     * It uses the Hungarian algorithm of Kuhn (1955), as improved and
750     * presented by E. L. Lawler in his book <cite>Combinatorial
751     * Optimization: Networks and Matroids</cite> (Holt, Rinehart and
752     * Winston, 1976, p. 205-206).  The running time is
753     * O(<i>n</i><sup>3</sup>).  The weights can be any finite real
754     * numbers; Lawler's algorithm assumes positive weights, so if
755     * necessary we add a constant <i>c</i> to all the weights before
756     * running the algorithm.  This increases the weight of every perfect
757     * matching by <i>nc</i>, which doesn't change which perfect matchings
758     * have maximum weight.
759     *
760     * If a weight is set to Double.NEGATIVE_INFINITY, then the algorithm will
761     * behave as if that edge were not in the graph.  If all the edges incident on
762     * a given node have weight Double.NEGATIVE_INFINITY, then the final result
763     * will not be a perfect matching, and an exception will be thrown. 
764     */
765     class MWBMatchingAlgorithm {
766        /**
767         * Creates a BipartiteMatcher without specifying the graph size.  Calling
768         * any other method before calling reset will yield an
769         * IllegalStateException.
770         */
771       
772         /**
773         * Tolerance for comparisons to zero, to account for
774         * floating-point imprecision.  We consider a positive number to
775         * be essentially zero if it is strictly less than TOL.
776         */
777        private static final double TOL = 1e-10;
778        //Number of left side nodes
779        int n;
780
781        //Number of right side nodes
782        int m;
783
784        double[][] weights;
785        double minWeight;
786        double maxWeight;
787
788        // If (i, j) is in the mapping, then sMatches[i] = j and tMatches[j] = i. 
789        // If i is unmatched, then sMatches[i] = -1 (and likewise for tMatches).
790        int[] sMatches;
791        int[] tMatches;
792
793        static final int NO_LABEL = -1;
794        static final int EMPTY_LABEL = -2;
795
796        int[] sLabels;
797        int[] tLabels;
798
799        double[] u;
800        double[] v;
801       
802        double[] pi;
803
804        List<Integer> eligibleS = new ArrayList<Integer>();
805        List<Integer> eligibleT = new ArrayList<Integer>();
806       
807       
808        public MWBMatchingAlgorithm() {
809        n = -1;
810        m = -1;
811        }
812
813        /**
814         * Creates a BipartiteMatcher and prepares it to run on an n x m graph. 
815         * All the weights are initially set to 1. 
816         */
817        public MWBMatchingAlgorithm(int n, int m) {
818        reset(n, m);
819        }
820
821        /**
822         * Resets the BipartiteMatcher to run on an n x m graph.  The weights are
823         * all reset to 1.
824         */
825        private void reset(int n, int m) {
826            if (n < 0 || m < 0) {
827                throw new IllegalArgumentException("Negative num nodes: " + n + " or " + m);
828            }
829            this.n = n;
830            this.m = m;
831
832            weights = new double[n][m];
833            for (int i = 0; i < n; i++) {
834                for (int j = 0; j < m; j++) {
835                weights[i][j] = 1;
836                }
837            }
838            minWeight = 1;
839            maxWeight = Double.NEGATIVE_INFINITY;
840
841            sMatches = new int[n];
842            tMatches = new int[m];
843            sLabels = new int[n];
844            tLabels = new int[m];
845            u = new double[n];
846            v = new double[m];
847            pi = new double[m];
848           
849        }
850        /**
851         * Sets the weight w<sub>ij</sub> to the given value w.
852         *
853         * @throws IllegalArgumentException if i or j is outside the range [0, n).
854         */
855        public void setWeight(int i, int j, double w) {
856        if (n == -1 || m == -1) {
857            throw new IllegalStateException("Graph size not specified.");
858        }
859        if ((i < 0) || (i >= n)) {
860            throw new IllegalArgumentException("i-value out of range: " + i);
861        }
862        if ((j < 0) || (j >= m)) {
863            throw new IllegalArgumentException("j-value out of range: " + j);
864        }
865        if (Double.isNaN(w)) {
866            throw new IllegalArgumentException("Illegal weight: " + w);
867        }
868
869        weights[i][j] = w;
870        if ((w > Double.NEGATIVE_INFINITY) && (w < minWeight)) {
871            minWeight = w;
872        }
873        if (w > maxWeight) {
874            maxWeight = w;
875        }
876        }
877
878        /**
879         * Returns a maximum-weight perfect matching relative to the weights
880         * specified with setWeight.  The matching is represented as an array arr
881         * of length n, where arr[i] = j if (i,j) is in the matching.
882         */
883        public int[] getMatching() {
884        if (n == -1 || m == -1 ) {
885            throw new IllegalStateException("Graph size not specified.");
886        }
887        if (n == 0) {
888            return new int[0];
889        }
890        ensurePositiveWeights();
891
892        // Step 0: Initialization
893        eligibleS.clear();
894        eligibleT.clear();
895        for (Integer i = 0; i < n; i++) {
896            sMatches[i] = -1;
897
898            u[i] = maxWeight; // ambiguous on p. 205 of Lawler, but see p. 202
899
900            // this is really first run of Step 1.0
901            sLabels[i] = EMPTY_LABEL;
902            eligibleS.add(i);
903        }
904
905        for (int j = 0; j < m; j++) {
906            tMatches[j] = -1;
907
908            v[j] = 0;
909            pi[j] = Double.POSITIVE_INFINITY;
910
911            // this is really first run of Step 1.0
912            tLabels[j] = NO_LABEL;
913        }
914       
915        while (true) {
916            // Augment the matching until we can't augment any more given the
917            // current settings of the dual variables. 
918            while (true) {
919            // Steps 1.1-1.4: Find an augmenting path
920            int lastNode = findAugmentingPath();
921            if (lastNode == -1) {
922                break; // no augmenting path
923            }
924                   
925            // Step 2: Augmentation
926            flipPath(lastNode);
927            for (int i = 0; i < n; i++)
928                sLabels[i] = NO_LABEL;
929           
930            for (int j = 0; j < m; j++) {
931                pi[j] = Double.POSITIVE_INFINITY;
932                tLabels[j] = NO_LABEL;
933            }
934           
935           
936           
937            // This is Step 1.0
938            eligibleS.clear();
939            for (int i = 0; i < n; i++) {
940                if (sMatches[i] == -1) {
941                sLabels[i] = EMPTY_LABEL;
942                eligibleS.add(new Integer(i));
943                }
944            }
945
946           
947            eligibleT.clear();
948            }
949
950            // Step 3: Change the dual variables
951
952            // delta1 = min_i u[i]
953            double delta1 = Double.POSITIVE_INFINITY;
954            for (int i = 0; i < n; i++) {
955            if (u[i] < delta1) {
956                delta1 = u[i];
957            }
958            }
959
960            // delta2 = min_{j : pi[j] > 0} pi[j]
961            double delta2 = Double.POSITIVE_INFINITY;
962            for (int j = 0; j < m; j++) {
963            if ((pi[j] >= TOL) && (pi[j] < delta2)) {
964                delta2 = pi[j];
965            }
966            }
967
968            if (delta1 < delta2) {
969            // In order to make another pi[j] equal 0, we'd need to
970            // make some u[i] negative. 
971            break; // we have a maximum-weight matching
972            }
973               
974            changeDualVars(delta2);
975        }
976
977        int[] matching = new int[n];
978        for (int i = 0; i < n; i++) {
979            matching[i] = sMatches[i];
980        }
981        return matching;
982        }
983
984        /**
985         * Tries to find an augmenting path containing only edges (i,j) for which
986         * u[i] + v[j] = weights[i][j].  If it succeeds, returns the index of the
987         * last node in the path.  Otherwise, returns -1.  In any case, updates
988         * the labels and pi values.
989         */
990        int findAugmentingPath() {
991        while ((!eligibleS.isEmpty()) || (!eligibleT.isEmpty())) {
992            if (!eligibleS.isEmpty()) {
993            int i = ((Integer) eligibleS.get(eligibleS.size() - 1)).
994                intValue();
995            eligibleS.remove(eligibleS.size() - 1);
996            for (int j = 0; j < m; j++) {
997                // If pi[j] has already been decreased essentially
998                // to zero, then j is already labeled, and we
999                // can't decrease pi[j] any more.  Omitting the
1000                // pi[j] >= TOL check could lead us to relabel j
1001                // unnecessarily, since the diff we compute on the
1002                // next line may end up being less than pi[j] due
1003                // to floating point imprecision.
1004                if ((tMatches[j] != i) && (pi[j] >= TOL)) {
1005                double diff = u[i] + v[j] - weights[i][j];
1006                if (diff < pi[j]) {
1007                    tLabels[j] = i;
1008                    pi[j] = diff;
1009                    if (pi[j] < TOL) {
1010                    eligibleT.add(new Integer(j));
1011                    }
1012                }
1013                }
1014            }
1015            } else {
1016            int j = ((Integer) eligibleT.get(eligibleT.size() - 1)).
1017                intValue();
1018            eligibleT.remove(eligibleT.size() - 1);
1019            if (tMatches[j] == -1) {
1020                return j; // we've found an augmenting path
1021            }
1022
1023            int i = tMatches[j];
1024            sLabels[i] = j;
1025            eligibleS.add(new Integer(i)); // ok to add twice
1026            }
1027        }
1028
1029        return -1;
1030        }
1031
1032        /**
1033         * Given an augmenting path ending at lastNode, "flips" the path.  This
1034         * means that an edge on the path is in the matching after the flip if
1035         * and only if it was not in the matching before the flip.  An augmenting
1036         * path connects two unmatched nodes, so the result is still a matching.
1037         */
1038        void flipPath(int lastNode) {
1039            while (lastNode != EMPTY_LABEL) {
1040                int parent = tLabels[lastNode];
1041   
1042                // Add (parent, lastNode) to matching.  We don't need to
1043                // explicitly remove any edges from the matching because:
1044                //  * We know at this point that there is no i such that
1045                //    sMatches[i] = lastNode. 
1046                //  * Although there might be some j such that tMatches[j] =
1047                //    parent, that j must be sLabels[parent], and will change
1048                //    tMatches[j] in the next time through this loop. 
1049                sMatches[parent] = lastNode;
1050                tMatches[lastNode] = parent;
1051                           
1052                lastNode = sLabels[parent];
1053            }
1054        }
1055
1056        void changeDualVars(double delta) {
1057            for (int i = 0; i < n; i++) {
1058                if (sLabels[i] != NO_LABEL) {
1059                u[i] -= delta;
1060                }
1061            }
1062               
1063            for (int j = 0; j < m; j++) {
1064                if (pi[j] < TOL) {
1065                v[j] += delta;
1066                } else if (tLabels[j] != NO_LABEL) {
1067                pi[j] -= delta;
1068                if (pi[j] < TOL) {
1069                    eligibleT.add(new Integer(j));
1070                }
1071                }
1072            }
1073        }
1074
1075        /**
1076         * Ensures that all weights are either Double.NEGATIVE_INFINITY,
1077         * or strictly greater than zero.
1078         */
1079        private void ensurePositiveWeights() {
1080            // minWeight is the minimum non-infinite weight
1081            if (minWeight < TOL) {
1082                for (int i = 0; i < n; i++) {
1083                for (int j = 0; j < m; j++) {
1084                    weights[i][j] = weights[i][j] - minWeight + 1;
1085                }
1086                }
1087   
1088                maxWeight = maxWeight - minWeight + 1;
1089                minWeight = 1;
1090            }
1091        }
1092
1093        @SuppressWarnings("unused")
1094        private void printWeights() {
1095            for (int i = 0; i < n; i++) {
1096                for (int j = 0; j < m; j++) {
1097                System.out.print(weights[i][j] + " ");
1098                }
1099                System.out.println("");
1100            }
1101        }
1102    }
1103}
Note: See TracBrowser for help on using the repository browser.