Changeset 47 for trunk


Ignore:
Timestamp:
12/12/15 10:57:31 (9 years ago)
Author:
atrautsch
Message:

selection

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/CrossPare/src/de/ugoe/cs/cpdp/training/MetricMatchingTraining.java

    r45 r47  
    1616 
    1717import java.util.ArrayList; 
     18import java.util.Arrays; 
     19import java.util.Collections; 
     20import java.util.Comparator; 
    1821import java.util.HashMap; 
    1922import java.util.Iterator; 
     23import java.util.LinkedHashMap; 
     24import java.util.LinkedList; 
     25import java.util.List; 
    2026import java.util.Map; 
    2127import java.util.logging.Level; 
     28 
     29import javax.management.RuntimeErrorException; 
     30 
    2231import java.util.Random; 
    2332 
     
    2837 
    2938import de.ugoe.cs.util.console.Console; 
     39import weka.attributeSelection.SignificanceAttributeEval; 
    3040import weka.classifiers.AbstractClassifier; 
    3141import weka.classifiers.Classifier; 
     
    3545import weka.core.Instance; 
    3646import weka.core.Instances; 
     47 
    3748 
    3849public class MetricMatchingTraining extends WekaBaseTraining implements ISetWiseTestdataAwareTrainingStrategy { 
     
    5364        return this.classifier; 
    5465    } 
    55      
    56      
    57     @Override 
    58     public String getName() { 
    59         return "MetricMatching_" + classifierName; 
    60     } 
    6166 
    6267 
     
    9095                        //tmp.kolmogorovSmirnovTest(0.05); 
    9196                         
    92                         if( this.method.equals("spearman") ) { 
     97                        try { 
     98                                tmp.attributeSelection(); 
     99                        }catch(Exception e) { 
     100                                 
     101                        } 
     102                         
     103                        if (this.method.equals("spearman")) { 
    93104                            tmp.spearmansRankCorrelation(this.threshold); 
    94105                        } 
    95                         else if( this.method.equals("kolmogorov") ) { 
     106                        else if (this.method.equals("kolmogorov")) { 
    96107                            tmp.kolmogorovSmirnovTest(this.threshold); 
    97108                        } 
     
    101112 
    102113                        // we only select the training data from our set with the most matching attributes 
    103                         if(tmp.getRank() > rank) { 
     114                        if (tmp.getRank() > rank) { 
    104115                                rank = tmp.getRank(); 
    105116                                biggest = tmp; 
     
    108119                } 
    109120                 
    110                 if( biggest == null ) { 
     121                if (biggest == null) { 
    111122                    throw new RuntimeException("not enough matching attributes found"); 
    112123                } 
     
    117128                Instances ilist = this.mm.getMatchedTrain(); 
    118129                Console.traceln(Level.INFO, "Chosing the trainingdata set num "+biggest_num +" with " + rank + " matching attributs, " + ilist.size() + " instances out of a possible set of " + traindataSet.size() + " sets"); 
     130                 
     131                // replace traindataSEt 
     132                //traindataSet = new SetUniqueList<Instances>(); 
     133                traindataSet.clear(); 
     134                traindataSet.add(ilist); 
    119135                 
    120136                // we have to build the classifier here: 
     
    122138                     
    123139                        // 
    124                     if( this.classifier == null ) { 
     140                    if (this.classifier == null) { 
    125141                        Console.traceln(Level.SEVERE, "Classifier is null"); 
    126142                    } 
     
    200216                         
    201217                         this.test_values = new ArrayList<double[]>(); 
    202                          for( int i=0; i < this.test.numAttributes()-1; i++ ) { 
     218                         for (int i=0; i < this.test.numAttributes()-1; i++) { 
    203219                                this.test_values.add(this.test.attributeToDoubleArray(i)); 
    204220                         } 
     
    232248                         Iterator it = this.attributes.entrySet().iterator(); 
    233249                         int j = 0; 
    234                          while(it.hasNext()) { 
     250                         while (it.hasNext()) { 
    235251                                 Map.Entry values = (Map.Entry)it.next(); 
    236252                                 ni.setValue(testdata.attribute(j), test.value((int)values.getValue())); 
     
    268284                         Attribute[] attrs = new Attribute[this.attributes.size()+1]; 
    269285                         FastVector fwTrain = new FastVector(this.attributes.size()); 
    270                          for(int i=0; i < this.attributes.size(); i++) { 
     286                         for (int i=0; i < this.attributes.size(); i++) { 
    271287                                 attrs[i] = new Attribute(String.valueOf(i)); 
    272288                                 fwTrain.addElement(attrs[i]); 
     
    281297                         newTrain.setClassIndex(newTrain.numAttributes()-1); 
    282298                          
    283                          for(int i=0; i < data.size(); i++) { 
     299                         for (int i=0; i < data.size(); i++) { 
    284300                                 Instance ni = new DenseInstance(this.attributes.size()+1); 
    285301                                 
    286302                                 Iterator it = this.attributes.entrySet().iterator(); 
    287303                                 int j = 0; 
    288                                  while(it.hasNext()) { 
     304                                 while (it.hasNext()) { 
    289305                                         Map.Entry values = (Map.Entry)it.next(); 
    290306                                         int value = (int)values.getValue(); 
    291307                                          
    292308                                         // key ist traindata 
    293                                          if(name.equals("train")) { 
     309                                         if (name.equals("train")) { 
    294310                                                 value = (int)values.getKey(); 
    295311                                         } 
     
    302318                         } 
    303319                          
    304                          return newTrain; 
    305                  } 
     320                    return newTrain; 
     321        } 
     322                  
     323                  
     324                /** 
     325                 * performs the attribute selection 
     326                 * we perform attribute significance tests and drop attributes 
     327                 */ 
     328                public void attributeSelection() throws Exception { 
     329                        this.attributeSelection(this.train); 
     330                        this.attributeSelection(this.test); 
     331                } 
     332                 
     333                private void attributeSelection(Instances which) throws Exception { 
     334                        // 1. step we have to categorize the attributes 
     335                        //http://weka.sourceforge.net/doc.packages/probabilisticSignificanceAE/weka/attributeSelection/SignificanceAttributeEval.html 
     336                         
     337                        SignificanceAttributeEval et = new SignificanceAttributeEval(); 
     338                        et.buildEvaluator(which); 
     339                        //double tmp[] = new double[this.train.numAttributes()]; 
     340                        HashMap<Integer,Double> saeval = new HashMap<Integer,Double>(); 
     341                        // evaluate all training attributes 
     342                        // select top 15% of metrics 
     343                        for(int i=0; i < which.numAttributes() - 1; i++) {  
     344                                //tmp[i] = et.evaluateAttribute(i); 
     345                                saeval.put(i, et.evaluateAttribute(i)); 
     346                                //Console.traceln(Level.SEVERE, "Significance Attribute Eval: " + tmp); 
     347                        } 
     348                         
     349                        HashMap<Integer, Double> sorted = sortByValues(saeval); 
     350                         
     351                        // die letzen 15% wollen wir haben 
     352                        int last = (saeval.size() / 100) * 15; 
     353                        int drop_first = saeval.size() - last; 
     354                         
     355                        // drop attributes above last 
     356                        Iterator it = sorted.entrySet().iterator(); 
     357                    while (it.hasNext()) { 
     358                        Map.Entry pair = (Map.Entry)it.next(); 
     359                        if(drop_first > 0) { 
     360                                which.deleteAttributeAt((int)pair.getKey()); 
     361                        } 
     362                        drop_first--; 
     363                    }    
     364                } 
     365                 
     366                private HashMap sortByValues(HashMap map) { 
     367               List list = new LinkedList(map.entrySet()); 
     368               // Defined Custom Comparator here 
     369               Collections.sort(list, new Comparator() { 
     370                    public int compare(Object o1, Object o2) { 
     371                       return ((Comparable) ((Map.Entry) (o1)).getValue()) 
     372                          .compareTo(((Map.Entry) (o2)).getValue()); 
     373                    } 
     374               }); 
     375 
     376               // Here I am copying the sorted list in HashMap 
     377               // using LinkedHashMap to preserve the insertion order 
     378               HashMap sortedHashMap = new LinkedHashMap(); 
     379               for (Iterator it = list.iterator(); it.hasNext();) { 
     380                      Map.Entry entry = (Map.Entry) it.next(); 
     381                      sortedHashMap.put(entry.getKey(), entry.getValue()); 
     382               }  
     383               return sortedHashMap; 
     384                } 
     385                  
    306386                  
    307387                 /** 
     
    315395 
    316396                         // size has to be the same so we randomly sample the number of the smaller sample from the big sample 
    317                          if( this.train.size() > this.test.size() ) { 
     397                         if (this.train.size() > this.test.size()) { 
    318398                             this.sample(this.train, this.test, this.train_values); 
    319                          }else if( this.test.size() > this.train.size() ) { 
     399                         }else if (this.test.size() > this.train.size()) { 
    320400                             this.sample(this.test, this.train, this.test_values); 
    321401                         } 
    322402                          
    323403                         // try out possible attribute combinations 
    324             for( int i=0; i < this.train.numAttributes()-1; i++ ) { 
    325                 for ( int j=0; j < this.test.numAttributes()-1; j++ ) { 
     404            for (int i=0; i < this.train.numAttributes()-1; i++) { 
     405                for (int j=0; j < this.test.numAttributes()-1; j++) { 
    326406                    // class attributes are not relevant  
    327                     if ( this.train.classIndex() == i ) { 
     407                    if (this.train.classIndex() == i) { 
    328408                        continue; 
    329409                    } 
    330                     if ( this.test.classIndex() == j ) { 
     410                    if (this.test.classIndex() == j) { 
    331411                        continue; 
    332412                    } 
    333413                     
    334414                     
    335                                         if( !this.attributes.containsKey(i) ) { 
     415                                        if (!this.attributes.containsKey(i)) { 
    336416                                                p = t.correlation(this.train_values.get(i), this.test_values.get(j)); 
    337                                                 if( p > cutoff ) { 
     417                                                if (p > cutoff) { 
    338418                                                        this.attributes.put(i, j); 
    339419                                                } 
     
    349429            ArrayList<Integer> indices = new ArrayList<Integer>(); 
    350430            Random rand = new Random(); 
    351             while( indices_to_draw > 0) { 
     431            while (indices_to_draw > 0) { 
    352432                 
    353433                int index = rand.nextInt(bigger.size()-1); 
    354434                 
    355                 if( !indices.contains(index) ) { 
     435                if (!indices.contains(index)) { 
    356436                    indices.add(index); 
    357437                    indices_to_draw--; 
     
    360440             
    361441            // now reduce our values to the indices we choose above for every attribute 
    362             for(int att=0; att < bigger.numAttributes()-1; att++ ) { 
     442            for (int att=0; att < bigger.numAttributes()-1; att++) { 
    363443                 
    364444                // get double for the att 
     
    367447                 
    368448                int i = 0; 
    369                 for( Iterator<Integer> it = indices.iterator(); it.hasNext(); ) { 
     449                for (Iterator<Integer> it = indices.iterator(); it.hasNext();) { 
    370450                    new_vals[i] = vals[it.next()]; 
    371451                    i++; 
     
    394474                        // todo: this relies on the last attribute being the class,  
    395475                        //Console.traceln(Level.INFO, "Starting Kolmogorov-Smirnov test for traindata size: " + this.train.size() + " attributes("+this.train.numAttributes()+") and testdata size: " + this.test.size() + " attributes("+this.test.numAttributes()+")"); 
    396                         for( int i=0; i < this.train.numAttributes()-1; i++ ) { 
     476                        for (int i=0; i < this.train.numAttributes()-1; i++) { 
    397477                                for ( int j=0; j < this.test.numAttributes()-1; j++) { 
    398478                                        //p = t.kolmogorovSmirnovTest(this.train_values.get(i), this.test_values.get(j)); 
     
    406486                    } 
    407487                                        // PRoblem: exactP is forced for small sample sizes and it never finishes 
    408                                         if( !this.attributes.containsKey(i) ) { 
     488                                        if (!this.attributes.containsKey(i)) { 
    409489                                                 
    410490                                                // todo: output the values and complain on the math.commons mailinglist 
    411491                                                p = t.kolmogorovSmirnovTest(this.train_values.get(i), this.test_values.get(j)); 
    412                                                 if( p > cutoff ) { 
     492                                                if (p > cutoff) { 
    413493                                                        this.attributes.put(i, j); 
    414494                                                } 
Note: See TracChangeset for help on using the changeset viewer.