source: trunk/CrossPare/src/de/ugoe/cs/cpdp/training/MetricMatchingTraining.java @ 48

Last change on this file since 48 was 48, checked in by atrautsch, 9 years ago

final attribute selection

File size: 17.1 KB
Line 
1// Copyright 2015 Georg-August-Universit�t G�ttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.cpdp.training;
16
17import java.util.ArrayList;
18import java.util.Arrays;
19import java.util.Collections;
20import java.util.Comparator;
21import java.util.HashMap;
22import java.util.Iterator;
23import java.util.LinkedHashMap;
24import java.util.LinkedList;
25import java.util.List;
26import java.util.Map;
27import java.util.logging.Level;
28
29import javax.management.RuntimeErrorException;
30
31import java.util.Random;
32
33import org.apache.commons.collections4.list.SetUniqueList;
34import org.apache.commons.math3.stat.inference.ChiSquareTest;
35import org.apache.commons.math3.stat.correlation.SpearmansCorrelation;
36import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest;
37
38import de.ugoe.cs.util.console.Console;
39import weka.attributeSelection.SignificanceAttributeEval;
40import weka.classifiers.AbstractClassifier;
41import weka.classifiers.Classifier;
42import weka.core.Attribute;
43import weka.core.DenseInstance;
44import weka.core.FastVector;
45import weka.core.Instance;
46import weka.core.Instances;
47
48
49public class MetricMatchingTraining extends WekaBaseTraining implements ISetWiseTestdataAwareTrainingStrategy {
50
51    private SetUniqueList<Instances> traindataSet;
52    private MetricMatch mm;
53    private final Classifier classifier = new MetricMatchingClassifier();
54   
55    private String method;
56    private float threshold;
57   
58    /**
59     * We wrap the classifier here because of classifyInstance
60     * @return
61     */
62    @Override
63    public Classifier getClassifier() {
64        return this.classifier;
65    }
66
67
68    @Override
69    public void setMethod(String method) {
70        this.method = method;
71    }
72
73
74    @Override
75    public void setThreshold(String threshold) {
76        this.threshold = Float.parseFloat(threshold);
77    }
78
79        /**
80         * We need the testdata instances to do a metric matching, so in this special case we get this data
81         * before evaluation
82         */
83        @Override
84        public void apply(SetUniqueList<Instances> traindataSet, Instances testdata) {
85                this.traindataSet = traindataSet;
86
87                int rank = 0; // we want at least 5 matching attributes
88                int num = 0;
89                int biggest_num = 0;
90                MetricMatch tmp;
91                MetricMatch biggest = null;
92                for (Instances traindata : this.traindataSet) {
93                        num++;
94                        tmp = new MetricMatch(traindata, testdata);
95
96                        // metric selection may create error, continue to next training set
97                        try {
98                                tmp.attributeSelection();
99                        }catch(Exception e) {
100                                e.printStackTrace();
101                                throw new RuntimeException(e);
102                        }
103                       
104                        if (this.method.equals("spearman")) {
105                            tmp.spearmansRankCorrelation(this.threshold);
106                        }
107                        else if (this.method.equals("kolmogorov")) {
108                            tmp.kolmogorovSmirnovTest(this.threshold);
109                        }
110                        else {
111                            throw new RuntimeException("unknown method");
112                        }
113
114                        // we only select the training data from our set with the most matching attributes
115                        if (tmp.getRank() > rank) {
116                                rank = tmp.getRank();
117                                biggest = tmp;
118                                biggest_num = num;
119                        }
120                }
121               
122                if (biggest == null) {
123                    throw new RuntimeException("not enough matching attributes found");
124                }
125
126                // we use the best match
127               
128                this.mm = biggest;
129                Instances ilist = this.mm.getMatchedTrain();
130                Console.traceln(Level.INFO, "Chosing the trainingdata set num "+biggest_num +" with " + rank + " matching attributes, " + ilist.size() + " instances out of a possible set of " + traindataSet.size() + " sets");
131               
132                // replace traindataSEt
133                //traindataSet = new SetUniqueList<Instances>();
134                traindataSet.clear();
135                traindataSet.add(ilist);
136               
137                // we have to build the classifier here:
138                try {
139                   
140                        //
141                    if (this.classifier == null) {
142                        Console.traceln(Level.SEVERE, "Classifier is null");
143                    }
144                        //Console.traceln(Level.INFO, "Building classifier with the matched training data with " + ilist.size() + " instances and "+ ilist.numAttributes() + " attributes");
145                        this.classifier.buildClassifier(ilist);
146                        ((MetricMatchingClassifier) this.classifier).setMetricMatching(this.mm);
147                }catch(Exception e) {
148                        e.printStackTrace();
149                        throw new RuntimeException(e);
150                }
151        }
152
153       
154        /**
155         * encapsulates the classifier configured with WekaBase
156         */
157        public class MetricMatchingClassifier extends AbstractClassifier {
158
159                private static final long serialVersionUID = -1342172153473770935L;
160                private MetricMatch mm;
161                private Classifier classifier;
162               
163                @Override
164                public void buildClassifier(Instances traindata) throws Exception {
165                        this.classifier = setupClassifier();  // parent method from WekaBase
166                        this.classifier.buildClassifier(traindata);
167                }
168
169                public void setMetricMatching(MetricMatch mm) {
170                        this.mm = mm;
171                }
172               
173                /**
174                 * Here we can not do the metric matching because we only get one instance
175                 */
176                public double classifyInstance(Instance testdata) {
177                        // todo: maybe we can pull the instance out of our matched testdata?
178                        Instance ntest = this.mm.getMatchedTestInstance(testdata);
179
180                        double ret = 0.0;
181                        try {
182                                ret = this.classifier.classifyInstance(ntest);
183                        }catch(Exception e) {
184                                e.printStackTrace();
185                                throw new RuntimeException(e);
186                        }
187                       
188                        return ret;
189                }
190        }
191       
192        /**
193         * Encapsulates MetricMatching on Instances Arrays
194         */
195    public class MetricMatch {
196                 Instances train;
197                 Instances test;
198                 
199                 HashMap<Integer, Integer> attributes = new HashMap<Integer,Integer>();
200                 
201                 ArrayList<double[]> train_values;
202                 ArrayList<double[]> test_values;
203                 
204                 // todo: this constructor does not work
205                 public MetricMatch() {
206                 }
207                 
208                 public MetricMatch(Instances train, Instances test) {
209                         this.train = new Instances(train);  // expensive! but we are dropping the attributes
210                         this.test = new Instances(test);
211                         
212                         // 1. convert metrics of testdata and traindata to later use in test
213                         this.train_values = new ArrayList<double[]>();
214                         for (int i = 0; i < this.train.numAttributes()-1; i++) {
215                                this.train_values.add(train.attributeToDoubleArray(i));
216                         }
217                       
218                         this.test_values = new ArrayList<double[]>();
219                         for (int i=0; i < this.test.numAttributes()-1; i++) {
220                                this.test_values.add(this.test.attributeToDoubleArray(i));
221                         }
222                 }
223                 
224                 /**
225                  * returns the number of matched attributes
226                  * as a way of scoring traindata sets individually
227                  *
228                  * @return
229                  */
230                 public int getRank() {
231                         return this.attributes.size();
232                 }
233                 
234                 public int getNumInstances() {
235                     return this.train_values.get(0).length;
236                 }
237                 
238                 public Instance getMatchedTestInstance(Instance test) {
239                         // create new instance with our matched number of attributes + 1 (the class attribute)
240                         //Console.traceln(Level.INFO, "getting matched instance");
241                         Instances testdata = this.getMatchedTest();
242                         
243                         //Instance ni = new DenseInstance(this.attmatch.size()+1);
244                         Instance ni = new DenseInstance(this.attributes.size()+1);
245                         ni.setDataset(testdata);
246                         
247                         //Console.traceln(Level.INFO, "Attributes to match: " + this.attmatch.size() + "");
248                         
249                         Iterator it = this.attributes.entrySet().iterator();
250                         int j = 0;
251                         while (it.hasNext()) {
252                                 Map.Entry values = (Map.Entry)it.next();
253                                 ni.setValue(testdata.attribute(j), test.value((int)values.getValue()));
254                                 j++;
255                                 
256                         }
257                         
258                         ni.setClassValue(test.value(test.classAttribute()));
259                         
260                         //System.out.println(ni);
261                         return ni;
262                 }
263
264         /**
265          * returns a new instances array with the metric matched training data
266          *
267          * @return instances
268          */
269                 public Instances getMatchedTrain() {
270                         return this.getMatchedInstances("train", this.train);
271                 }
272                 
273                 /**
274                  * returns a new instances array with the metric matched test data
275                  *
276                  * @return instances
277                  */
278                 public Instances getMatchedTest() {
279                         return this.getMatchedInstances("test", this.test);
280                 }
281                 
282                 // https://weka.wikispaces.com/Programmatic+Use
283                 private Instances getMatchedInstances(String name, Instances data) {
284                         // construct our new attributes
285                         Attribute[] attrs = new Attribute[this.attributes.size()+1];
286                         FastVector fwTrain = new FastVector(this.attributes.size());
287                         for (int i=0; i < this.attributes.size(); i++) {
288                                 attrs[i] = new Attribute(String.valueOf(i));
289                                 fwTrain.addElement(attrs[i]);
290                         }
291                         // add our ClassAttribute (which is not numeric!)
292                         ArrayList<String> acl= new ArrayList<String>();
293                         acl.add("0");
294                         acl.add("1");
295                         
296                         fwTrain.addElement(new Attribute("bug", acl));
297                         Instances newTrain = new Instances(name, fwTrain, data.size());
298                         newTrain.setClassIndex(newTrain.numAttributes()-1);
299                         
300                         for (int i=0; i < data.size(); i++) {
301                                 Instance ni = new DenseInstance(this.attributes.size()+1);
302                               
303                                 Iterator it = this.attributes.entrySet().iterator();
304                                 int j = 0;
305                                 while (it.hasNext()) {
306                                         Map.Entry values = (Map.Entry)it.next();
307                                         int value = (int)values.getValue();
308                                         
309                                         // key ist traindata
310                                         if (name.equals("train")) {
311                                                 value = (int)values.getKey();
312                                         }
313                                         ni.setValue(newTrain.attribute(j), data.instance(i).value(value));
314                                         j++;
315                                 }
316                                 ni.setValue(ni.numAttributes()-1, data.instance(i).value(data.classAttribute()));
317                                 
318                                 newTrain.add(ni);
319                         }
320                         
321                    return newTrain;
322        }
323                 
324                 
325                /**
326                 * performs the attribute selection
327                 * we perform attribute significance tests and drop attributes
328                 */
329                public void attributeSelection() throws Exception {
330                        //Console.traceln(Level.INFO, "Attribute Selection on Training Attributes ("+this.train.numAttributes()+")");
331                        this.attributeSelection(this.train);
332                        //Console.traceln(Level.INFO, "-----");
333                        //Console.traceln(Level.INFO, "Attribute Selection on Test Attributes ("+this.test.numAttributes()+")");
334                        this.attributeSelection(this.test);
335                        //Console.traceln(Level.INFO, "-----");
336                }
337               
338                private void attributeSelection(Instances which) throws Exception {
339                        // 1. step we have to categorize the attributes
340                        //http://weka.sourceforge.net/doc.packages/probabilisticSignificanceAE/weka/attributeSelection/SignificanceAttributeEval.html
341                       
342                        SignificanceAttributeEval et = new SignificanceAttributeEval();
343                        et.buildEvaluator(which);
344                        //double tmp[] = new double[this.train.numAttributes()];
345                        HashMap<String,Double> saeval = new HashMap<String,Double>();
346                        // evaluate all training attributes
347                        // select top 15% of metrics
348                        for(int i=0; i < which.numAttributes(); i++) {
349                                if(which.classIndex() != i) {
350                                        saeval.put(which.attribute(i).name(), et.evaluateAttribute(i));
351                                }
352                                //Console.traceln(Level.SEVERE, "Significance Attribute Eval: " + tmp);
353                        }
354                       
355                        HashMap<String, Double> sorted = sortByValues(saeval);
356                       
357                        // die letzen 15% wollen wir haben
358                        float last = ((float)saeval.size() / 100) * 15;
359                        int drop_first = saeval.size() - (int)last;
360                       
361                        //Console.traceln(Level.INFO, "Dropping "+ drop_first + " of " + sorted.size() + " attributes (we only keep the best 15% "+last+")");
362                        /*
363                        Iterator it = sorted.entrySet().iterator();
364                    while (it.hasNext()) {
365                        Map.Entry pair = (Map.Entry)it.next();
366                        Console.traceln(Level.INFO, "key: "+(int)pair.getKey()+", value: " + (double)pair.getValue() + "");
367                    }*/
368                       
369                        // drop attributes above last
370                        Iterator it = sorted.entrySet().iterator();
371                    while (drop_first > 0) {
372                        Map.Entry pair = (Map.Entry)it.next();
373                        if(which.attribute((String)pair.getKey()).index() != which.classIndex()) {
374                               
375                                which.deleteAttributeAt(which.attribute((String)pair.getKey()).index());
376                                //Console.traceln(Level.INFO, "dropping attribute: "+ (String)pair.getKey());
377                        }
378                        drop_first-=1;
379                 
380                   
381                    }
382//                  //Console.traceln(Level.INFO, "Now we have " + which.numAttributes() + " attributes left (incl. class attribute!)");
383                }
384               
385               
386                private HashMap sortByValues(HashMap map) {
387               List list = new LinkedList(map.entrySet());
388
389               Collections.sort(list, new Comparator() {
390                    public int compare(Object o1, Object o2) {
391                       return ((Comparable) ((Map.Entry) (o1)).getValue())
392                          .compareTo(((Map.Entry) (o2)).getValue());
393                    }
394               });
395
396
397               HashMap sortedHashMap = new LinkedHashMap();
398               for (Iterator it = list.iterator(); it.hasNext();) {
399                      Map.Entry entry = (Map.Entry) it.next();
400                      sortedHashMap.put(entry.getKey(), entry.getValue());
401               }
402               return sortedHashMap;
403                }
404                 
405                 
406                 /**
407                  * calculate Spearmans rank correlation coefficient as matching score
408                  *
409                  * @param cutoff
410                  */
411                 public void spearmansRankCorrelation(double cutoff) {
412                         double p = 0;                   
413                         SpearmansCorrelation t = new SpearmansCorrelation();
414
415                         // size has to be the same so we randomly sample the number of the smaller sample from the big sample
416                         if (this.train.size() > this.test.size()) {
417                             this.sample(this.train, this.test, this.train_values);
418                         }else if (this.test.size() > this.train.size()) {
419                             this.sample(this.test, this.train, this.test_values);
420                         }
421                         
422                         // try out possible attribute combinations
423            for (int i=0; i < this.train.numAttributes()-1; i++) {
424                for (int j=0; j < this.test.numAttributes()-1; j++) {
425                    // class attributes are not relevant
426                    if (this.train.classIndex() == i) {
427                        continue;
428                    }
429                    if (this.test.classIndex() == j) {
430                        continue;
431                    }
432                   
433                   
434                                        if (!this.attributes.containsKey(i)) {
435                                                p = t.correlation(this.train_values.get(i), this.test_values.get(j));
436                                                if (p > cutoff) {
437                                                        this.attributes.put(i, j);
438                                                }
439                                        }
440                                }
441                    }
442        }
443
444               
445        public void sample(Instances bigger, Instances smaller, ArrayList<double[]> values) {
446            // we want to at keep the indices we select the same
447            int indices_to_draw = smaller.size();
448            ArrayList<Integer> indices = new ArrayList<Integer>();
449            Random rand = new Random();
450            while (indices_to_draw > 0) {
451               
452                int index = rand.nextInt(bigger.size()-1);
453               
454                if (!indices.contains(index)) {
455                    indices.add(index);
456                    indices_to_draw--;
457                }
458            }
459           
460            // now reduce our values to the indices we choose above for every attribute
461            for (int att=0; att < bigger.numAttributes()-1; att++) {
462               
463                // get double for the att
464                double[] vals = values.get(att);
465                double[] new_vals = new double[indices.size()];
466               
467                int i = 0;
468                for (Iterator<Integer> it = indices.iterator(); it.hasNext();) {
469                    new_vals[i] = vals[it.next()];
470                    i++;
471                }
472               
473                values.set(att, new_vals);
474            }
475                }
476               
477               
478                /**
479                 * We run the kolmogorov-smirnov test on the data from our test an traindata
480                 * if the p value is above the cutoff we include it in the results
481                 * p value tends to be 0 when the distributions of the data are significantly different
482                 * but we want them to be the same
483                 *
484                 * @param cutoff
485                 * @return p-val
486                 */
487                public void kolmogorovSmirnovTest(double cutoff) {
488                        double p = 0;
489                       
490                        KolmogorovSmirnovTest t = new KolmogorovSmirnovTest();
491
492                        // todo: this should be symmetrical we don't have to compare i to j and then j to i
493                        // todo: this relies on the last attribute being the class,
494                        //Console.traceln(Level.INFO, "Starting Kolmogorov-Smirnov test for traindata size: " + this.train.size() + " attributes("+this.train.numAttributes()+") and testdata size: " + this.test.size() + " attributes("+this.test.numAttributes()+")");
495                        for (int i=0; i < this.train.numAttributes()-1; i++) {
496                                for ( int j=0; j < this.test.numAttributes()-1; j++) {
497                                        //p = t.kolmogorovSmirnovTest(this.train_values.get(i), this.test_values.get(j));
498                                        //p = t.kolmogorovSmirnovTest(this.train_values.get(i), this.test_values.get(j));
499                    // class attributes are not relevant
500                    if ( this.train.classIndex() == i ) {
501                        continue;
502                    }
503                    if ( this.test.classIndex() == j ) {
504                        continue;
505                    }
506                                        // PRoblem: exactP is forced for small sample sizes and it never finishes
507                                        if (!this.attributes.containsKey(i)) {
508                                               
509                                                // todo: output the values and complain on the math.commons mailinglist
510                                                p = t.kolmogorovSmirnovTest(this.train_values.get(i), this.test_values.get(j));
511                                                if (p > cutoff) {
512                                                        this.attributes.put(i, j);
513                                                }
514                                        }
515                                }
516                        }
517
518                        //Console.traceln(Level.INFO, "Found " + this.attmatch.size() + " matching attributes");
519                }
520         }
521}
Note: See TracBrowser for help on using the repository browser.