source: trunk/CrossPare/src/de/ugoe/cs/cpdp/training/MetricMatchingTraining.java @ 46

Last change on this file since 46 was 45, checked in by atrautsch, 9 years ago

metric matching configurable

File size: 13.4 KB
Line 
1// Copyright 2015 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.cpdp.training;
16
17import java.util.ArrayList;
18import java.util.HashMap;
19import java.util.Iterator;
20import java.util.Map;
21import java.util.logging.Level;
22import java.util.Random;
23
24import org.apache.commons.collections4.list.SetUniqueList;
25import org.apache.commons.math3.stat.inference.ChiSquareTest;
26import org.apache.commons.math3.stat.correlation.SpearmansCorrelation;
27import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest;
28
29import de.ugoe.cs.util.console.Console;
30import weka.classifiers.AbstractClassifier;
31import weka.classifiers.Classifier;
32import weka.core.Attribute;
33import weka.core.DenseInstance;
34import weka.core.FastVector;
35import weka.core.Instance;
36import weka.core.Instances;
37
38public class MetricMatchingTraining extends WekaBaseTraining implements ISetWiseTestdataAwareTrainingStrategy {
39
40    private SetUniqueList<Instances> traindataSet;
41    private MetricMatch mm;
42    private final Classifier classifier = new MetricMatchingClassifier();
43   
44    private String method;
45    private float threshold;
46   
47    /**
48     * We wrap the classifier here because of classifyInstance
49     * @return
50     */
51    @Override
52    public Classifier getClassifier() {
53        return this.classifier;
54    }
55   
56   
57    @Override
58    public String getName() {
59        return "MetricMatching_" + classifierName;
60    }
61
62
63    @Override
64    public void setMethod(String method) {
65        this.method = method;
66    }
67
68
69    @Override
70    public void setThreshold(String threshold) {
71        this.threshold = Float.parseFloat(threshold);
72    }
73
74        /**
75         * We need the testdata instances to do a metric matching, so in this special case we get this data
76         * before evaluation
77         */
78        @Override
79        public void apply(SetUniqueList<Instances> traindataSet, Instances testdata) {
80                this.traindataSet = traindataSet;
81
82                int rank = 5; // we want at least 5 matching attributes
83                int num = 0;
84                int biggest_num = 0;
85                MetricMatch tmp;
86                MetricMatch biggest = null;
87                for (Instances traindata : this.traindataSet) {
88                        num++;
89                        tmp = new MetricMatch(traindata, testdata);
90                        //tmp.kolmogorovSmirnovTest(0.05);
91                       
92                        if( this.method.equals("spearman") ) {
93                            tmp.spearmansRankCorrelation(this.threshold);
94                        }
95                        else if( this.method.equals("kolmogorov") ) {
96                            tmp.kolmogorovSmirnovTest(this.threshold);
97                        }
98                        else {
99                            throw new RuntimeException("unknown method");
100                        }
101
102                        // we only select the training data from our set with the most matching attributes
103                        if(tmp.getRank() > rank) {
104                                rank = tmp.getRank();
105                                biggest = tmp;
106                                biggest_num = num;
107                        }
108                }
109               
110                if( biggest == null ) {
111                    throw new RuntimeException("not enough matching attributes found");
112                }
113
114                // we use the best match
115               
116                this.mm = biggest;
117                Instances ilist = this.mm.getMatchedTrain();
118                Console.traceln(Level.INFO, "Chosing the trainingdata set num "+biggest_num +" with " + rank + " matching attributs, " + ilist.size() + " instances out of a possible set of " + traindataSet.size() + " sets");
119               
120                // we have to build the classifier here:
121                try {
122                   
123                        //
124                    if( this.classifier == null ) {
125                        Console.traceln(Level.SEVERE, "Classifier is null");
126                    }
127                        //Console.traceln(Level.INFO, "Building classifier with the matched training data with " + ilist.size() + " instances and "+ ilist.numAttributes() + " attributes");
128                        this.classifier.buildClassifier(ilist);
129                        ((MetricMatchingClassifier) this.classifier).setMetricMatching(this.mm);
130                }catch(Exception e) {
131                        e.printStackTrace();
132                        throw new RuntimeException(e);
133                }
134        }
135
136       
137        /**
138         * encapsulates the classifier configured with WekaBase
139         */
140        public class MetricMatchingClassifier extends AbstractClassifier {
141
142                private static final long serialVersionUID = -1342172153473770935L;
143                private MetricMatch mm;
144                private Classifier classifier;
145               
146                @Override
147                public void buildClassifier(Instances traindata) throws Exception {
148                        this.classifier = setupClassifier();  // parent method from WekaBase
149                        this.classifier.buildClassifier(traindata);
150                }
151
152                public void setMetricMatching(MetricMatch mm) {
153                        this.mm = mm;
154                }
155               
156                /**
157                 * Here we can not do the metric matching because we only get one instance
158                 */
159                public double classifyInstance(Instance testdata) {
160                        // todo: maybe we can pull the instance out of our matched testdata?
161                        Instance ntest = this.mm.getMatchedTestInstance(testdata);
162
163                        double ret = 0.0;
164                        try {
165                                ret = this.classifier.classifyInstance(ntest);
166                        }catch(Exception e) {
167                                e.printStackTrace();
168                                throw new RuntimeException(e);
169                        }
170                       
171                        return ret;
172                }
173        }
174       
175        /**
176         * Encapsulates MetricMatching on Instances Arrays
177         */
178    public class MetricMatch {
179                 Instances train;
180                 Instances test;
181                 
182                 HashMap<Integer, Integer> attributes = new HashMap<Integer,Integer>();
183                 
184                 ArrayList<double[]> train_values;
185                 ArrayList<double[]> test_values;
186                 
187                 // todo: this constructor does not work
188                 public MetricMatch() {
189                 }
190                 
191                 public MetricMatch(Instances train, Instances test) {
192                         this.train = train;
193                         this.test = test;
194                         
195                         // 1. convert metrics of testdata and traindata to later use in test
196                         this.train_values = new ArrayList<double[]>();
197                         for (int i = 0; i < this.train.numAttributes()-1; i++) {
198                                this.train_values.add(train.attributeToDoubleArray(i));
199                         }
200                       
201                         this.test_values = new ArrayList<double[]>();
202                         for( int i=0; i < this.test.numAttributes()-1; i++ ) {
203                                this.test_values.add(this.test.attributeToDoubleArray(i));
204                         }
205                 }
206                 
207                 /**
208                  * returns the number of matched attributes
209                  * as a way of scoring traindata sets individually
210                  *
211                  * @return
212                  */
213                 public int getRank() {
214                         return this.attributes.size();
215                 }
216                 
217                 public int getNumInstances() {
218                     return this.train_values.get(0).length;
219                 }
220                 
221                 public Instance getMatchedTestInstance(Instance test) {
222                         // create new instance with our matched number of attributes + 1 (the class attribute)
223                         //Console.traceln(Level.INFO, "getting matched instance");
224                         Instances testdata = this.getMatchedTest();
225                         
226                         //Instance ni = new DenseInstance(this.attmatch.size()+1);
227                         Instance ni = new DenseInstance(this.attributes.size()+1);
228                         ni.setDataset(testdata);
229                         
230                         //Console.traceln(Level.INFO, "Attributes to match: " + this.attmatch.size() + "");
231                         
232                         Iterator it = this.attributes.entrySet().iterator();
233                         int j = 0;
234                         while(it.hasNext()) {
235                                 Map.Entry values = (Map.Entry)it.next();
236                                 ni.setValue(testdata.attribute(j), test.value((int)values.getValue()));
237                                 j++;
238                                 
239                         }
240                         
241                         ni.setClassValue(test.value(test.classAttribute()));
242                         
243                         //System.out.println(ni);
244                         return ni;
245                 }
246
247         /**
248          * returns a new instances array with the metric matched training data
249          *
250          * @return instances
251          */
252                 public Instances getMatchedTrain() {
253                         return this.getMatchedInstances("train", this.train);
254                 }
255                 
256                 /**
257                  * returns a new instances array with the metric matched test data
258                  *
259                  * @return instances
260                  */
261                 public Instances getMatchedTest() {
262                         return this.getMatchedInstances("test", this.test);
263                 }
264                 
265                 // https://weka.wikispaces.com/Programmatic+Use
266                 private Instances getMatchedInstances(String name, Instances data) {
267                         // construct our new attributes
268                         Attribute[] attrs = new Attribute[this.attributes.size()+1];
269                         FastVector fwTrain = new FastVector(this.attributes.size());
270                         for(int i=0; i < this.attributes.size(); i++) {
271                                 attrs[i] = new Attribute(String.valueOf(i));
272                                 fwTrain.addElement(attrs[i]);
273                         }
274                         // add our ClassAttribute (which is not numeric!)
275                         ArrayList<String> acl= new ArrayList<String>();
276                         acl.add("0");
277                         acl.add("1");
278                         
279                         fwTrain.addElement(new Attribute("bug", acl));
280                         Instances newTrain = new Instances(name, fwTrain, data.size());
281                         newTrain.setClassIndex(newTrain.numAttributes()-1);
282                         
283                         for(int i=0; i < data.size(); i++) {
284                                 Instance ni = new DenseInstance(this.attributes.size()+1);
285                               
286                                 Iterator it = this.attributes.entrySet().iterator();
287                                 int j = 0;
288                                 while(it.hasNext()) {
289                                         Map.Entry values = (Map.Entry)it.next();
290                                         int value = (int)values.getValue();
291                                         
292                                         // key ist traindata
293                                         if(name.equals("train")) {
294                                                 value = (int)values.getKey();
295                                         }
296                                         ni.setValue(newTrain.attribute(j), data.instance(i).value(value));
297                                         j++;
298                                 }
299                                 ni.setValue(ni.numAttributes()-1, data.instance(i).value(data.classAttribute()));
300                                 
301                                 newTrain.add(ni);
302                         }
303                         
304                         return newTrain;
305                 }
306                 
307                 /**
308                  * calculate Spearmans rank correlation coefficient as matching score
309                  *
310                  * @param cutoff
311                  */
312                 public void spearmansRankCorrelation(double cutoff) {
313                         double p = 0;                   
314                         SpearmansCorrelation t = new SpearmansCorrelation();
315
316                         // size has to be the same so we randomly sample the number of the smaller sample from the big sample
317                         if( this.train.size() > this.test.size() ) {
318                             this.sample(this.train, this.test, this.train_values);
319                         }else if( this.test.size() > this.train.size() ) {
320                             this.sample(this.test, this.train, this.test_values);
321                         }
322                         
323                         // try out possible attribute combinations
324            for( int i=0; i < this.train.numAttributes()-1; i++ ) {
325                for ( int j=0; j < this.test.numAttributes()-1; j++ ) {
326                    // class attributes are not relevant
327                    if ( this.train.classIndex() == i ) {
328                        continue;
329                    }
330                    if ( this.test.classIndex() == j ) {
331                        continue;
332                    }
333                   
334                   
335                                        if( !this.attributes.containsKey(i) ) {
336                                                p = t.correlation(this.train_values.get(i), this.test_values.get(j));
337                                                if( p > cutoff ) {
338                                                        this.attributes.put(i, j);
339                                                }
340                                        }
341                                }
342                    }
343        }
344
345               
346        public void sample(Instances bigger, Instances smaller, ArrayList<double[]> values) {
347            // we want to at keep the indices we select the same
348            int indices_to_draw = smaller.size();
349            ArrayList<Integer> indices = new ArrayList<Integer>();
350            Random rand = new Random();
351            while( indices_to_draw > 0) {
352               
353                int index = rand.nextInt(bigger.size()-1);
354               
355                if( !indices.contains(index) ) {
356                    indices.add(index);
357                    indices_to_draw--;
358                }
359            }
360           
361            // now reduce our values to the indices we choose above for every attribute
362            for(int att=0; att < bigger.numAttributes()-1; att++ ) {
363               
364                // get double for the att
365                double[] vals = values.get(att);
366                double[] new_vals = new double[indices.size()];
367               
368                int i = 0;
369                for( Iterator<Integer> it = indices.iterator(); it.hasNext(); ) {
370                    new_vals[i] = vals[it.next()];
371                    i++;
372                }
373               
374                values.set(att, new_vals);
375            }
376                }
377               
378               
379                /**
380                 * We run the kolmogorov-smirnov test on the data from our test an traindata
381                 * if the p value is above the cutoff we include it in the results
382                 * p value tends to be 0 when the distributions of the data are significantly different
383                 * but we want them to be the same
384                 *
385                 * @param cutoff
386                 * @return p-val
387                 */
388                public void kolmogorovSmirnovTest(double cutoff) {
389                        double p = 0;
390                       
391                        KolmogorovSmirnovTest t = new KolmogorovSmirnovTest();
392
393                        // todo: this should be symmetrical we don't have to compare i to j and then j to i
394                        // todo: this relies on the last attribute being the class,
395                        //Console.traceln(Level.INFO, "Starting Kolmogorov-Smirnov test for traindata size: " + this.train.size() + " attributes("+this.train.numAttributes()+") and testdata size: " + this.test.size() + " attributes("+this.test.numAttributes()+")");
396                        for( int i=0; i < this.train.numAttributes()-1; i++ ) {
397                                for ( int j=0; j < this.test.numAttributes()-1; j++) {
398                                        //p = t.kolmogorovSmirnovTest(this.train_values.get(i), this.test_values.get(j));
399                                        //p = t.kolmogorovSmirnovTest(this.train_values.get(i), this.test_values.get(j));
400                    // class attributes are not relevant
401                    if ( this.train.classIndex() == i ) {
402                        continue;
403                    }
404                    if ( this.test.classIndex() == j ) {
405                        continue;
406                    }
407                                        // PRoblem: exactP is forced for small sample sizes and it never finishes
408                                        if( !this.attributes.containsKey(i) ) {
409                                               
410                                                // todo: output the values and complain on the math.commons mailinglist
411                                                p = t.kolmogorovSmirnovTest(this.train_values.get(i), this.test_values.get(j));
412                                                if( p > cutoff ) {
413                                                        this.attributes.put(i, j);
414                                                }
415                                        }
416                                }
417                        }
418
419                        //Console.traceln(Level.INFO, "Found " + this.attmatch.size() + " matching attributes");
420                }
421         }
422}
Note: See TracBrowser for help on using the repository browser.