source: trunk/CrossPare/src/de/ugoe/cs/cpdp/training/MetricMatchingTraining.java @ 139

Last change on this file since 139 was 139, checked in by atrautsch, 8 years ago

Code Cleanup

File size: 38.2 KB
Line 
1// Copyright 2015 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.cpdp.training;
16
17import java.util.ArrayList;
18import java.util.Arrays;
19import java.util.Collections;
20import java.util.Comparator;
21import java.util.HashMap;
22import java.util.Iterator;
23import java.util.LinkedHashMap;
24import java.util.LinkedList;
25import java.util.List;
26import java.util.Map;
27import java.util.Map.Entry;
28import java.util.logging.Level;
29
30import java.util.Random;
31
32import org.apache.commons.collections4.list.SetUniqueList;
33import org.apache.commons.math3.stat.correlation.SpearmansCorrelation;
34import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest;
35
36import de.ugoe.cs.util.console.Console;
37import weka.attributeSelection.SignificanceAttributeEval;
38import weka.classifiers.AbstractClassifier;
39import weka.classifiers.Classifier;
40import weka.core.Attribute;
41import weka.core.DenseInstance;
42import weka.core.Instance;
43import weka.core.Instances;
44
45/**
46 * Implements Heterogenous Defect Prediction after Nam et al. 2015.
47 *
48 * We extend WekaBaseTraining because we have to Wrap the Classifier to use MetricMatching.
49 * This also means we can use any Weka Classifier not just LogisticRegression.
50 *
51 */
52public class MetricMatchingTraining extends WekaBaseTraining implements ISetWiseTestdataAwareTrainingStrategy {
53
54    private MetricMatch mm = null;
55    private Classifier classifier = new MetricMatchingClassifier();
56   
57    private String method;
58    private float threshold;
59   
60    /**
61     * We wrap the classifier here because of classifyInstance with our MetricMatchingClassfier
62     * @return
63     */
64    @Override
65    public Classifier getClassifier() {
66        return this.classifier;
67    }
68
69    /**
70     * Set similarity measure method.
71     */
72    @Override
73    public void setMethod(String method) {
74        this.method = method;
75    }
76
77    /**
78     * Set threshold for similarity measure.
79     */
80    @Override
81    public void setThreshold(String threshold) {
82        this.threshold = Float.parseFloat(threshold);
83    }
84
85        /**
86         * We need the test data instances to do a metric matching, so in this special case we get this data
87         * before evaluation.
88         */
89        @Override
90        public void apply(SetUniqueList<Instances> traindataSet, Instances testdata) {
91
92                double score = 0; // matching score to select the best matching training data from the set
93                int num = 0;
94                int biggest_num = 0;
95                MetricMatch tmp;
96                for (Instances traindata : traindataSet) {
97                        num++;
98
99                        tmp = new MetricMatch(traindata, testdata);
100
101                        // metric selection may create error, continue to next training set
102                        try {
103                                tmp.attributeSelection();
104                                tmp.matchAttributes(this.method, this.threshold);
105                        }catch(Exception e) {
106                                e.printStackTrace();
107                                throw new RuntimeException(e);
108                        }
109                       
110                        // we only select the training data from our set with the most matching attributes
111                        if (tmp.getScore() > score && tmp.attributes.size() > 0) {
112                                score = tmp.getScore();
113                                this.mm = tmp;
114                                biggest_num = num;
115                        }
116                }
117               
118                // if we have found a matching instance we use it
119                Instances ilist = null;
120                if (this.mm != null) {
121                ilist = this.mm.getMatchedTrain();
122                Console.traceln(Level.INFO, "[MATCH FOUND] match: ["+biggest_num +"], score: [" + score + "], instances: [" + ilist.size() + "], attributes: [" + this.mm.attributes.size() + "], ilist attrs: ["+ilist.numAttributes()+"]");
123                for(Map.Entry<Integer, Integer> attmatch : this.mm.attributes.entrySet()) {
124                    Console.traceln(Level.INFO, "[MATCHED ATTRIBUTE] source attribute: [" + this.mm.train.attribute(attmatch.getKey()).name() + "], target attribute: [" + this.mm.test.attribute(attmatch.getValue()).name() + "]");
125                }
126            }else {
127                Console.traceln(Level.INFO, "[NO MATCH FOUND]");
128            }
129               
130                // if we have a match we build the MetricMatchingClassifier, if not we fall back to FixClass Classifier
131                try {
132                        if(this.mm != null) {
133                            this.classifier.buildClassifier(ilist);
134                            ((MetricMatchingClassifier) this.classifier).setMetricMatching(this.mm);
135                        }else {
136                            this.classifier = new FixClass();
137                            this.classifier.buildClassifier(ilist);  // this is null, but the FixClass Classifier does not use it anyway
138                        }
139                }catch(Exception e) {
140                        e.printStackTrace();
141                        throw new RuntimeException(e);
142                }
143        }
144
145       
146        /**
147         * Encapsulates the classifier configured with WekaBase within but use metric matching.
148         * This allows us to use any Weka classifier with Heterogenous Defect Prediction.
149         */
150        public class MetricMatchingClassifier extends AbstractClassifier {
151
152                private static final long serialVersionUID = -1342172153473770935L;
153                private MetricMatch mm;
154                private Classifier classifier;
155               
156                @Override
157                public void buildClassifier(Instances traindata) throws Exception {
158                        this.classifier = setupClassifier();
159                        this.classifier.buildClassifier(traindata);
160                }
161               
162                /**
163                 * Sets the MetricMatch instance so that we can use matched test data later.
164                 * @param mm
165                 */
166                public void setMetricMatching(MetricMatch mm) {
167                        this.mm = mm;
168                }
169               
170                /**
171                 * Here we can not do the metric matching because we only get one instance.
172                 * Therefore we need a MetricMatch instance beforehand to use here.
173                 */
174                public double classifyInstance(Instance testdata) {
175                    // get a copy of testdata Instance with only the matched attributes
176                        Instance ntest = this.mm.getMatchedTestInstance(testdata);
177
178                        double ret = 0.0;
179                        try {
180                                ret = this.classifier.classifyInstance(ntest);
181                        }catch(Exception e) {
182                                e.printStackTrace();
183                                throw new RuntimeException(e);
184                        }
185                       
186                        return ret;
187                }
188        }
189       
190        /**
191         * Encapsulates one MetricMatching process.
192         * One source (train) matches against one target (test).
193         */
194    public class MetricMatch {
195            Instances train;
196                Instances test;
197               
198                // used to sum up the matching values of all attributes
199                protected double p_sum = 0;
200               
201                // attribute matching, train -> test
202                HashMap<Integer, Integer> attributes = new HashMap<Integer,Integer>();
203               
204                // used for similarity tests
205                protected ArrayList<double[]> train_values;
206                protected ArrayList<double[]> test_values;
207
208               
209                public MetricMatch(Instances train, Instances test) {
210                        // this is expensive but we need to keep the original data intact
211                    this.train = this.deepCopy(train);
212                        this.test = test; // we do not need a copy here because we do not drop attributes before the matching and after the matching we create a new Instances with only the matched attributes
213                       
214                        // convert metrics of testdata and traindata to later use in similarity tests
215                        this.train_values = new ArrayList<double[]>();
216                        for (int i = 0; i < this.train.numAttributes(); i++) {
217                            if(this.train.classIndex() != i) {
218                                this.train_values.add(this.train.attributeToDoubleArray(i));
219                            }
220                        }
221                       
222                        this.test_values = new ArrayList<double[]>();
223                        for (int i=0; i < this.test.numAttributes(); i++) {
224                            if(this.test.classIndex() != i) {
225                                this.test_values.add(this.test.attributeToDoubleArray(i));
226                            }
227                        }
228                }
229                 
230                /**
231                 * We have a lot of matching possibilities.
232                 * Here we try to determine the best one.
233                 *
234                 * @return double matching score
235                 */
236            public double getScore() {
237                int as = this.attributes.size();  // # of attributes that were matched
238               
239                // we use thresholding ranking approach for numInstances to influence the matching score
240                int instances = this.train.numInstances();
241                int inst_rank = 0;
242                if(instances > 100) {
243                    inst_rank = 1;
244                }
245                if(instances > 500) {
246                inst_rank = 2;
247            }
248           
249                return this.p_sum + as + inst_rank;
250            }
251                 
252                public HashMap<Integer, Integer> getAttributes() {
253                    return this.attributes;
254                }
255                 
256                public int getNumInstances() {
257                    return this.train_values.get(0).length;
258                }
259               
260                /**
261                 * This creates a new Instance out of the passed Instance and the previously matched attributes.
262                 * We do this because the evaluation phase requires an original Instance with every attribute.
263                 *
264                 * @param test instance
265                 * @return new instance
266                 */
267                public Instance getMatchedTestInstance(Instance test) {
268                    //create new instance with our matched number of attributes + 1 (the class attribute)
269                        Instances testdata = this.getMatchedTest();
270
271                        Instance ni = new DenseInstance(this.attributes.size()+1);
272                        ni.setDataset(testdata);
273                         
274                for(Map.Entry<Integer, Integer> attmatch : this.attributes.entrySet()) {
275                    ni.setValue(testdata.attribute(attmatch.getKey()), test.value(attmatch.getValue()));
276                }
277               
278                        ni.setClassValue(test.value(test.classAttribute()));
279
280                        return ni;
281        }
282
283        /**
284         * returns a new instances array with the metric matched training data
285         *
286         * @return instances
287         */
288                public Instances getMatchedTrain() {
289                    return this.getMatchedInstances("train", this.train);
290                }
291                 
292        /**
293                 * returns a new instances array with the metric matched test data
294                 *
295                 * @return instances
296                 */
297                public Instances getMatchedTest() {
298                    return this.getMatchedInstances("test", this.test);
299                }
300               
301                /**
302                 * We could drop unmatched attributes from our instances datasets.
303                 * Alas, that would not be nice for the following postprocessing jobs and would not work at all for evaluation.
304                 * We keep this as a warning for future generations.
305                 *
306                 * @param name
307                 * @param data
308                 */
309                @SuppressWarnings("unused")
310        private void dropUnmatched(String name, Instances data) {
311                    for(int i = 0; i < data.numAttributes(); i++) {
312                        if(data.classIndex() == i) {
313                            continue;
314                        }
315                       
316                        if(name.equals("train") && !this.attributes.containsKey(i)) {
317                            data.deleteAttributeAt(i);
318                        }
319                       
320                        if(name.equals("test") && !this.attributes.containsValue(i)) {
321                            data.deleteAttributeAt(i);
322                        }
323                    }
324                }
325
326        /**
327         * Deep Copy (well, reasonably deep, not sure about header information of attributes) Weka Instances.
328         *
329         * @param data Instances
330         * @return copy of Instances passed
331         */
332                private Instances deepCopy(Instances data) {
333                    Instances newInst = new Instances(data);
334                   
335                    newInst.clear();
336                   
337            for (int i=0; i < data.size(); i++) {
338                Instance ni = new DenseInstance(data.numAttributes());
339                for(int j = 0; j < data.numAttributes(); j++) {
340                    ni.setValue(newInst.attribute(j), data.instance(i).value(data.attribute(j)));
341                }
342                newInst.add(ni);
343            }
344           
345            return newInst;
346                }
347
348        /**
349         * Returns a deep copy of passed Instances data for Train or Test data.
350         * It only keeps attributes that have been matched.
351         *
352         * @param name
353         * @param data
354         * @return matched Instances
355         */
356                private Instances getMatchedInstances(String name, Instances data) {
357                    ArrayList<Attribute> attrs = new ArrayList<Attribute>();
358           
359                    // bug attr is a string, really!
360                    ArrayList<String> bug = new ArrayList<String>();
361            bug.add("0");
362            bug.add("1");
363           
364            // add our matched attributes and last the bug
365                    for(Map.Entry<Integer, Integer> attmatch : this.attributes.entrySet()) {
366                        attrs.add(new Attribute(String.valueOf(attmatch.getValue())));
367                    }
368                    attrs.add(new Attribute("bug", bug));
369                   
370                    // create new instances object of the same size (at least for instances)
371                    Instances newInst = new Instances(name, attrs, data.size());
372                   
373                    // set last as class
374                    newInst.setClassIndex(newInst.numAttributes()-1);
375                   
376                    // copy data for matched attributes, this depends if we return train or test data
377            for (int i=0; i < data.size(); i++) {
378                Instance ni = new DenseInstance(this.attributes.size()+1);
379               
380                int j = 0; // new indices!
381                for(Map.Entry<Integer, Integer> attmatch : this.attributes.entrySet()) {
382 
383                    // test attribute match
384                    int value = attmatch.getValue();
385                   
386                    // train attribute match
387                    if(name.equals("train")) {
388                        value = attmatch.getKey();
389                    }
390                   
391                    ni.setValue(newInst.attribute(j), data.instance(i).value(value));
392                    j++;
393                }
394                ni.setValue(ni.numAttributes()-1, data.instance(i).value(data.classAttribute()));
395                newInst.add(ni);
396            }
397
398            return newInst;
399                }
400                 
401                /**
402                 * performs the attribute selection
403                 * we perform attribute significance tests and drop attributes
404                 *
405                 * attribute selection is only performed on the source dataset
406                 * we retain the top 15% attributes (if 15% is a float we just use the integer part)
407                 */
408                public void attributeSelection() throws Exception {
409                   
410                    // it is a wrapper, we may decide to implement ChiSquare or other means of selecting attributes
411                        this.attributeSelectionBySignificance(this.train);
412                }
413               
414                private void attributeSelectionBySignificance(Instances which) throws Exception {
415                        // Uses: http://weka.sourceforge.net/doc.packages/probabilisticSignificanceAE/weka/attributeSelection/SignificanceAttributeEval.html
416                        SignificanceAttributeEval et = new SignificanceAttributeEval();
417                        et.buildEvaluator(which);
418                       
419                        // evaluate all training attributes
420                        HashMap<String,Double> saeval = new HashMap<String,Double>();
421                        for(int i=0; i < which.numAttributes(); i++) {
422                                if(which.classIndex() != i) {
423                                        saeval.put(which.attribute(i).name(), et.evaluateAttribute(i));
424                                }
425                        }
426                       
427                        // sort by significance
428                        HashMap<String, Double> sorted = (HashMap<String, Double>) sortByValues(saeval);
429                       
430                        // Keep the best 15%
431                        double last = ((double)saeval.size() / 100.0) * 15.0;
432                        int drop_first = saeval.size() - (int)last;
433                       
434                        // drop attributes above last
435                        Iterator<Entry<String, Double>> it = sorted.entrySet().iterator();
436                    while (drop_first > 0) {
437                        Map.Entry<String, Double> pair = (Map.Entry<String, Double>)it.next();
438                        if(which.attribute((String)pair.getKey()).index() != which.classIndex()) {
439                                which.deleteAttributeAt(which.attribute((String)pair.getKey()).index());
440                        }
441                        drop_first-=1;
442                    }
443                }
444               
445                /**
446                 * Helper method to sort a hashmap by its values.
447                 *
448                 * @param map
449                 * @return sorted map
450                 */
451                private HashMap<String, Double> sortByValues(HashMap<String, Double> map) {
452               List<Map.Entry<String, Double>> list = new LinkedList<Map.Entry<String, Double>>(map.entrySet());
453
454               Collections.sort(list, new Comparator<Map.Entry<String, Double>>() {
455                    public int compare(Map.Entry<String, Double> o1, Map.Entry<String, Double> o2) {
456                       return (o1.getValue()).compareTo( o2.getValue() );
457                    }
458               });
459
460               HashMap<String, Double> sortedHashMap = new LinkedHashMap<String, Double>();
461               for(Map.Entry<String, Double> item : list) {
462                   sortedHashMap.put(item.getKey(), item.getValue());
463               }
464               return sortedHashMap;
465                }
466               
467               
468        /**
469         * Executes the similarity matching between train and test data.
470         *
471         * After this function is finished we have this.attributes with the correct matching between train and test data attributes.
472         *
473         * @param type
474         * @param cutoff
475         */
476                public void matchAttributes(String type, double cutoff) {
477                   
478                    MWBMatchingAlgorithm mwbm = new MWBMatchingAlgorithm(this.train.numAttributes(), this.test.numAttributes());
479                   
480                    if (type.equals("spearman")) {
481                        this.spearmansRankCorrelation(cutoff, mwbm);
482                    }else if(type.equals("ks")) {
483                        this.kolmogorovSmirnovTest(cutoff, mwbm);
484                    }else if(type.equals("percentile")) {
485                        this.percentiles(cutoff, mwbm);
486                    }else {
487                        throw new RuntimeException("unknown matching method");
488                    }
489                   
490                    // resulting maximal match gets assigned to this.attributes
491            int[] result = mwbm.getMatching();
492            for( int i = 0; i < result.length; i++) {
493               
494                // -1 means that it is not in the set of maximal matching
495                if( i != -1 && result[i] != -1) {
496                    //Console.traceln(Level.INFO, "Found maximal bipartite match between: "+ i + " and " + result[i]);
497                    this.attributes.put(i, result[i]);
498                }
499            }
500        }
501       
502       
503                /**
504                 * Calculates the Percentiles of the source and target metrics.
505                 *
506                 * @param cutoff
507                 */
508                public void percentiles(double cutoff, MWBMatchingAlgorithm mwbm) {
509                    for( int i = 0; i < this.train.numAttributes(); i++ ) {
510                for( int j = 0; j < this.test.numAttributes(); j++ ) {
511                    // negative infinity counts as not present, we do this so we don't have to map between attribute indexes in weka
512                    // and the result of the mwbm computation
513                    mwbm.setWeight(i, j, Double.NEGATIVE_INFINITY);
514                   
515                    // class attributes are not relevant
516                    if (this.test.classIndex() == j) {
517                        continue;
518                    }
519                    if (this.train.classIndex() == i) {
520                        continue;
521                    }
522
523                    // get percentiles
524                    double train[] = this.train_values.get(i);
525                    double test[] = this.test_values.get(j);
526                   
527                    Arrays.sort(train);
528                    Arrays.sort(test);
529                   
530                    // percentiles
531                    double train_p;
532                    double test_p;
533                    double score = 0.0;
534                    for( int p=1; p <= 9; p++ ) {
535                        train_p = train[(int)Math.ceil(train.length * (p/100))];
536                        test_p = test[(int)Math.ceil(test.length * (p/100))];
537                   
538                        if( train_p > test_p ) {
539                            score += test_p / train_p;
540                        }else {
541                            score += train_p / test_p;
542                        }
543                    }
544                   
545                    if( score > cutoff ) {
546                        this.p_sum += score;
547                        mwbm.setWeight(i, j, score);
548                    }
549                }
550            }
551                }
552                 
553                /**
554                 * Calculate Spearmans rank correlation coefficient as matching score.
555                 * The number of instances for the source and target needs to be the same so we randomly sample from the bigger one.
556                 *
557                 * @param cutoff
558                 * @param mwbmatching
559                 */
560                public void spearmansRankCorrelation(double cutoff, MWBMatchingAlgorithm mwbm) {
561                    double p = 0;
562
563                        SpearmansCorrelation t = new SpearmansCorrelation();
564
565                        // size has to be the same so we randomly sample the number of the smaller sample from the big sample
566                        if (this.train.size() > this.test.size()) {
567                            this.sample(this.train, this.test, this.train_values);
568                        }else if (this.test.size() > this.train.size()) {
569                            this.sample(this.test, this.train, this.test_values);
570                        }
571                       
572            // try out possible attribute combinations
573            for (int i=0; i < this.train.numAttributes(); i++) {
574                for (int j=0; j < this.test.numAttributes(); j++) {
575                    // negative infinity counts as not present, we do this so we don't have to map between attribute indexs in weka
576                    // and the result of the mwbm computation
577                    mwbm.setWeight(i, j, Double.NEGATIVE_INFINITY);
578                   
579                    // class attributes are not relevant
580                    if (this.test.classIndex() == j) {
581                        continue;
582                    }
583                    if (this.train.classIndex() == i) {
584                        continue;
585                    }
586                   
587                                        p = t.correlation(this.train_values.get(i), this.test_values.get(j));
588                                        if (p > cutoff) {
589                                            this.p_sum += p;
590                        mwbm.setWeight(i, j, p);
591                                        }
592                                }
593                    }
594        }
595
596                /**
597                 * Helper method to sample instances for the Spearman rank correlation coefficient method.
598                 *
599                 * @param bigger
600                 * @param smaller
601                 * @param values
602                 */
603        private void sample(Instances bigger, Instances smaller, ArrayList<double[]> values) {
604            // we want to at keep the indices we select the same
605            int indices_to_draw = smaller.size();
606            ArrayList<Integer> indices = new ArrayList<Integer>();
607            Random rand = new Random();
608            while (indices_to_draw > 0) {
609               
610                int index = rand.nextInt(bigger.size()-1);
611               
612                if (!indices.contains(index)) {
613                    indices.add(index);
614                    indices_to_draw--;
615                }
616            }
617           
618            // now reduce our values to the indices we choose above for every attribute
619            for (int att=0; att < bigger.numAttributes()-1; att++) {
620               
621                // get double for the att
622                double[] vals = values.get(att);
623                double[] new_vals = new double[indices.size()];
624               
625                int i = 0;
626                for (Iterator<Integer> it = indices.iterator(); it.hasNext();) {
627                    new_vals[i] = vals[it.next()];
628                    i++;
629                }
630               
631                values.set(att, new_vals);
632            }
633                }
634               
635               
636                /**
637                 * We run the kolmogorov-smirnov test on the data from our test an traindata
638                 * if the p value is above the cutoff we include it in the results
639                 * p value tends to be 0 when the distributions of the data are significantly different
640                 * but we want them to be the same
641                 *
642                 * @param cutoff
643                 * @return p-val
644                 */
645                public void kolmogorovSmirnovTest(double cutoff, MWBMatchingAlgorithm mwbm) {
646                        double p = 0;
647           
648                        KolmogorovSmirnovTest t = new KolmogorovSmirnovTest();
649                        for (int i=0; i < this.train.numAttributes(); i++) {
650                                for ( int j=0; j < this.test.numAttributes(); j++) {
651                    // negative infinity counts as not present, we do this so we don't have to map between attribute indexs in weka
652                    // and the result of the mwbm computation
653                    mwbm.setWeight(i, j, Double.NEGATIVE_INFINITY);
654                   
655                    // class attributes are not relevant
656                    if (this.test.classIndex() == j) {
657                        continue;
658                    }
659                    if (this.train.classIndex() == i) {
660                        continue;
661                    }
662                   
663                    // this may invoke exactP on small sample sizes which will not terminate in all cases
664                                        //p = t.kolmogorovSmirnovTest(this.train_values.get(i), this.test_values.get(j), false);
665
666                    // this uses approximateP everytime
667                                        p = t.approximateP(t.kolmogorovSmirnovStatistic(this.train_values.get(i), this.test_values.get(j)), this.train_values.get(i).length, this.test_values.get(j).length);
668                                        if (p > cutoff) {
669                        this.p_sum += p;
670                        mwbm.setWeight(i, j, p);
671                                        }
672                                }
673                        }
674            }
675    }
676
677    /*
678     * Copyright (c) 2007, Massachusetts Institute of Technology
679     * Copyright (c) 2005-2006, Regents of the University of California
680     * All rights reserved.
681     *
682     * Redistribution and use in source and binary forms, with or without
683     * modification, are permitted provided that the following conditions
684     * are met:
685     *
686     * * Redistributions of source code must retain the above copyright
687     *   notice, this list of conditions and the following disclaimer.
688     *
689     * * Redistributions in binary form must reproduce the above copyright
690     *   notice, this list of conditions and the following disclaimer in
691     *   the documentation and/or other materials provided with the
692     *   distribution. 
693     *
694     * * Neither the name of the University of California, Berkeley nor
695     *   the names of its contributors may be used to endorse or promote
696     *   products derived from this software without specific prior
697     *   written permission.
698     *
699     * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
700     * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
701     * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
702     * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
703     * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
704     * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
705     * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
706     * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
707     * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
708     * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
709     * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
710     * OF THE POSSIBILITY OF SUCH DAMAGE.
711     */
712
713
714
715    /**
716     * An engine for finding the maximum-weight matching in a complete
717     * bipartite graph.  Suppose we have two sets <i>S</i> and <i>T</i>,
718     * both of size <i>n</i>.  For each <i>i</i> in <i>S</i> and <i>j</i>
719     * in <i>T</i>, we have a weight <i>w<sub>ij</sub></i>.  A perfect
720     * matching <i>X</i> is a subset of <i>S</i> x <i>T</i> such that each
721     * <i>i</i> in <i>S</i> occurs in exactly one element of <i>X</i>, and
722     * each <i>j</i> in <i>T</i> occurs in exactly one element of
723     * <i>X</i>.  Thus, <i>X</i> can be thought of as a one-to-one
724     * function from <i>S</i> to <i>T</i>.  The weight of <i>X</i> is the
725     * sum, over (<i>i</i>, <i>j</i>) in <i>X</i>, of
726     * <i>w<sub>ij</sub></i>.  A BipartiteMatcher takes the number
727     * <i>n</i> and the weights <i>w<sub>ij</sub></i>, and finds a perfect
728     * matching of maximum weight.
729     *
730     * It uses the Hungarian algorithm of Kuhn (1955), as improved and
731     * presented by E. L. Lawler in his book <cite>Combinatorial
732     * Optimization: Networks and Matroids</cite> (Holt, Rinehart and
733     * Winston, 1976, p. 205-206).  The running time is
734     * O(<i>n</i><sup>3</sup>).  The weights can be any finite real
735     * numbers; Lawler's algorithm assumes positive weights, so if
736     * necessary we add a constant <i>c</i> to all the weights before
737     * running the algorithm.  This increases the weight of every perfect
738     * matching by <i>nc</i>, which doesn't change which perfect matchings
739     * have maximum weight.
740     *
741     * If a weight is set to Double.NEGATIVE_INFINITY, then the algorithm will
742     * behave as if that edge were not in the graph.  If all the edges incident on
743     * a given node have weight Double.NEGATIVE_INFINITY, then the final result
744     * will not be a perfect matching, and an exception will be thrown. 
745     */
746     class MWBMatchingAlgorithm {
747        /**
748         * Creates a BipartiteMatcher without specifying the graph size.  Calling
749         * any other method before calling reset will yield an
750         * IllegalStateException.
751         */
752       
753         /**
754         * Tolerance for comparisons to zero, to account for
755         * floating-point imprecision.  We consider a positive number to
756         * be essentially zero if it is strictly less than TOL.
757         */
758        private static final double TOL = 1e-10;
759        //Number of left side nodes
760        int n;
761
762        //Number of right side nodes
763        int m;
764
765        double[][] weights;
766        double minWeight;
767        double maxWeight;
768
769        // If (i, j) is in the mapping, then sMatches[i] = j and tMatches[j] = i. 
770        // If i is unmatched, then sMatches[i] = -1 (and likewise for tMatches).
771        int[] sMatches;
772        int[] tMatches;
773
774        static final int NO_LABEL = -1;
775        static final int EMPTY_LABEL = -2;
776
777        int[] sLabels;
778        int[] tLabels;
779
780        double[] u;
781        double[] v;
782       
783        double[] pi;
784
785        List<Integer> eligibleS = new ArrayList<Integer>();
786        List<Integer> eligibleT = new ArrayList<Integer>();
787       
788       
789        public MWBMatchingAlgorithm() {
790        n = -1;
791        m = -1;
792        }
793
794        /**
795         * Creates a BipartiteMatcher and prepares it to run on an n x m graph. 
796         * All the weights are initially set to 1. 
797         */
798        public MWBMatchingAlgorithm(int n, int m) {
799        reset(n, m);
800        }
801
802        /**
803         * Resets the BipartiteMatcher to run on an n x m graph.  The weights are
804         * all reset to 1.
805         */
806        private void reset(int n, int m) {
807            if (n < 0 || m < 0) {
808                throw new IllegalArgumentException("Negative num nodes: " + n + " or " + m);
809            }
810            this.n = n;
811            this.m = m;
812
813            weights = new double[n][m];
814            for (int i = 0; i < n; i++) {
815                for (int j = 0; j < m; j++) {
816                weights[i][j] = 1;
817                }
818            }
819            minWeight = 1;
820            maxWeight = Double.NEGATIVE_INFINITY;
821
822            sMatches = new int[n];
823            tMatches = new int[m];
824            sLabels = new int[n];
825            tLabels = new int[m];
826            u = new double[n];
827            v = new double[m];
828            pi = new double[m];
829           
830        }
831        /**
832         * Sets the weight w<sub>ij</sub> to the given value w.
833         *
834         * @throws IllegalArgumentException if i or j is outside the range [0, n).
835         */
836        public void setWeight(int i, int j, double w) {
837        if (n == -1 || m == -1) {
838            throw new IllegalStateException("Graph size not specified.");
839        }
840        if ((i < 0) || (i >= n)) {
841            throw new IllegalArgumentException("i-value out of range: " + i);
842        }
843        if ((j < 0) || (j >= m)) {
844            throw new IllegalArgumentException("j-value out of range: " + j);
845        }
846        if (Double.isNaN(w)) {
847            throw new IllegalArgumentException("Illegal weight: " + w);
848        }
849
850        weights[i][j] = w;
851        if ((w > Double.NEGATIVE_INFINITY) && (w < minWeight)) {
852            minWeight = w;
853        }
854        if (w > maxWeight) {
855            maxWeight = w;
856        }
857        }
858
859        /**
860         * Returns a maximum-weight perfect matching relative to the weights
861         * specified with setWeight.  The matching is represented as an array arr
862         * of length n, where arr[i] = j if (i,j) is in the matching.
863         */
864        public int[] getMatching() {
865        if (n == -1 || m == -1 ) {
866            throw new IllegalStateException("Graph size not specified.");
867        }
868        if (n == 0) {
869            return new int[0];
870        }
871        ensurePositiveWeights();
872
873        // Step 0: Initialization
874        eligibleS.clear();
875        eligibleT.clear();
876        for (Integer i = 0; i < n; i++) {
877            sMatches[i] = -1;
878
879            u[i] = maxWeight; // ambiguous on p. 205 of Lawler, but see p. 202
880
881            // this is really first run of Step 1.0
882            sLabels[i] = EMPTY_LABEL;
883            eligibleS.add(i);
884        }
885
886        for (int j = 0; j < m; j++) {
887            tMatches[j] = -1;
888
889            v[j] = 0;
890            pi[j] = Double.POSITIVE_INFINITY;
891
892            // this is really first run of Step 1.0
893            tLabels[j] = NO_LABEL;
894        }
895       
896        while (true) {
897            // Augment the matching until we can't augment any more given the
898            // current settings of the dual variables. 
899            while (true) {
900            // Steps 1.1-1.4: Find an augmenting path
901            int lastNode = findAugmentingPath();
902            if (lastNode == -1) {
903                break; // no augmenting path
904            }
905                   
906            // Step 2: Augmentation
907            flipPath(lastNode);
908            for (int i = 0; i < n; i++)
909                sLabels[i] = NO_LABEL;
910           
911            for (int j = 0; j < m; j++) {
912                pi[j] = Double.POSITIVE_INFINITY;
913                tLabels[j] = NO_LABEL;
914            }
915           
916           
917           
918            // This is Step 1.0
919            eligibleS.clear();
920            for (int i = 0; i < n; i++) {
921                if (sMatches[i] == -1) {
922                sLabels[i] = EMPTY_LABEL;
923                eligibleS.add(new Integer(i));
924                }
925            }
926
927           
928            eligibleT.clear();
929            }
930
931            // Step 3: Change the dual variables
932
933            // delta1 = min_i u[i]
934            double delta1 = Double.POSITIVE_INFINITY;
935            for (int i = 0; i < n; i++) {
936            if (u[i] < delta1) {
937                delta1 = u[i];
938            }
939            }
940
941            // delta2 = min_{j : pi[j] > 0} pi[j]
942            double delta2 = Double.POSITIVE_INFINITY;
943            for (int j = 0; j < m; j++) {
944            if ((pi[j] >= TOL) && (pi[j] < delta2)) {
945                delta2 = pi[j];
946            }
947            }
948
949            if (delta1 < delta2) {
950            // In order to make another pi[j] equal 0, we'd need to
951            // make some u[i] negative. 
952            break; // we have a maximum-weight matching
953            }
954               
955            changeDualVars(delta2);
956        }
957
958        int[] matching = new int[n];
959        for (int i = 0; i < n; i++) {
960            matching[i] = sMatches[i];
961        }
962        return matching;
963        }
964
965        /**
966         * Tries to find an augmenting path containing only edges (i,j) for which
967         * u[i] + v[j] = weights[i][j].  If it succeeds, returns the index of the
968         * last node in the path.  Otherwise, returns -1.  In any case, updates
969         * the labels and pi values.
970         */
971        int findAugmentingPath() {
972        while ((!eligibleS.isEmpty()) || (!eligibleT.isEmpty())) {
973            if (!eligibleS.isEmpty()) {
974            int i = ((Integer) eligibleS.get(eligibleS.size() - 1)).
975                intValue();
976            eligibleS.remove(eligibleS.size() - 1);
977            for (int j = 0; j < m; j++) {
978                // If pi[j] has already been decreased essentially
979                // to zero, then j is already labeled, and we
980                // can't decrease pi[j] any more.  Omitting the
981                // pi[j] >= TOL check could lead us to relabel j
982                // unnecessarily, since the diff we compute on the
983                // next line may end up being less than pi[j] due
984                // to floating point imprecision.
985                if ((tMatches[j] != i) && (pi[j] >= TOL)) {
986                double diff = u[i] + v[j] - weights[i][j];
987                if (diff < pi[j]) {
988                    tLabels[j] = i;
989                    pi[j] = diff;
990                    if (pi[j] < TOL) {
991                    eligibleT.add(new Integer(j));
992                    }
993                }
994                }
995            }
996            } else {
997            int j = ((Integer) eligibleT.get(eligibleT.size() - 1)).
998                intValue();
999            eligibleT.remove(eligibleT.size() - 1);
1000            if (tMatches[j] == -1) {
1001                return j; // we've found an augmenting path
1002            }
1003
1004            int i = tMatches[j];
1005            sLabels[i] = j;
1006            eligibleS.add(new Integer(i)); // ok to add twice
1007            }
1008        }
1009
1010        return -1;
1011        }
1012
1013        /**
1014         * Given an augmenting path ending at lastNode, "flips" the path.  This
1015         * means that an edge on the path is in the matching after the flip if
1016         * and only if it was not in the matching before the flip.  An augmenting
1017         * path connects two unmatched nodes, so the result is still a matching.
1018         */
1019        void flipPath(int lastNode) {
1020            while (lastNode != EMPTY_LABEL) {
1021                int parent = tLabels[lastNode];
1022   
1023                // Add (parent, lastNode) to matching.  We don't need to
1024                // explicitly remove any edges from the matching because:
1025                //  * We know at this point that there is no i such that
1026                //    sMatches[i] = lastNode. 
1027                //  * Although there might be some j such that tMatches[j] =
1028                //    parent, that j must be sLabels[parent], and will change
1029                //    tMatches[j] in the next time through this loop. 
1030                sMatches[parent] = lastNode;
1031                tMatches[lastNode] = parent;
1032                           
1033                lastNode = sLabels[parent];
1034            }
1035        }
1036
1037        void changeDualVars(double delta) {
1038            for (int i = 0; i < n; i++) {
1039                if (sLabels[i] != NO_LABEL) {
1040                u[i] -= delta;
1041                }
1042            }
1043               
1044            for (int j = 0; j < m; j++) {
1045                if (pi[j] < TOL) {
1046                v[j] += delta;
1047                } else if (tLabels[j] != NO_LABEL) {
1048                pi[j] -= delta;
1049                if (pi[j] < TOL) {
1050                    eligibleT.add(new Integer(j));
1051                }
1052                }
1053            }
1054        }
1055
1056        /**
1057         * Ensures that all weights are either Double.NEGATIVE_INFINITY,
1058         * or strictly greater than zero.
1059         */
1060        private void ensurePositiveWeights() {
1061            // minWeight is the minimum non-infinite weight
1062            if (minWeight < TOL) {
1063                for (int i = 0; i < n; i++) {
1064                for (int j = 0; j < m; j++) {
1065                    weights[i][j] = weights[i][j] - minWeight + 1;
1066                }
1067                }
1068   
1069                maxWeight = maxWeight - minWeight + 1;
1070                minWeight = 1;
1071            }
1072        }
1073
1074        @SuppressWarnings("unused")
1075        private void printWeights() {
1076            for (int i = 0; i < n; i++) {
1077                for (int j = 0; j < m; j++) {
1078                System.out.print(weights[i][j] + " ");
1079                }
1080                System.out.println("");
1081            }
1082        }
1083    }
1084}
Note: See TracBrowser for help on using the repository browser.