source: trunk/CrossPare/src/de/ugoe/cs/cpdp/training/MetricMatchingTraining.java

Last change on this file was 141, checked in by sherbold, 8 years ago
  • code formatting
File size: 42.1 KB
Line 
1// Copyright 2015 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.cpdp.training;
16
17import java.util.ArrayList;
18import java.util.Arrays;
19import java.util.Collections;
20import java.util.Comparator;
21import java.util.HashMap;
22import java.util.Iterator;
23import java.util.LinkedHashMap;
24import java.util.LinkedList;
25import java.util.List;
26import java.util.Map;
27import java.util.Map.Entry;
28import java.util.logging.Level;
29
30import java.util.Random;
31
32import org.apache.commons.collections4.list.SetUniqueList;
33import org.apache.commons.math3.stat.correlation.SpearmansCorrelation;
34import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest;
35
36import de.ugoe.cs.util.console.Console;
37import weka.attributeSelection.SignificanceAttributeEval;
38import weka.classifiers.AbstractClassifier;
39import weka.classifiers.Classifier;
40import weka.core.Attribute;
41import weka.core.DenseInstance;
42import weka.core.Instance;
43import weka.core.Instances;
44
45/**
46 * Implements Heterogenous Defect Prediction after Nam et al. 2015.
47 *
48 * We extend WekaBaseTraining because we have to Wrap the Classifier to use MetricMatching. This
49 * also means we can use any Weka Classifier not just LogisticRegression.
50 *
51 * Config: <setwisetestdataawaretrainer name="MetricMatchingTraining" param=
52 * "Logistic weka.classifiers.functions.Logistic" threshold="0.05" method="spearman"/> Instead of
53 * spearman metchod it also takes ks, percentile. Instead of Logistic every other weka classifier
54 * can be chosen.
55 *
56 * Future work: implement chisquare test in addition to significance for attribute selection
57 * http://commons.apache.org/proper/commons-math/apidocs/org/apache/commons/math3/stat/inference/
58 * ChiSquareTest.html use chiSquareTestDataSetsComparison
59 */
60public class MetricMatchingTraining extends WekaBaseTraining
61    implements ISetWiseTestdataAwareTrainingStrategy
62{
63
64    private MetricMatch mm = null;
65    private Classifier classifier = null;
66
67    private String method;
68    private float threshold;
69
70    /**
71     * We wrap the classifier here because of classifyInstance with our MetricMatchingClassfier
72     *
73     * @return
74     */
75    @Override
76    public Classifier getClassifier() {
77        return this.classifier;
78    }
79
80    /**
81     * Set similarity measure method.
82     */
83    @Override
84    public void setMethod(String method) {
85        this.method = method;
86    }
87
88    /**
89     * Set threshold for similarity measure.
90     */
91    @Override
92    public void setThreshold(String threshold) {
93        this.threshold = Float.parseFloat(threshold);
94    }
95
96    /**
97     * We need the test data instances to do a metric matching, so in this special case we get this
98     * data before evaluation.
99     */
100    @Override
101    public void apply(SetUniqueList<Instances> traindataSet, Instances testdata) {
102        // reset these for each run
103        this.mm = null;
104        this.classifier = null;
105
106        double score = 0; // matching score to select the best matching training data from the set
107        int num = 0;
108        int biggest_num = 0;
109        MetricMatch tmp;
110        for (Instances traindata : traindataSet) {
111            num++;
112
113            tmp = new MetricMatch(traindata, testdata);
114
115            // metric selection may create error, continue to next training set
116            try {
117                tmp.attributeSelection();
118                tmp.matchAttributes(this.method, this.threshold);
119            }
120            catch (Exception e) {
121                e.printStackTrace();
122                throw new RuntimeException(e);
123            }
124
125            // we only select the training data from our set with the most matching attributes
126            if (tmp.getScore() > score && tmp.attributes.size() > 0) {
127                score = tmp.getScore();
128                this.mm = tmp;
129                biggest_num = num;
130            }
131        }
132
133        // if we have found a matching instance we use it, log information about the match for
134        // additional eval later
135        Instances ilist = null;
136        if (this.mm != null) {
137            ilist = this.mm.getMatchedTrain();
138            Console.traceln(Level.INFO, "[MATCH FOUND] match: [" + biggest_num + "], score: [" +
139                score + "], instances: [" + ilist.size() + "], attributes: [" +
140                this.mm.attributes.size() + "], ilist attrs: [" + ilist.numAttributes() + "]");
141            for (Map.Entry<Integer, Integer> attmatch : this.mm.attributes.entrySet()) {
142                Console.traceln(Level.INFO, "[MATCHED ATTRIBUTE] source attribute: [" +
143                    this.mm.train.attribute(attmatch.getKey()).name() + "], target attribute: [" +
144                    this.mm.test.attribute(attmatch.getValue()).name() + "]");
145            }
146        }
147        else {
148            Console.traceln(Level.INFO, "[NO MATCH FOUND]");
149        }
150
151        // if we have a match we build the MetricMatchingClassifier, if not we fall back to FixClass
152        // Classifier
153        try {
154            if (this.mm != null) {
155                this.classifier = new MetricMatchingClassifier();
156                this.classifier.buildClassifier(ilist);
157                ((MetricMatchingClassifier) this.classifier).setMetricMatching(this.mm);
158            }
159            else {
160                this.classifier = new FixClass();
161                this.classifier.buildClassifier(ilist); // this is null, but the FixClass Classifier
162                                                        // does not use it anyway
163            }
164        }
165        catch (Exception e) {
166            e.printStackTrace();
167            throw new RuntimeException(e);
168        }
169    }
170
171    /**
172     * Encapsulates the classifier configured with WekaBase within but use metric matching. This
173     * allows us to use any Weka classifier with Heterogenous Defect Prediction.
174     */
175    public class MetricMatchingClassifier extends AbstractClassifier {
176
177        private static final long serialVersionUID = -1342172153473770935L;
178        private MetricMatch mm;
179        private Classifier classifier;
180
181        @Override
182        public void buildClassifier(Instances traindata) throws Exception {
183            this.classifier = setupClassifier();
184            this.classifier.buildClassifier(traindata);
185        }
186
187        /**
188         * Sets the MetricMatch instance so that we can use matched test data later.
189         *
190         * @param mm
191         */
192        public void setMetricMatching(MetricMatch mm) {
193            this.mm = mm;
194        }
195
196        /**
197         * Here we can not do the metric matching because we only get one instance. Therefore we
198         * need a MetricMatch instance beforehand to use here.
199         */
200        public double classifyInstance(Instance testdata) {
201            // get a copy of testdata Instance with only the matched attributes
202            Instance ntest = this.mm.getMatchedTestInstance(testdata);
203
204            double ret = 0.0;
205            try {
206                ret = this.classifier.classifyInstance(ntest);
207            }
208            catch (Exception e) {
209                e.printStackTrace();
210                throw new RuntimeException(e);
211            }
212
213            return ret;
214        }
215    }
216
217    /**
218     * Encapsulates one MetricMatching process. One source (train) matches against one target
219     * (test).
220     */
221    public class MetricMatch {
222        Instances train;
223        Instances test;
224
225        // used to sum up the matching values of all attributes
226        protected double p_sum = 0;
227
228        // attribute matching, train -> test
229        HashMap<Integer, Integer> attributes = new HashMap<Integer, Integer>();
230
231        // used for similarity tests
232        protected ArrayList<double[]> train_values;
233        protected ArrayList<double[]> test_values;
234
235        public MetricMatch(Instances train, Instances test) {
236            // this is expensive but we need to keep the original data intact
237            this.train = this.deepCopy(train);
238            this.test = test; // we do not need a copy here because we do not drop attributes before
239                              // the matching and after the matching we create a new Instances with
240                              // only the matched attributes
241
242            // convert metrics of testdata and traindata to later use in similarity tests
243            this.train_values = new ArrayList<double[]>();
244            for (int i = 0; i < this.train.numAttributes(); i++) {
245                if (this.train.classIndex() != i) {
246                    this.train_values.add(this.train.attributeToDoubleArray(i));
247                }
248            }
249
250            this.test_values = new ArrayList<double[]>();
251            for (int i = 0; i < this.test.numAttributes(); i++) {
252                if (this.test.classIndex() != i) {
253                    this.test_values.add(this.test.attributeToDoubleArray(i));
254                }
255            }
256        }
257
258        /**
259         * We have a lot of matching possibilities. Here we try to determine the best one.
260         *
261         * @return double matching score
262         */
263        public double getScore() {
264            int as = this.attributes.size(); // # of attributes that were matched
265
266            // we use thresholding ranking approach for numInstances to influence the matching score
267            int instances = this.train.numInstances();
268            int inst_rank = 0;
269            if (instances > 100) {
270                inst_rank = 1;
271            }
272            if (instances > 500) {
273                inst_rank = 2;
274            }
275
276            return this.p_sum + as + inst_rank;
277        }
278
279        public HashMap<Integer, Integer> getAttributes() {
280            return this.attributes;
281        }
282
283        public int getNumInstances() {
284            return this.train_values.get(0).length;
285        }
286
287        /**
288         * The test instance must be of the same dataset as the train data, otherwise WekaEvaluation
289         * will die. This means we have to force the dataset of this.train (after matching) and only
290         * set the values for the attributes we matched but with the index of the traindata
291         * attributes we matched.
292         *
293         * @param test
294         * @return
295         */
296        public Instance getMatchedTestInstance(Instance test) {
297            Instance ni = new DenseInstance(this.attributes.size() + 1);
298
299            Instances inst = this.getMatchedTrain();
300
301            ni.setDataset(inst);
302
303            // assign only the matched attributes to new indexes
304            double val;
305            int k = 0;
306            for (Map.Entry<Integer, Integer> attmatch : this.attributes.entrySet()) {
307                // get value from matched attribute
308                val = test.value(attmatch.getValue());
309
310                // set it to new index, the order of the attributes is the same
311                ni.setValue(k, val);
312                k++;
313            }
314            ni.setClassValue(test.value(test.classAttribute()));
315
316            return ni;
317        }
318
319        /**
320         * returns a new instances array with the metric matched training data
321         *
322         * @return instances
323         */
324        public Instances getMatchedTrain() {
325            return this.getMatchedInstances("train", this.train);
326        }
327
328        /**
329         * returns a new instances array with the metric matched test data
330         *
331         * @return instances
332         */
333        public Instances getMatchedTest() {
334            return this.getMatchedInstances("test", this.test);
335        }
336
337        /**
338         * We could drop unmatched attributes from our instances datasets. Alas, that would not be
339         * nice for the following postprocessing jobs and would not work at all for evaluation. We
340         * keep this as a warning for future generations.
341         *
342         * @param name
343         * @param data
344         */
345        @SuppressWarnings("unused")
346        private void dropUnmatched(String name, Instances data) {
347            for (int i = 0; i < data.numAttributes(); i++) {
348                if (data.classIndex() == i) {
349                    continue;
350                }
351
352                if (name.equals("train") && !this.attributes.containsKey(i)) {
353                    data.deleteAttributeAt(i);
354                }
355
356                if (name.equals("test") && !this.attributes.containsValue(i)) {
357                    data.deleteAttributeAt(i);
358                }
359            }
360        }
361
362        /**
363         * Deep Copy (well, reasonably deep, not sure about header information of attributes) Weka
364         * Instances.
365         *
366         * @param data
367         *            Instances
368         * @return copy of Instances passed
369         */
370        private Instances deepCopy(Instances data) {
371            Instances newInst = new Instances(data);
372
373            newInst.clear();
374
375            for (int i = 0; i < data.size(); i++) {
376                Instance ni = new DenseInstance(data.numAttributes());
377                for (int j = 0; j < data.numAttributes(); j++) {
378                    ni.setValue(newInst.attribute(j), data.instance(i).value(data.attribute(j)));
379                }
380                newInst.add(ni);
381            }
382
383            return newInst;
384        }
385
386        /**
387         * Returns a deep copy of passed Instances data for Train or Test data. It only keeps
388         * attributes that have been matched.
389         *
390         * @param name
391         * @param data
392         * @return matched Instances
393         */
394        private Instances getMatchedInstances(String name, Instances data) {
395            ArrayList<Attribute> attrs = new ArrayList<Attribute>();
396
397            // bug attr is a string, really!
398            ArrayList<String> bug = new ArrayList<String>();
399            bug.add("0");
400            bug.add("1");
401
402            // add our matched attributes and last the bug
403            for (Map.Entry<Integer, Integer> attmatch : this.attributes.entrySet()) {
404                attrs.add(new Attribute(String.valueOf(attmatch.getValue())));
405            }
406            attrs.add(new Attribute("bug", bug));
407
408            // create new instances object of the same size (at least for instances)
409            Instances newInst = new Instances(name, attrs, data.size());
410
411            // set last as class
412            newInst.setClassIndex(newInst.numAttributes() - 1);
413
414            // copy data for matched attributes, this depends if we return train or test data
415            for (int i = 0; i < data.size(); i++) {
416                Instance ni = new DenseInstance(this.attributes.size() + 1);
417
418                int j = 0; // new indices!
419                for (Map.Entry<Integer, Integer> attmatch : this.attributes.entrySet()) {
420
421                    // test attribute match
422                    int value = attmatch.getValue();
423
424                    // train attribute match
425                    if (name.equals("train")) {
426                        value = attmatch.getKey();
427                    }
428
429                    ni.setValue(newInst.attribute(j), data.instance(i).value(value));
430                    j++;
431                }
432                ni.setValue(ni.numAttributes() - 1, data.instance(i).value(data.classAttribute()));
433                newInst.add(ni);
434            }
435
436            return newInst;
437        }
438
439        /**
440         * performs the attribute selection we perform attribute significance tests and drop
441         * attributes
442         *
443         * attribute selection is only performed on the source dataset we retain the top 15%
444         * attributes (if 15% is a float we just use the integer part)
445         */
446        public void attributeSelection() throws Exception {
447
448            // it is a wrapper, we may decide to implement ChiSquare or other means of selecting
449            // attributes
450            this.attributeSelectionBySignificance(this.train);
451        }
452
453        private void attributeSelectionBySignificance(Instances which) throws Exception {
454            // Uses:
455            // http://weka.sourceforge.net/doc.packages/probabilisticSignificanceAE/weka/attributeSelection/SignificanceAttributeEval.html
456            SignificanceAttributeEval et = new SignificanceAttributeEval();
457            et.buildEvaluator(which);
458
459            // evaluate all training attributes
460            HashMap<String, Double> saeval = new HashMap<String, Double>();
461            for (int i = 0; i < which.numAttributes(); i++) {
462                if (which.classIndex() != i) {
463                    saeval.put(which.attribute(i).name(), et.evaluateAttribute(i));
464                }
465            }
466
467            // sort by significance
468            HashMap<String, Double> sorted = (HashMap<String, Double>) sortByValues(saeval);
469
470            // Keep the best 15%
471            double last = ((double) saeval.size() / 100.0) * 15.0;
472            int drop_first = saeval.size() - (int) last;
473
474            // drop attributes above last
475            Iterator<Entry<String, Double>> it = sorted.entrySet().iterator();
476            while (drop_first > 0) {
477                Map.Entry<String, Double> pair = (Map.Entry<String, Double>) it.next();
478                if (which.attribute((String) pair.getKey()).index() != which.classIndex()) {
479                    which.deleteAttributeAt(which.attribute((String) pair.getKey()).index());
480                }
481                drop_first -= 1;
482            }
483        }
484
485        /**
486         * Helper method to sort a hashmap by its values.
487         *
488         * @param map
489         * @return sorted map
490         */
491        private HashMap<String, Double> sortByValues(HashMap<String, Double> map) {
492            List<Map.Entry<String, Double>> list =
493                new LinkedList<Map.Entry<String, Double>>(map.entrySet());
494
495            Collections.sort(list, new Comparator<Map.Entry<String, Double>>() {
496                public int compare(Map.Entry<String, Double> o1, Map.Entry<String, Double> o2) {
497                    return (o1.getValue()).compareTo(o2.getValue());
498                }
499            });
500
501            HashMap<String, Double> sortedHashMap = new LinkedHashMap<String, Double>();
502            for (Map.Entry<String, Double> item : list) {
503                sortedHashMap.put(item.getKey(), item.getValue());
504            }
505            return sortedHashMap;
506        }
507
508        /**
509         * Executes the similarity matching between train and test data.
510         *
511         * After this function is finished we have this.attributes with the correct matching between
512         * train and test data attributes.
513         *
514         * @param type
515         * @param cutoff
516         */
517        public void matchAttributes(String type, double cutoff) {
518
519            MWBMatchingAlgorithm mwbm =
520                new MWBMatchingAlgorithm(this.train.numAttributes(), this.test.numAttributes());
521
522            if (type.equals("spearman")) {
523                this.spearmansRankCorrelation(cutoff, mwbm);
524            }
525            else if (type.equals("ks")) {
526                this.kolmogorovSmirnovTest(cutoff, mwbm);
527            }
528            else if (type.equals("percentile")) {
529                this.percentiles(cutoff, mwbm);
530            }
531            else {
532                throw new RuntimeException("unknown matching method");
533            }
534
535            // resulting maximal match gets assigned to this.attributes
536            int[] result = mwbm.getMatching();
537            for (int i = 0; i < result.length; i++) {
538
539                // -1 means that it is not in the set of maximal matching
540                if (i != -1 && result[i] != -1) {
541                    this.p_sum += mwbm.weights[i][result[i]]; // we add the weight of the returned
542                                                              // matching for scoring the complete
543                                                              // match later
544                    this.attributes.put(i, result[i]);
545                }
546            }
547        }
548
549        /**
550         * Calculates the Percentiles of the source and target metrics.
551         *
552         * @param cutoff
553         */
554        public void percentiles(double cutoff, MWBMatchingAlgorithm mwbm) {
555            for (int i = 0; i < this.train.numAttributes(); i++) {
556                for (int j = 0; j < this.test.numAttributes(); j++) {
557                    // negative infinity counts as not present, we do this so we don't have to map
558                    // between attribute indexes in weka
559                    // and the result of the mwbm computation
560                    mwbm.setWeight(i, j, Double.NEGATIVE_INFINITY);
561
562                    // class attributes are not relevant
563                    if (this.test.classIndex() == j) {
564                        continue;
565                    }
566                    if (this.train.classIndex() == i) {
567                        continue;
568                    }
569
570                    // get percentiles
571                    double train[] = this.train_values.get(i);
572                    double test[] = this.test_values.get(j);
573
574                    Arrays.sort(train);
575                    Arrays.sort(test);
576
577                    // percentiles
578                    double train_p;
579                    double test_p;
580                    double score = 0.0;
581                    for (int p = 1; p <= 9; p++) {
582                        train_p = train[(int) Math.ceil(train.length * (p / 100))];
583                        test_p = test[(int) Math.ceil(test.length * (p / 100))];
584
585                        if (train_p > test_p) {
586                            score += test_p / train_p;
587                        }
588                        else {
589                            score += train_p / test_p;
590                        }
591                    }
592
593                    if (score > cutoff) {
594                        mwbm.setWeight(i, j, score);
595                    }
596                }
597            }
598        }
599
600        /**
601         * Calculate Spearmans rank correlation coefficient as matching score. The number of
602         * instances for the source and target needs to be the same so we randomly sample from the
603         * bigger one.
604         *
605         * @param cutoff
606         * @param mwbmatching
607         */
608        public void spearmansRankCorrelation(double cutoff, MWBMatchingAlgorithm mwbm) {
609            double p = 0;
610
611            SpearmansCorrelation t = new SpearmansCorrelation();
612
613            // size has to be the same so we randomly sample the number of the smaller sample from
614            // the big sample
615            if (this.train.size() > this.test.size()) {
616                this.sample(this.train, this.test, this.train_values);
617            }
618            else if (this.test.size() > this.train.size()) {
619                this.sample(this.test, this.train, this.test_values);
620            }
621
622            // try out possible attribute combinations
623            for (int i = 0; i < this.train.numAttributes(); i++) {
624                for (int j = 0; j < this.test.numAttributes(); j++) {
625                    // negative infinity counts as not present, we do this so we don't have to map
626                    // between attribute indexs in weka
627                    // and the result of the mwbm computation
628                    mwbm.setWeight(i, j, Double.NEGATIVE_INFINITY);
629
630                    // class attributes are not relevant
631                    if (this.test.classIndex() == j) {
632                        continue;
633                    }
634                    if (this.train.classIndex() == i) {
635                        continue;
636                    }
637
638                    p = t.correlation(this.train_values.get(i), this.test_values.get(j));
639                    if (p > cutoff) {
640                        mwbm.setWeight(i, j, p);
641                    }
642                }
643            }
644        }
645
646        /**
647         * Helper method to sample instances for the Spearman rank correlation coefficient method.
648         *
649         * @param bigger
650         * @param smaller
651         * @param values
652         */
653        private void sample(Instances bigger, Instances smaller, ArrayList<double[]> values) {
654            // we want to at keep the indices we select the same
655            int indices_to_draw = smaller.size();
656            ArrayList<Integer> indices = new ArrayList<Integer>();
657            Random rand = new Random();
658            while (indices_to_draw > 0) {
659
660                int index = rand.nextInt(bigger.size() - 1);
661
662                if (!indices.contains(index)) {
663                    indices.add(index);
664                    indices_to_draw--;
665                }
666            }
667
668            // now reduce our values to the indices we choose above for every attribute
669            for (int att = 0; att < bigger.numAttributes() - 1; att++) {
670
671                // get double for the att
672                double[] vals = values.get(att);
673                double[] new_vals = new double[indices.size()];
674
675                int i = 0;
676                for (Iterator<Integer> it = indices.iterator(); it.hasNext();) {
677                    new_vals[i] = vals[it.next()];
678                    i++;
679                }
680
681                values.set(att, new_vals);
682            }
683        }
684
685        /**
686         * We run the kolmogorov-smirnov test on the data from our test an traindata if the p value
687         * is above the cutoff we include it in the results p value tends to be 0 when the
688         * distributions of the data are significantly different but we want them to be the same
689         *
690         * @param cutoff
691         * @return p-val
692         */
693        public void kolmogorovSmirnovTest(double cutoff, MWBMatchingAlgorithm mwbm) {
694            double p = 0;
695
696            KolmogorovSmirnovTest t = new KolmogorovSmirnovTest();
697            for (int i = 0; i < this.train.numAttributes(); i++) {
698                for (int j = 0; j < this.test.numAttributes(); j++) {
699                    // negative infinity counts as not present, we do this so we don't have to map
700                    // between attribute indexs in weka
701                    // and the result of the mwbm computation
702                    mwbm.setWeight(i, j, Double.NEGATIVE_INFINITY);
703
704                    // class attributes are not relevant
705                    if (this.test.classIndex() == j) {
706                        continue;
707                    }
708                    if (this.train.classIndex() == i) {
709                        continue;
710                    }
711
712                    // this may invoke exactP on small sample sizes which will not terminate in all
713                    // cases
714                    // p = t.kolmogorovSmirnovTest(this.train_values.get(i),
715                    // this.test_values.get(j), false);
716
717                    // this uses approximateP everytime
718                    p = t.approximateP(
719                                       t.kolmogorovSmirnovStatistic(this.train_values.get(i),
720                                                                    this.test_values.get(j)),
721                                       this.train_values.get(i).length,
722                                       this.test_values.get(j).length);
723                    if (p > cutoff) {
724                        mwbm.setWeight(i, j, p);
725                    }
726                }
727            }
728        }
729    }
730
731    /*
732     * Copyright (c) 2007, Massachusetts Institute of Technology Copyright (c) 2005-2006, Regents of
733     * the University of California All rights reserved.
734     *
735     * Redistribution and use in source and binary forms, with or without modification, are
736     * permitted provided that the following conditions are met:
737     *
738     * * Redistributions of source code must retain the above copyright notice, this list of
739     * conditions and the following disclaimer.
740     *
741     * * Redistributions in binary form must reproduce the above copyright notice, this list of
742     * conditions and the following disclaimer in the documentation and/or other materials provided
743     * with the distribution.
744     *
745     * * Neither the name of the University of California, Berkeley nor the names of its
746     * contributors may be used to endorse or promote products derived from this software without
747     * specific prior written permission.
748     *
749     * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS
750     * OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
751     * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
752     * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
753     * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
754     * GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED
755     * AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
756     * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
757     * OF THE POSSIBILITY OF SUCH DAMAGE.
758     */
759
760    /**
761     * An engine for finding the maximum-weight matching in a complete bipartite graph. Suppose we
762     * have two sets <i>S</i> and <i>T</i>, both of size <i>n</i>. For each <i>i</i> in <i>S</i> and
763     * <i>j</i> in <i>T</i>, we have a weight <i>w<sub>ij</sub></i>. A perfect matching <i>X</i> is
764     * a subset of <i>S</i> x <i>T</i> such that each <i>i</i> in <i>S</i> occurs in exactly one
765     * element of <i>X</i>, and each <i>j</i> in <i>T</i> occurs in exactly one element of <i>X</i>.
766     * Thus, <i>X</i> can be thought of as a one-to-one function from <i>S</i> to <i>T</i>. The
767     * weight of <i>X</i> is the sum, over (<i>i</i>, <i>j</i>) in <i>X</i>, of <i>w
768     * <sub>ij</sub></i>. A BipartiteMatcher takes the number <i>n</i> and the weights <i>w
769     * <sub>ij</sub></i>, and finds a perfect matching of maximum weight.
770     *
771     * It uses the Hungarian algorithm of Kuhn (1955), as improved and presented by E. L. Lawler in
772     * his book <cite>Combinatorial Optimization: Networks and Matroids</cite> (Holt, Rinehart and
773     * Winston, 1976, p. 205-206). The running time is O(<i>n</i><sup>3</sup>). The weights can be
774     * any finite real numbers; Lawler's algorithm assumes positive weights, so if necessary we add
775     * a constant <i>c</i> to all the weights before running the algorithm. This increases the
776     * weight of every perfect matching by <i>nc</i>, which doesn't change which perfect matchings
777     * have maximum weight.
778     *
779     * If a weight is set to Double.NEGATIVE_INFINITY, then the algorithm will behave as if that
780     * edge were not in the graph. If all the edges incident on a given node have weight
781     * Double.NEGATIVE_INFINITY, then the final result will not be a perfect matching, and an
782     * exception will be thrown.
783     */
784    class MWBMatchingAlgorithm {
785        /**
786         * Creates a BipartiteMatcher without specifying the graph size. Calling any other method
787         * before calling reset will yield an IllegalStateException.
788         */
789
790        /**
791         * Tolerance for comparisons to zero, to account for floating-point imprecision. We consider
792         * a positive number to be essentially zero if it is strictly less than TOL.
793         */
794        private static final double TOL = 1e-10;
795        // Number of left side nodes
796        int n;
797
798        // Number of right side nodes
799        int m;
800
801        double[][] weights;
802        double minWeight;
803        double maxWeight;
804
805        // If (i, j) is in the mapping, then sMatches[i] = j and tMatches[j] = i.
806        // If i is unmatched, then sMatches[i] = -1 (and likewise for tMatches).
807        int[] sMatches;
808        int[] tMatches;
809
810        static final int NO_LABEL = -1;
811        static final int EMPTY_LABEL = -2;
812
813        int[] sLabels;
814        int[] tLabels;
815
816        double[] u;
817        double[] v;
818
819        double[] pi;
820
821        List<Integer> eligibleS = new ArrayList<Integer>();
822        List<Integer> eligibleT = new ArrayList<Integer>();
823
824        public MWBMatchingAlgorithm() {
825            n = -1;
826            m = -1;
827        }
828
829        /**
830         * Creates a BipartiteMatcher and prepares it to run on an n x m graph. All the weights are
831         * initially set to 1.
832         */
833        public MWBMatchingAlgorithm(int n, int m) {
834            reset(n, m);
835        }
836
837        /**
838         * Resets the BipartiteMatcher to run on an n x m graph. The weights are all reset to 1.
839         */
840        private void reset(int n, int m) {
841            if (n < 0 || m < 0) {
842                throw new IllegalArgumentException("Negative num nodes: " + n + " or " + m);
843            }
844            this.n = n;
845            this.m = m;
846
847            weights = new double[n][m];
848            for (int i = 0; i < n; i++) {
849                for (int j = 0; j < m; j++) {
850                    weights[i][j] = 1;
851                }
852            }
853            minWeight = 1;
854            maxWeight = Double.NEGATIVE_INFINITY;
855
856            sMatches = new int[n];
857            tMatches = new int[m];
858            sLabels = new int[n];
859            tLabels = new int[m];
860            u = new double[n];
861            v = new double[m];
862            pi = new double[m];
863
864        }
865
866        /**
867         * Sets the weight w<sub>ij</sub> to the given value w.
868         *
869         * @throws IllegalArgumentException
870         *             if i or j is outside the range [0, n).
871         */
872        public void setWeight(int i, int j, double w) {
873            if (n == -1 || m == -1) {
874                throw new IllegalStateException("Graph size not specified.");
875            }
876            if ((i < 0) || (i >= n)) {
877                throw new IllegalArgumentException("i-value out of range: " + i);
878            }
879            if ((j < 0) || (j >= m)) {
880                throw new IllegalArgumentException("j-value out of range: " + j);
881            }
882            if (Double.isNaN(w)) {
883                throw new IllegalArgumentException("Illegal weight: " + w);
884            }
885
886            weights[i][j] = w;
887            if ((w > Double.NEGATIVE_INFINITY) && (w < minWeight)) {
888                minWeight = w;
889            }
890            if (w > maxWeight) {
891                maxWeight = w;
892            }
893        }
894
895        /**
896         * Returns a maximum-weight perfect matching relative to the weights specified with
897         * setWeight. The matching is represented as an array arr of length n, where arr[i] = j if
898         * (i,j) is in the matching.
899         */
900        public int[] getMatching() {
901            if (n == -1 || m == -1) {
902                throw new IllegalStateException("Graph size not specified.");
903            }
904            if (n == 0) {
905                return new int[0];
906            }
907            ensurePositiveWeights();
908
909            // Step 0: Initialization
910            eligibleS.clear();
911            eligibleT.clear();
912            for (Integer i = 0; i < n; i++) {
913                sMatches[i] = -1;
914
915                u[i] = maxWeight; // ambiguous on p. 205 of Lawler, but see p. 202
916
917                // this is really first run of Step 1.0
918                sLabels[i] = EMPTY_LABEL;
919                eligibleS.add(i);
920            }
921
922            for (int j = 0; j < m; j++) {
923                tMatches[j] = -1;
924
925                v[j] = 0;
926                pi[j] = Double.POSITIVE_INFINITY;
927
928                // this is really first run of Step 1.0
929                tLabels[j] = NO_LABEL;
930            }
931
932            while (true) {
933                // Augment the matching until we can't augment any more given the
934                // current settings of the dual variables.
935                while (true) {
936                    // Steps 1.1-1.4: Find an augmenting path
937                    int lastNode = findAugmentingPath();
938                    if (lastNode == -1) {
939                        break; // no augmenting path
940                    }
941
942                    // Step 2: Augmentation
943                    flipPath(lastNode);
944                    for (int i = 0; i < n; i++)
945                        sLabels[i] = NO_LABEL;
946
947                    for (int j = 0; j < m; j++) {
948                        pi[j] = Double.POSITIVE_INFINITY;
949                        tLabels[j] = NO_LABEL;
950                    }
951
952                    // This is Step 1.0
953                    eligibleS.clear();
954                    for (int i = 0; i < n; i++) {
955                        if (sMatches[i] == -1) {
956                            sLabels[i] = EMPTY_LABEL;
957                            eligibleS.add(new Integer(i));
958                        }
959                    }
960
961                    eligibleT.clear();
962                }
963
964                // Step 3: Change the dual variables
965
966                // delta1 = min_i u[i]
967                double delta1 = Double.POSITIVE_INFINITY;
968                for (int i = 0; i < n; i++) {
969                    if (u[i] < delta1) {
970                        delta1 = u[i];
971                    }
972                }
973
974                // delta2 = min_{j : pi[j] > 0} pi[j]
975                double delta2 = Double.POSITIVE_INFINITY;
976                for (int j = 0; j < m; j++) {
977                    if ((pi[j] >= TOL) && (pi[j] < delta2)) {
978                        delta2 = pi[j];
979                    }
980                }
981
982                if (delta1 < delta2) {
983                    // In order to make another pi[j] equal 0, we'd need to
984                    // make some u[i] negative.
985                    break; // we have a maximum-weight matching
986                }
987
988                changeDualVars(delta2);
989            }
990
991            int[] matching = new int[n];
992            for (int i = 0; i < n; i++) {
993                matching[i] = sMatches[i];
994            }
995            return matching;
996        }
997
998        /**
999         * Tries to find an augmenting path containing only edges (i,j) for which u[i] + v[j] =
1000         * weights[i][j]. If it succeeds, returns the index of the last node in the path. Otherwise,
1001         * returns -1. In any case, updates the labels and pi values.
1002         */
1003        int findAugmentingPath() {
1004            while ((!eligibleS.isEmpty()) || (!eligibleT.isEmpty())) {
1005                if (!eligibleS.isEmpty()) {
1006                    int i = ((Integer) eligibleS.get(eligibleS.size() - 1)).intValue();
1007                    eligibleS.remove(eligibleS.size() - 1);
1008                    for (int j = 0; j < m; j++) {
1009                        // If pi[j] has already been decreased essentially
1010                        // to zero, then j is already labeled, and we
1011                        // can't decrease pi[j] any more. Omitting the
1012                        // pi[j] >= TOL check could lead us to relabel j
1013                        // unnecessarily, since the diff we compute on the
1014                        // next line may end up being less than pi[j] due
1015                        // to floating point imprecision.
1016                        if ((tMatches[j] != i) && (pi[j] >= TOL)) {
1017                            double diff = u[i] + v[j] - weights[i][j];
1018                            if (diff < pi[j]) {
1019                                tLabels[j] = i;
1020                                pi[j] = diff;
1021                                if (pi[j] < TOL) {
1022                                    eligibleT.add(new Integer(j));
1023                                }
1024                            }
1025                        }
1026                    }
1027                }
1028                else {
1029                    int j = ((Integer) eligibleT.get(eligibleT.size() - 1)).intValue();
1030                    eligibleT.remove(eligibleT.size() - 1);
1031                    if (tMatches[j] == -1) {
1032                        return j; // we've found an augmenting path
1033                    }
1034
1035                    int i = tMatches[j];
1036                    sLabels[i] = j;
1037                    eligibleS.add(new Integer(i)); // ok to add twice
1038                }
1039            }
1040
1041            return -1;
1042        }
1043
1044        /**
1045         * Given an augmenting path ending at lastNode, "flips" the path. This means that an edge on
1046         * the path is in the matching after the flip if and only if it was not in the matching
1047         * before the flip. An augmenting path connects two unmatched nodes, so the result is still
1048         * a matching.
1049         */
1050        void flipPath(int lastNode) {
1051            while (lastNode != EMPTY_LABEL) {
1052                int parent = tLabels[lastNode];
1053
1054                // Add (parent, lastNode) to matching. We don't need to
1055                // explicitly remove any edges from the matching because:
1056                // * We know at this point that there is no i such that
1057                // sMatches[i] = lastNode.
1058                // * Although there might be some j such that tMatches[j] =
1059                // parent, that j must be sLabels[parent], and will change
1060                // tMatches[j] in the next time through this loop.
1061                sMatches[parent] = lastNode;
1062                tMatches[lastNode] = parent;
1063
1064                lastNode = sLabels[parent];
1065            }
1066        }
1067
1068        void changeDualVars(double delta) {
1069            for (int i = 0; i < n; i++) {
1070                if (sLabels[i] != NO_LABEL) {
1071                    u[i] -= delta;
1072                }
1073            }
1074
1075            for (int j = 0; j < m; j++) {
1076                if (pi[j] < TOL) {
1077                    v[j] += delta;
1078                }
1079                else if (tLabels[j] != NO_LABEL) {
1080                    pi[j] -= delta;
1081                    if (pi[j] < TOL) {
1082                        eligibleT.add(new Integer(j));
1083                    }
1084                }
1085            }
1086        }
1087
1088        /**
1089         * Ensures that all weights are either Double.NEGATIVE_INFINITY, or strictly greater than
1090         * zero.
1091         */
1092        private void ensurePositiveWeights() {
1093            // minWeight is the minimum non-infinite weight
1094            if (minWeight < TOL) {
1095                for (int i = 0; i < n; i++) {
1096                    for (int j = 0; j < m; j++) {
1097                        weights[i][j] = weights[i][j] - minWeight + 1;
1098                    }
1099                }
1100
1101                maxWeight = maxWeight - minWeight + 1;
1102                minWeight = 1;
1103            }
1104        }
1105
1106        @SuppressWarnings("unused")
1107        private void printWeights() {
1108            for (int i = 0; i < n; i++) {
1109                for (int j = 0; j < m; j++) {
1110                    System.out.print(weights[i][j] + " ");
1111                }
1112                System.out.println("");
1113            }
1114        }
1115    }
1116}
Note: See TracBrowser for help on using the repository browser.