source: trunk/CrossPare/src/de/ugoe/cs/cpdp/training/MetricMatchingTraining.java @ 74

Last change on this file since 74 was 49, checked in by atrautsch, 9 years ago

PAnalyzer implemented

File size: 19.0 KB
Line 
1// Copyright 2015 Georg-August-Universit�t G�ttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.cpdp.training;
16
17import java.util.ArrayList;
18import java.util.Arrays;
19import java.util.Collections;
20import java.util.Comparator;
21import java.util.HashMap;
22import java.util.Iterator;
23import java.util.LinkedHashMap;
24import java.util.LinkedList;
25import java.util.List;
26import java.util.Map;
27import java.util.logging.Level;
28
29import javax.management.RuntimeErrorException;
30
31import java.util.Random;
32
33import org.apache.commons.collections4.list.SetUniqueList;
34import org.apache.commons.math3.stat.inference.ChiSquareTest;
35import org.apache.commons.math3.stat.correlation.SpearmansCorrelation;
36import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest;
37
38import de.ugoe.cs.util.console.Console;
39import weka.attributeSelection.SignificanceAttributeEval;
40import weka.classifiers.AbstractClassifier;
41import weka.classifiers.Classifier;
42import weka.core.Attribute;
43import weka.core.DenseInstance;
44import weka.core.FastVector;
45import weka.core.Instance;
46import weka.core.Instances;
47
48
49public class MetricMatchingTraining extends WekaBaseTraining implements ISetWiseTestdataAwareTrainingStrategy {
50
51    private SetUniqueList<Instances> traindataSet;
52    private MetricMatch mm;
53    private final Classifier classifier = new MetricMatchingClassifier();
54   
55    private String method;
56    private float threshold;
57   
58    /**
59     * We wrap the classifier here because of classifyInstance
60     * @return
61     */
62    @Override
63    public Classifier getClassifier() {
64        return this.classifier;
65    }
66
67
68    @Override
69    public void setMethod(String method) {
70        this.method = method;
71    }
72
73
74    @Override
75    public void setThreshold(String threshold) {
76        this.threshold = Float.parseFloat(threshold);
77    }
78
79        /**
80         * We need the testdata instances to do a metric matching, so in this special case we get this data
81         * before evaluation
82         */
83        @Override
84        public void apply(SetUniqueList<Instances> traindataSet, Instances testdata) {
85                this.traindataSet = traindataSet;
86
87                int rank = 0; // we want at least 5 matching attributes
88                int num = 0;
89                int biggest_num = 0;
90                MetricMatch tmp;
91                MetricMatch biggest = null;
92                for (Instances traindata : this.traindataSet) {
93                        num++;
94                        tmp = new MetricMatch(traindata, testdata);
95
96                        // metric selection may create error, continue to next training set
97                        try {
98                                tmp.attributeSelection();
99                        }catch(Exception e) {
100                                e.printStackTrace();
101                                throw new RuntimeException(e);
102                        }
103                       
104                        if (this.method.equals("spearman")) {
105                            tmp.spearmansRankCorrelation(this.threshold);
106                        }
107                        else if (this.method.equals("kolmogorov")) {
108                            tmp.kolmogorovSmirnovTest(this.threshold);
109                        }
110                        else if( this.method.equals("percentile") ) {
111                            tmp.percentiles(this.threshold);
112                        }
113                        else {
114                            throw new RuntimeException("unknown method");
115                        }
116
117                        // we only select the training data from our set with the most matching attributes
118                        if (tmp.getRank() > rank) {
119                                rank = tmp.getRank();
120                                biggest = tmp;
121                                biggest_num = num;
122                        }
123                }
124               
125                if (biggest == null) {
126                    throw new RuntimeException("not enough matching attributes found");
127                }
128
129                // we use the best match
130               
131                this.mm = biggest;
132                Instances ilist = this.mm.getMatchedTrain();
133                Console.traceln(Level.INFO, "Chosing the trainingdata set num "+biggest_num +" with " + rank + " matching attributes, " + ilist.size() + " instances out of a possible set of " + traindataSet.size() + " sets");
134               
135                // replace traindataSEt
136                //traindataSet = new SetUniqueList<Instances>();
137                traindataSet.clear();
138                traindataSet.add(ilist);
139               
140                // we have to build the classifier here:
141                try {
142                   
143                        //
144                    if (this.classifier == null) {
145                        Console.traceln(Level.SEVERE, "Classifier is null");
146                    }
147                        //Console.traceln(Level.INFO, "Building classifier with the matched training data with " + ilist.size() + " instances and "+ ilist.numAttributes() + " attributes");
148                        this.classifier.buildClassifier(ilist);
149                        ((MetricMatchingClassifier) this.classifier).setMetricMatching(this.mm);
150                }catch(Exception e) {
151                        e.printStackTrace();
152                        throw new RuntimeException(e);
153                }
154        }
155
156       
157        /**
158         * encapsulates the classifier configured with WekaBase
159         */
160        public class MetricMatchingClassifier extends AbstractClassifier {
161
162                private static final long serialVersionUID = -1342172153473770935L;
163                private MetricMatch mm;
164                private Classifier classifier;
165               
166                @Override
167                public void buildClassifier(Instances traindata) throws Exception {
168                        this.classifier = setupClassifier();  // parent method from WekaBase
169                        this.classifier.buildClassifier(traindata);
170                }
171
172                public void setMetricMatching(MetricMatch mm) {
173                        this.mm = mm;
174                }
175               
176                /**
177                 * Here we can not do the metric matching because we only get one instance
178                 */
179                public double classifyInstance(Instance testdata) {
180                        // todo: maybe we can pull the instance out of our matched testdata?
181                        Instance ntest = this.mm.getMatchedTestInstance(testdata);
182
183                        double ret = 0.0;
184                        try {
185                                ret = this.classifier.classifyInstance(ntest);
186                        }catch(Exception e) {
187                                e.printStackTrace();
188                                throw new RuntimeException(e);
189                        }
190                       
191                        return ret;
192                }
193        }
194       
195        /**
196         * Encapsulates MetricMatching on Instances Arrays
197         */
198    public class MetricMatch {
199                 Instances train;
200                 Instances test;
201                 
202                 HashMap<Integer, Integer> attributes = new HashMap<Integer,Integer>();
203                 
204                 ArrayList<double[]> train_values;
205                 ArrayList<double[]> test_values;
206                 
207                 // todo: this constructor does not work
208                 public MetricMatch() {
209                 }
210                 
211                 public MetricMatch(Instances train, Instances test) {
212                         this.train = new Instances(train);  // expensive! but we are dropping the attributes
213                         this.test = new Instances(test);
214                         
215                         // 1. convert metrics of testdata and traindata to later use in test
216                         this.train_values = new ArrayList<double[]>();
217                         for (int i = 0; i < this.train.numAttributes()-1; i++) {
218                                this.train_values.add(train.attributeToDoubleArray(i));
219                         }
220                       
221                         this.test_values = new ArrayList<double[]>();
222                         for (int i=0; i < this.test.numAttributes()-1; i++) {
223                                this.test_values.add(this.test.attributeToDoubleArray(i));
224                         }
225                 }
226                 
227                 /**
228                  * returns the number of matched attributes
229                  * as a way of scoring traindata sets individually
230                  *
231                  * @return
232                  */
233                 public int getRank() {
234                         return this.attributes.size();
235                 }
236                 
237                 public int getNumInstances() {
238                     return this.train_values.get(0).length;
239                 }
240                 
241                 public Instance getMatchedTestInstance(Instance test) {
242                         // create new instance with our matched number of attributes + 1 (the class attribute)
243                         //Console.traceln(Level.INFO, "getting matched instance");
244                         Instances testdata = this.getMatchedTest();
245                         
246                         //Instance ni = new DenseInstance(this.attmatch.size()+1);
247                         Instance ni = new DenseInstance(this.attributes.size()+1);
248                         ni.setDataset(testdata);
249                         
250                         //Console.traceln(Level.INFO, "Attributes to match: " + this.attmatch.size() + "");
251                         
252                         Iterator it = this.attributes.entrySet().iterator();
253                         int j = 0;
254                         while (it.hasNext()) {
255                                 Map.Entry values = (Map.Entry)it.next();
256                                 ni.setValue(testdata.attribute(j), test.value((int)values.getValue()));
257                                 j++;
258                                 
259                         }
260                         
261                         ni.setClassValue(test.value(test.classAttribute()));
262                         
263                         //System.out.println(ni);
264                         return ni;
265                 }
266
267         /**
268          * returns a new instances array with the metric matched training data
269          *
270          * @return instances
271          */
272                 public Instances getMatchedTrain() {
273                         return this.getMatchedInstances("train", this.train);
274                 }
275                 
276                 /**
277                  * returns a new instances array with the metric matched test data
278                  *
279                  * @return instances
280                  */
281                 public Instances getMatchedTest() {
282                         return this.getMatchedInstances("test", this.test);
283                 }
284                 
285                 // https://weka.wikispaces.com/Programmatic+Use
286                 private Instances getMatchedInstances(String name, Instances data) {
287                         // construct our new attributes
288                         Attribute[] attrs = new Attribute[this.attributes.size()+1];
289                         FastVector fwTrain = new FastVector(this.attributes.size());
290                         for (int i=0; i < this.attributes.size(); i++) {
291                                 attrs[i] = new Attribute(String.valueOf(i));
292                                 fwTrain.addElement(attrs[i]);
293                         }
294                         // add our ClassAttribute (which is not numeric!)
295                         ArrayList<String> acl= new ArrayList<String>();
296                         acl.add("0");
297                         acl.add("1");
298                         
299                         fwTrain.addElement(new Attribute("bug", acl));
300                         Instances newTrain = new Instances(name, fwTrain, data.size());
301                         newTrain.setClassIndex(newTrain.numAttributes()-1);
302                         
303                         for (int i=0; i < data.size(); i++) {
304                                 Instance ni = new DenseInstance(this.attributes.size()+1);
305                               
306                                 Iterator it = this.attributes.entrySet().iterator();
307                                 int j = 0;
308                                 while (it.hasNext()) {
309                                         Map.Entry values = (Map.Entry)it.next();
310                                         int value = (int)values.getValue();
311                                         
312                                         // key ist traindata
313                                         if (name.equals("train")) {
314                                                 value = (int)values.getKey();
315                                         }
316                                         ni.setValue(newTrain.attribute(j), data.instance(i).value(value));
317                                         j++;
318                                 }
319                                 ni.setValue(ni.numAttributes()-1, data.instance(i).value(data.classAttribute()));
320                                 
321                                 newTrain.add(ni);
322                         }
323                         
324                    return newTrain;
325        }
326                 
327                 
328                /**
329                 * performs the attribute selection
330                 * we perform attribute significance tests and drop attributes
331                 */
332                public void attributeSelection() throws Exception {
333                        //Console.traceln(Level.INFO, "Attribute Selection on Training Attributes ("+this.train.numAttributes()+")");
334                        this.attributeSelection(this.train);
335                        //Console.traceln(Level.INFO, "-----");
336                        //Console.traceln(Level.INFO, "Attribute Selection on Test Attributes ("+this.test.numAttributes()+")");
337                        this.attributeSelection(this.test);
338                        //Console.traceln(Level.INFO, "-----");
339                }
340               
341                private void attributeSelection(Instances which) throws Exception {
342                        // 1. step we have to categorize the attributes
343                        //http://weka.sourceforge.net/doc.packages/probabilisticSignificanceAE/weka/attributeSelection/SignificanceAttributeEval.html
344                       
345                        SignificanceAttributeEval et = new SignificanceAttributeEval();
346                        et.buildEvaluator(which);
347                        //double tmp[] = new double[this.train.numAttributes()];
348                        HashMap<String,Double> saeval = new HashMap<String,Double>();
349                        // evaluate all training attributes
350                        // select top 15% of metrics
351                        for(int i=0; i < which.numAttributes(); i++) {
352                                if(which.classIndex() != i) {
353                                        saeval.put(which.attribute(i).name(), et.evaluateAttribute(i));
354                                }
355                                //Console.traceln(Level.SEVERE, "Significance Attribute Eval: " + tmp);
356                        }
357                       
358                        HashMap<String, Double> sorted = sortByValues(saeval);
359                       
360                        // die letzen 15% wollen wir haben
361                        float last = ((float)saeval.size() / 100) * 15;
362                        int drop_first = saeval.size() - (int)last;
363                       
364                        //Console.traceln(Level.INFO, "Dropping "+ drop_first + " of " + sorted.size() + " attributes (we only keep the best 15% "+last+")");
365                        /*
366                        Iterator it = sorted.entrySet().iterator();
367                    while (it.hasNext()) {
368                        Map.Entry pair = (Map.Entry)it.next();
369                        Console.traceln(Level.INFO, "key: "+(int)pair.getKey()+", value: " + (double)pair.getValue() + "");
370                    }*/
371                       
372                        // drop attributes above last
373                        Iterator it = sorted.entrySet().iterator();
374                    while (drop_first > 0) {
375                        Map.Entry pair = (Map.Entry)it.next();
376                        if(which.attribute((String)pair.getKey()).index() != which.classIndex()) {
377                               
378                                which.deleteAttributeAt(which.attribute((String)pair.getKey()).index());
379                                //Console.traceln(Level.INFO, "dropping attribute: "+ (String)pair.getKey());
380                        }
381                        drop_first-=1;
382                 
383                   
384                    }
385//                  //Console.traceln(Level.INFO, "Now we have " + which.numAttributes() + " attributes left (incl. class attribute!)");
386                }
387               
388               
389                private HashMap sortByValues(HashMap map) {
390               List list = new LinkedList(map.entrySet());
391
392               Collections.sort(list, new Comparator() {
393                    public int compare(Object o1, Object o2) {
394                       return ((Comparable) ((Map.Entry) (o1)).getValue())
395                          .compareTo(((Map.Entry) (o2)).getValue());
396                    }
397               });
398
399
400               HashMap sortedHashMap = new LinkedHashMap();
401               for (Iterator it = list.iterator(); it.hasNext();) {
402                      Map.Entry entry = (Map.Entry) it.next();
403                      sortedHashMap.put(entry.getKey(), entry.getValue());
404               }
405               return sortedHashMap;
406                }
407                 
408                /**
409                 * Calculates the Percentiles of the source and target metrics.
410                 *
411                 * @param cutoff
412                 */
413                public void percentiles(double cutoff) {
414                    for( int i = 0; i < this.train.numAttributes()-1; i++ ) {
415                for( int j = 0; j < this.test.numAttributes()-1; j++ ) {
416                    // class attributes are not relevant
417                    if( this.train.classIndex() == i ) {
418                        continue;
419                    }
420                    if( this.test.classIndex() == j ) {
421                        continue;
422                    }
423                   
424                   
425                    if( !this.attributes.containsKey(i) ) {
426                        // get percentiles
427                        double train[] = this.train_values.get(i);
428                        double test[] = this.test_values.get(j);
429                       
430                        Arrays.sort(train);
431                        Arrays.sort(test);
432                       
433                        // percentiles
434                        double train_p;
435                        double test_p;
436                        double score = 0.0;
437                        for( double p=0.1; p < 1; p+=0.1 ) {
438                            train_p = train[(int)Math.ceil(train.length * p)];
439                            test_p = test[(int)Math.ceil(test.length * p)];
440                       
441                            if( train_p > test_p ) {
442                                score += test_p / train_p;
443                            }else {
444                                score += train_p / test_p;
445                            }
446                        }
447                       
448                        if( score > cutoff ) {
449                            this.attributes.put(i, j);
450                        }
451                    }
452                }
453            }
454                }
455                 
456                 /**
457                  * calculate Spearmans rank correlation coefficient as matching score
458                  *
459                  * @param cutoff
460                  */
461                 public void spearmansRankCorrelation(double cutoff) {
462                         double p = 0;                   
463                         SpearmansCorrelation t = new SpearmansCorrelation();
464
465                         // size has to be the same so we randomly sample the number of the smaller sample from the big sample
466                         if (this.train.size() > this.test.size()) {
467                             this.sample(this.train, this.test, this.train_values);
468                         }else if (this.test.size() > this.train.size()) {
469                             this.sample(this.test, this.train, this.test_values);
470                         }
471                         
472                         // try out possible attribute combinations
473            for (int i=0; i < this.train.numAttributes()-1; i++) {
474                for (int j=0; j < this.test.numAttributes()-1; j++) {
475                    // class attributes are not relevant
476                    if (this.train.classIndex() == i) {
477                        continue;
478                    }
479                    if (this.test.classIndex() == j) {
480                        continue;
481                    }
482                   
483                   
484                                        if (!this.attributes.containsKey(i)) {
485                                                p = t.correlation(this.train_values.get(i), this.test_values.get(j));
486                                                if (p > cutoff) {
487                                                        this.attributes.put(i, j);
488                                                }
489                                        }
490                                }
491                    }
492        }
493
494               
495        public void sample(Instances bigger, Instances smaller, ArrayList<double[]> values) {
496            // we want to at keep the indices we select the same
497            int indices_to_draw = smaller.size();
498            ArrayList<Integer> indices = new ArrayList<Integer>();
499            Random rand = new Random();
500            while (indices_to_draw > 0) {
501               
502                int index = rand.nextInt(bigger.size()-1);
503               
504                if (!indices.contains(index)) {
505                    indices.add(index);
506                    indices_to_draw--;
507                }
508            }
509           
510            // now reduce our values to the indices we choose above for every attribute
511            for (int att=0; att < bigger.numAttributes()-1; att++) {
512               
513                // get double for the att
514                double[] vals = values.get(att);
515                double[] new_vals = new double[indices.size()];
516               
517                int i = 0;
518                for (Iterator<Integer> it = indices.iterator(); it.hasNext();) {
519                    new_vals[i] = vals[it.next()];
520                    i++;
521                }
522               
523                values.set(att, new_vals);
524            }
525                }
526               
527               
528                /**
529                 * We run the kolmogorov-smirnov test on the data from our test an traindata
530                 * if the p value is above the cutoff we include it in the results
531                 * p value tends to be 0 when the distributions of the data are significantly different
532                 * but we want them to be the same
533                 *
534                 * @param cutoff
535                 * @return p-val
536                 */
537                public void kolmogorovSmirnovTest(double cutoff) {
538                        double p = 0;
539                       
540                        KolmogorovSmirnovTest t = new KolmogorovSmirnovTest();
541
542                        // todo: this should be symmetrical we don't have to compare i to j and then j to i
543                        // todo: this relies on the last attribute being the class,
544                        //Console.traceln(Level.INFO, "Starting Kolmogorov-Smirnov test for traindata size: " + this.train.size() + " attributes("+this.train.numAttributes()+") and testdata size: " + this.test.size() + " attributes("+this.test.numAttributes()+")");
545                        for (int i=0; i < this.train.numAttributes()-1; i++) {
546                                for ( int j=0; j < this.test.numAttributes()-1; j++) {
547                                        //p = t.kolmogorovSmirnovTest(this.train_values.get(i), this.test_values.get(j));
548                                        //p = t.kolmogorovSmirnovTest(this.train_values.get(i), this.test_values.get(j));
549                    // class attributes are not relevant
550                    if ( this.train.classIndex() == i ) {
551                        continue;
552                    }
553                    if ( this.test.classIndex() == j ) {
554                        continue;
555                    }
556                                        // PRoblem: exactP is forced for small sample sizes and it never finishes
557                                        if (!this.attributes.containsKey(i)) {
558                                               
559                                                // todo: output the values and complain on the math.commons mailinglist
560                                                p = t.kolmogorovSmirnovTest(this.train_values.get(i), this.test_values.get(j));
561                                                if (p > cutoff) {
562                                                        this.attributes.put(i, j);
563                                                }
564                                        }
565                                }
566                        }
567
568                        //Console.traceln(Level.INFO, "Found " + this.attmatch.size() + " matching attributes");
569                }
570         }
571}
Note: See TracBrowser for help on using the repository browser.