Ignore:
Timestamp:
09/24/15 10:59:05 (9 years ago)
Author:
sherbold
Message:
  • formatted code and added copyrights
File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaLocalFQTraining.java

    r25 r41  
     1// Copyright 2015 Georg-August-Universität Göttingen, Germany 
     2// 
     3//   Licensed under the Apache License, Version 2.0 (the "License"); 
     4//   you may not use this file except in compliance with the License. 
     5//   You may obtain a copy of the License at 
     6// 
     7//       http://www.apache.org/licenses/LICENSE-2.0 
     8// 
     9//   Unless required by applicable law or agreed to in writing, software 
     10//   distributed under the License is distributed on an "AS IS" BASIS, 
     11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
     12//   See the License for the specific language governing permissions and 
     13//   limitations under the License. 
     14 
    115package de.ugoe.cs.cpdp.training; 
    216 
     
    2438 
    2539/** 
    26  * Trainer with reimplementation of WHERE clustering algorithm from: 
    27  * Tim Menzies, Andrew Butcher, David Cok, Andrian Marcus, Lucas Layman,  
    28  * Forrest Shull, Burak Turhan, Thomas Zimmermann,  
    29  * "Local versus Global Lessons for Defect Prediction and Effort Estimation,"  
    30  * IEEE Transactions on Software Engineering, vol. 39, no. 6, pp. 822-834, June, 2013   
     40 * Trainer with reimplementation of WHERE clustering algorithm from: Tim Menzies, Andrew Butcher, 
     41 * David Cok, Andrian Marcus, Lucas Layman, Forrest Shull, Burak Turhan, Thomas Zimmermann, 
     42 * "Local versus Global Lessons for Defect Prediction and Effort Estimation," IEEE Transactions on 
     43 * Software Engineering, vol. 39, no. 6, pp. 822-834, June, 2013 
    3144 *  
    32  * With WekaLocalFQTraining we do the following: 
    33  * 1) Run the Fastmap algorithm on all training data, let it calculate the 2 most significant  
    34  *    dimensions and projections of each instance to these dimensions 
    35  * 2) With these 2 dimensions we span a QuadTree which gets recursively split on median(x) and median(y) values. 
    36  * 3) We cluster the QuadTree nodes together if they have similar density (50%) 
    37  * 4) We save the clusters and their training data 
    38  * 5) We only use clusters with > ALPHA instances (currently Math.sqrt(SIZE)), rest is discarded with the training data of this cluster 
    39  * 6) We train a Weka classifier for each cluster with the clusters training data 
    40  * 7) We recalculate Fastmap distances for a single instance with the old pivots and then try to find a cluster containing the coords of the instance. 
    41  * 7.1.) If we can not find a cluster (due to coords outside of all clusters) we find the nearest cluster. 
    42  * 8) We classify the Instance with the classifier and traindata from the Cluster we found in 7. 
     45 * With WekaLocalFQTraining we do the following: 1) Run the Fastmap algorithm on all training data, 
     46 * let it calculate the 2 most significant dimensions and projections of each instance to these 
     47 * dimensions 2) With these 2 dimensions we span a QuadTree which gets recursively split on 
     48 * median(x) and median(y) values. 3) We cluster the QuadTree nodes together if they have similar 
     49 * density (50%) 4) We save the clusters and their training data 5) We only use clusters with > 
     50 * ALPHA instances (currently Math.sqrt(SIZE)), rest is discarded with the training data of this 
     51 * cluster 6) We train a Weka classifier for each cluster with the clusters training data 7) We 
     52 * recalculate Fastmap distances for a single instance with the old pivots and then try to find a 
     53 * cluster containing the coords of the instance. 7.1.) If we can not find a cluster (due to coords 
     54 * outside of all clusters) we find the nearest cluster. 8) We classify the Instance with the 
     55 * classifier and traindata from the Cluster we found in 7. 
    4356 */ 
    4457public class WekaLocalFQTraining extends WekaBaseTraining implements ITrainingStrategy { 
    45          
    46         private final TraindatasetCluster classifier = new TraindatasetCluster(); 
    47          
    48         @Override 
    49         public Classifier getClassifier() { 
    50                 return classifier; 
    51         } 
    52          
    53         @Override 
    54         public void apply(Instances traindata) { 
    55                 PrintStream errStr      = System.err; 
    56                 System.setErr(new PrintStream(new NullOutputStream())); 
    57                 try { 
    58                         classifier.buildClassifier(traindata); 
    59                 } catch (Exception e) { 
    60                         throw new RuntimeException(e); 
    61                 } finally { 
    62                         System.setErr(errStr); 
    63                 } 
    64         } 
    65          
    66          
    67         public class TraindatasetCluster extends AbstractClassifier { 
    68                  
    69                 private static final long serialVersionUID = 1L; 
    70                  
    71                 /* classifier per cluster */ 
    72                 private HashMap<Integer, Classifier> cclassifier; 
    73                  
    74                 /* instances per cluster */ 
    75                 private HashMap<Integer, Instances> ctraindata;  
    76                  
    77                 /* holds the instances and indices of the pivot objects of the Fastmap calculation in buildClassifier*/ 
    78                 private HashMap<Integer, Instance> cpivots; 
    79                  
    80                 /* holds the indices of the pivot objects for x,y and the dimension [x,y][dimension]*/ 
    81                 private int[][] cpivotindices; 
    82  
    83                 /* holds the sizes of the cluster multiple "boxes" per cluster */ 
    84                 private HashMap<Integer, ArrayList<Double[][]>> csize; 
    85                  
    86                 /* debug vars */ 
    87                 @SuppressWarnings("unused") 
    88                 private boolean show_biggest = true; 
    89                  
    90                 @SuppressWarnings("unused") 
    91                 private int CFOUND = 0; 
    92                 @SuppressWarnings("unused") 
    93                 private int CNOTFOUND = 0; 
    94                  
    95                  
    96                 private Instance createInstance(Instances instances, Instance instance) { 
    97                         // attributes for feeding instance to classifier 
    98                         Set<String> attributeNames = new HashSet<>(); 
    99                         for( int j=0; j<instances.numAttributes(); j++ ) { 
    100                                 attributeNames.add(instances.attribute(j).name()); 
    101                         } 
    102                          
    103                         double[] values = new double[instances.numAttributes()]; 
    104                         int index = 0; 
    105                         for( int j=0; j<instance.numAttributes(); j++ ) { 
    106                                 if( attributeNames.contains(instance.attribute(j).name())) { 
    107                                         values[index] = instance.value(j); 
    108                                         index++; 
    109                                 } 
    110                         } 
    111                          
    112                         Instances tmp = new Instances(instances); 
    113                         tmp.clear(); 
    114                         Instance instCopy = new DenseInstance(instance.weight(), values); 
    115                         instCopy.setDataset(tmp); 
    116                          
    117                         return instCopy; 
    118                 } 
    119                  
    120                 /** 
    121                  * Because Fastmap saves only the image not the values of the attributes it used 
    122                  * we can not use the old data directly to classify single instances to clusters. 
    123                  *  
    124                  * To classify a single instance we do a new fastmap computation with only the instance and 
    125                  * the old pivot elements. 
    126                  *  
    127                  * After that we find the cluster with our fastmap result for x and y. 
    128                  */ 
    129                 @Override 
    130                 public double classifyInstance(Instance instance) { 
    131                          
    132                         double ret = 0; 
    133                         try { 
    134                                 // classinstance gets passed to classifier 
    135                                 Instances traindata = ctraindata.get(0); 
    136                                 Instance classInstance = createInstance(traindata, instance); 
    137  
    138                                 // this one keeps the class attribute 
    139                                 Instances traindata2 = ctraindata.get(1);   
    140                                  
    141                                 // remove class attribute before clustering 
    142                                 Remove filter = new Remove(); 
    143                                 filter.setAttributeIndices("" + (traindata.classIndex() + 1)); 
    144                                 filter.setInputFormat(traindata); 
    145                                 traindata = Filter.useFilter(traindata, filter); 
    146                                 Instance clusterInstance = createInstance(traindata, instance); 
    147                                  
    148                                 Fastmap FMAP = new Fastmap(2); 
    149                                 EuclideanDistance dist = new EuclideanDistance(traindata); 
    150                                  
    151                                 // we set our pivot indices [x=0,y=1][dimension] 
    152                                 int[][] npivotindices = new int[2][2]; 
    153                                 npivotindices[0][0] = 1; 
    154                                 npivotindices[1][0] = 2; 
    155                                 npivotindices[0][1] = 3; 
    156                                 npivotindices[1][1] = 4; 
    157                                  
    158                                 // build temp dist matrix (2 pivots per dimension + 1 instance we want to classify) 
    159                                 // the instance we want to classify comes first after that the pivot elements in the order defined above 
    160                                 double[][] distmat = new double[2*FMAP.target_dims+1][2*FMAP.target_dims+1]; 
    161                                 distmat[0][0] = 0; 
    162                                 distmat[0][1] = dist.distance(clusterInstance, this.cpivots.get((Integer)this.cpivotindices[0][0])); 
    163                                 distmat[0][2] = dist.distance(clusterInstance, this.cpivots.get((Integer)this.cpivotindices[1][0])); 
    164                                 distmat[0][3] = dist.distance(clusterInstance, this.cpivots.get((Integer)this.cpivotindices[0][1])); 
    165                                 distmat[0][4] = dist.distance(clusterInstance, this.cpivots.get((Integer)this.cpivotindices[1][1])); 
    166                                  
    167                                 distmat[1][0] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][0]), clusterInstance); 
    168                                 distmat[1][1] = 0; 
    169                                 distmat[1][2] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][0]), this.cpivots.get((Integer)this.cpivotindices[1][0])); 
    170                                 distmat[1][3] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][0]), this.cpivots.get((Integer)this.cpivotindices[0][1])); 
    171                                 distmat[1][4] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][0]), this.cpivots.get((Integer)this.cpivotindices[1][1])); 
    172                                  
    173                                 distmat[2][0] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][0]), clusterInstance); 
    174                                 distmat[2][1] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][0]), this.cpivots.get((Integer)this.cpivotindices[0][0])); 
    175                                 distmat[2][2] = 0; 
    176                                 distmat[2][3] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][0]), this.cpivots.get((Integer)this.cpivotindices[0][1])); 
    177                                 distmat[2][4] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][0]), this.cpivots.get((Integer)this.cpivotindices[1][1])); 
    178                                  
    179                                 distmat[3][0] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][1]), clusterInstance); 
    180                                 distmat[3][1] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][1]), this.cpivots.get((Integer)this.cpivotindices[0][0])); 
    181                                 distmat[3][2] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][1]), this.cpivots.get((Integer)this.cpivotindices[1][0])); 
    182                                 distmat[3][3] = 0; 
    183                                 distmat[3][4] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][1]), this.cpivots.get((Integer)this.cpivotindices[1][1])); 
    184  
    185                                 distmat[4][0] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][1]), clusterInstance); 
    186                                 distmat[4][1] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][1]), this.cpivots.get((Integer)this.cpivotindices[0][0])); 
    187                                 distmat[4][2] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][1]), this.cpivots.get((Integer)this.cpivotindices[1][0])); 
    188                                 distmat[4][3] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][1]), this.cpivots.get((Integer)this.cpivotindices[0][1])); 
    189                                 distmat[4][4] = 0; 
    190                                  
    191                                  
    192                                 /* debug output: show biggest distance found within the new distance matrix 
    193                                 double biggest = 0; 
    194                                 for(int i=0; i < distmat.length; i++) { 
    195                                         for(int j=0; j < distmat[0].length; j++) { 
    196                                                 if(biggest < distmat[i][j]) { 
    197                                                         biggest = distmat[i][j]; 
    198                                                 } 
    199                                         } 
    200                                 } 
    201                                 if(this.show_biggest) { 
    202                                         Console.traceln(Level.INFO, String.format(""+clusterInstance)); 
    203                                         Console.traceln(Level.INFO, String.format("biggest distances: "+ biggest)); 
    204                                         this.show_biggest = false; 
    205                                 } 
    206                                 */ 
    207  
    208                                 FMAP.setDistmat(distmat); 
    209                                 FMAP.setPivots(npivotindices); 
    210                                 FMAP.calculate(); 
    211                                 double[][] x = FMAP.getX(); 
    212                                 double[] proj = x[0]; 
    213  
    214                                 // debug output: show the calculated distance matrix, our result vektor for the instance and the complete result matrix 
    215                                 /* 
    216                                 Console.traceln(Level.INFO, "distmat:"); 
    217                             for(int i=0; i<distmat.length; i++){ 
    218                                 for(int j=0; j<distmat[0].length; j++){ 
    219                                         Console.trace(Level.INFO, String.format("%20s", distmat[i][j])); 
    220                                 } 
    221                                 Console.traceln(Level.INFO, ""); 
    222                             } 
    223                              
    224                             Console.traceln(Level.INFO, "vector:"); 
    225                             for(int i=0; i < proj.length; i++) { 
    226                                 Console.trace(Level.INFO, String.format("%20s", proj[i])); 
    227                             } 
    228                             Console.traceln(Level.INFO, ""); 
    229                              
    230                                 Console.traceln(Level.INFO, "resultmat:"); 
    231                             for(int i=0; i<x.length; i++){ 
    232                                 for(int j=0; j<x[0].length; j++){ 
    233                                         Console.trace(Level.INFO, String.format("%20s", x[i][j])); 
    234                                 } 
    235                                 Console.traceln(Level.INFO, ""); 
    236                             } 
    237                             */ 
    238                                  
    239                                 // now we iterate over all clusters (well, boxes of sizes per cluster really) and save the number of the  
    240                                 // cluster in which we are 
    241                                 int cnumber; 
    242                                 int found_cnumber = -1; 
    243                                 Iterator<Integer> clusternumber = this.csize.keySet().iterator(); 
    244                                 while ( clusternumber.hasNext() && found_cnumber == -1) { 
    245                                         cnumber = clusternumber.next(); 
    246                                          
    247                                         // now iterate over the boxes of the cluster and hope we find one (cluster could have been removed) 
    248                                         // or we are too far away from any cluster because of the fastmap calculation with the initial pivot objects 
    249                                         for ( int box=0; box < this.csize.get(cnumber).size(); box++ ) {  
    250                                                 Double[][] current = this.csize.get(cnumber).get(box); 
    251                                                  
    252                                                 if(proj[0] >= current[0][0] && proj[0] <= current[0][1] &&  // x  
    253                                                    proj[1] >= current[1][0] && proj[1] <= current[1][1]) {  // y 
    254                                                         found_cnumber = cnumber; 
    255                                                 } 
    256                                         } 
    257                                 } 
    258                                  
    259                                 // we want to count how often we are really inside a cluster 
    260                                 //if ( found_cnumber == -1 ) { 
    261                                 //      CNOTFOUND += 1; 
    262                                 //}else { 
    263                                 //      CFOUND += 1; 
    264                                 //} 
    265  
    266                                 // now it can happen that we do not find a cluster because we deleted it previously (too few instances) 
    267                                 // or we get bigger distance measures from weka so that we are completely outside of our clusters. 
    268                                 // in these cases we just find the nearest cluster to our instance and use it for classification. 
    269                                 // to do that we use the EuclideanDistance again to compare our distance to all other Instances 
    270                                 // then we take the cluster of the closest weka instance 
    271                                 dist = new EuclideanDistance(traindata2); 
    272                                 if( !this.ctraindata.containsKey(found_cnumber) ) {  
    273                                         double min_distance = Double.MAX_VALUE; 
    274                                         clusternumber = ctraindata.keySet().iterator(); 
    275                                         while ( clusternumber.hasNext() ) { 
    276                                                 cnumber = clusternumber.next(); 
    277                                                 for(int i=0; i < ctraindata.get(cnumber).size(); i++) { 
    278                                                         if(dist.distance(instance, ctraindata.get(cnumber).get(i)) <= min_distance) { 
    279                                                                 found_cnumber = cnumber; 
    280                                                                 min_distance = dist.distance(instance, ctraindata.get(cnumber).get(i)); 
    281                                                         } 
    282                                                 } 
    283                                         } 
    284                                 } 
    285                                  
    286                                 // here we have the cluster where an instance has the minimum distance between itself and the 
    287                                 // instance we want to classify 
    288                                 // if we still have not found a cluster we exit because something is really wrong 
    289                                 if( found_cnumber == -1 ) { 
    290                                         Console.traceln(Level.INFO, String.format("ERROR matching instance to cluster with full search!")); 
    291                                         throw new RuntimeException("cluster not found with full search"); 
    292                                 } 
    293                                  
    294                                 // classify the passed instance with the cluster we found and its training data 
    295                                 ret = cclassifier.get(found_cnumber).classifyInstance(classInstance); 
    296                                  
    297                         }catch( Exception e ) { 
    298                                 Console.traceln(Level.INFO, String.format("ERROR matching instance to cluster!")); 
    299                                 throw new RuntimeException(e); 
    300                         } 
    301                         return ret; 
    302                 } 
    303                  
    304                 @Override 
    305                 public void buildClassifier(Instances traindata) throws Exception { 
    306                          
    307                         //Console.traceln(Level.INFO, String.format("found: "+ CFOUND + ", notfound: " + CNOTFOUND)); 
    308                         this.show_biggest = true; 
    309                          
    310                         cclassifier = new HashMap<Integer, Classifier>(); 
    311                         ctraindata = new HashMap<Integer, Instances>(); 
    312                         cpivots = new HashMap<Integer, Instance>(); 
    313                         cpivotindices = new int[2][2]; 
    314                          
    315                         // 1. copy traindata 
    316                         Instances train = new Instances(traindata); 
    317                         Instances train2 = new Instances(traindata);  // this one keeps the class attribute 
    318                          
    319                         // 2. remove class attribute for clustering 
    320                         Remove filter = new Remove(); 
    321                         filter.setAttributeIndices("" + (train.classIndex() + 1)); 
    322                         filter.setInputFormat(train); 
    323                         train = Filter.useFilter(train, filter); 
    324                          
    325                         // 3. calculate distance matrix (needed for Fastmap because it starts at dimension 1) 
    326                         double biggest = 0; 
    327                         EuclideanDistance dist = new EuclideanDistance(train); 
    328                         double[][] distmat = new double[train.size()][train.size()]; 
    329                         for( int i=0; i < train.size(); i++ ) { 
    330                                 for( int j=0; j < train.size(); j++ ) { 
    331                                         distmat[i][j] = dist.distance(train.get(i), train.get(j)); 
    332                                         if( distmat[i][j] > biggest ) { 
    333                                                 biggest = distmat[i][j]; 
    334                                         } 
    335                                 } 
    336                         } 
    337                         //Console.traceln(Level.INFO, String.format("biggest distances: "+ biggest)); 
    338                          
    339                         // 4. run fastmap for 2 dimensions on the distance matrix 
    340                         Fastmap FMAP = new Fastmap(2); 
    341                         FMAP.setDistmat(distmat); 
    342                         FMAP.calculate(); 
    343                          
    344                         cpivotindices = FMAP.getPivots(); 
    345                          
    346                         double[][] X = FMAP.getX(); 
    347                         distmat = new double[0][0]; 
    348                         System.gc(); 
    349                          
    350                         // quadtree payload generation 
    351                         ArrayList<QuadTreePayload<Instance>> qtp = new ArrayList<QuadTreePayload<Instance>>(); 
    352                      
    353                         // we need these for the sizes of the quadrants 
    354                         double[] big = {0,0}; 
    355                         double[] small = {Double.MAX_VALUE,Double.MAX_VALUE}; 
    356                          
    357                         // set quadtree payload values and get max and min x and y values for size 
    358                     for( int i=0; i<X.length; i++ ){ 
    359                         if(X[i][0] >= big[0]) { 
    360                                 big[0] = X[i][0]; 
    361                         } 
    362                         if(X[i][1] >= big[1]) { 
    363                                 big[1] = X[i][1]; 
    364                         } 
    365                         if(X[i][0] <= small[0]) { 
    366                                 small[0] = X[i][0]; 
    367                         } 
    368                         if(X[i][1] <= small[1]) { 
    369                                 small[1] = X[i][1]; 
    370                         } 
    371                         QuadTreePayload<Instance> tmp = new QuadTreePayload<Instance>(X[i][0], X[i][1], train2.get(i)); 
    372                         qtp.add(tmp); 
    373                     } 
    374                      
    375                     //Console.traceln(Level.INFO, String.format("size for cluster ("+small[0]+","+small[1]+") - ("+big[0]+","+big[1]+")")); 
    376                      
    377                     // 5. generate quadtree 
    378                     QuadTree TREE = new QuadTree(null, qtp); 
    379                     QuadTree.size = train.size(); 
    380                     QuadTree.alpha = Math.sqrt(train.size()); 
    381                     QuadTree.ccluster = new ArrayList<ArrayList<QuadTreePayload<Instance>>>(); 
    382                     QuadTree.csize = new HashMap<Integer, ArrayList<Double[][]>>(); 
    383                      
    384                     //Console.traceln(Level.INFO, String.format("Generate QuadTree with "+ QuadTree.size + " size, Alpha: "+ QuadTree.alpha+ "")); 
    385                      
    386                     // set the size and then split the tree recursively at the median value for x, y 
    387                     TREE.setSize(new double[] {small[0], big[0]}, new double[] {small[1], big[1]}); 
    388                      
    389                     // recursive split und grid clustering eher static 
    390                     TREE.recursiveSplit(TREE); 
    391                      
    392                     // generate list of nodes sorted by density (childs only) 
    393                     ArrayList<QuadTree> l = new ArrayList<QuadTree>(TREE.getList(TREE)); 
    394                      
    395                     // recursive grid clustering (tree pruning), the values are stored in ccluster 
    396                     TREE.gridClustering(l); 
    397                      
    398                     // wir iterieren durch die cluster und sammeln uns die instanzen daraus 
    399                     //ctraindata.clear(); 
    400                     for( int i=0; i < QuadTree.ccluster.size(); i++ ) { 
    401                         ArrayList<QuadTreePayload<Instance>> current = QuadTree.ccluster.get(i); 
    402                          
    403                         // i is the clusternumber 
    404                         // we only allow clusters with Instances > ALPHA, other clusters are not considered! 
    405                         //if(current.size() > QuadTree.alpha) { 
    406                         if( current.size() > 4 ) { 
    407                                 for( int j=0; j < current.size(); j++ ) { 
    408                                         if( !ctraindata.containsKey(i) ) { 
    409                                                 ctraindata.put(i, new Instances(train2)); 
    410                                                 ctraindata.get(i).delete(); 
    411                                         } 
    412                                         ctraindata.get(i).add(current.get(j).getInst()); 
    413                                 } 
    414                         }else{ 
    415                                 Console.traceln(Level.INFO, String.format("drop cluster, only: " + current.size() + " instances")); 
    416                         } 
    417                     } 
    418                          
    419                         // here we keep things we need later on 
    420                         // QuadTree sizes for later use (matching new instances) 
    421                         this.csize = new HashMap<Integer, ArrayList<Double[][]>>(QuadTree.csize); 
    422                  
    423                         // pivot elements 
    424                         //this.cpivots.clear(); 
    425                         for( int i=0; i < FMAP.PA[0].length; i++ ) { 
    426                                 this.cpivots.put(FMAP.PA[0][i], (Instance)train.get(FMAP.PA[0][i]).copy()); 
    427                         } 
    428                         for( int j=0; j < FMAP.PA[0].length; j++ ) { 
    429                                 this.cpivots.put(FMAP.PA[1][j], (Instance)train.get(FMAP.PA[1][j]).copy()); 
    430                         } 
    431                          
    432                          
    433                         /* debug output 
    434                         int pnumber; 
    435                         Iterator<Integer> pivotnumber = cpivots.keySet().iterator(); 
    436                         while ( pivotnumber.hasNext() ) { 
    437                                 pnumber = pivotnumber.next(); 
    438                                 Console.traceln(Level.INFO, String.format("pivot: "+pnumber+ " inst: "+cpivots.get(pnumber))); 
    439                         } 
    440                         */ 
    441                          
    442                     // train one classifier per cluster, we get the cluster number from the traindata 
    443                     int cnumber; 
    444                         Iterator<Integer> clusternumber = ctraindata.keySet().iterator(); 
    445                         //cclassifier.clear(); 
    446                          
    447                         //int traindata_count = 0; 
    448                         while ( clusternumber.hasNext() ) { 
    449                                 cnumber = clusternumber.next(); 
    450                                 cclassifier.put(cnumber,setupClassifier());  // this is the classifier used for the cluster  
    451                                 cclassifier.get(cnumber).buildClassifier(ctraindata.get(cnumber)); 
    452                                 //Console.traceln(Level.INFO, String.format("classifier in cluster "+cnumber)); 
    453                                 //traindata_count += ctraindata.get(cnumber).size(); 
    454                                 //Console.traceln(Level.INFO, String.format("building classifier in cluster "+cnumber +"  with "+ ctraindata.get(cnumber).size() +" traindata instances")); 
    455                         } 
    456                          
    457                         // add all traindata 
    458                         //Console.traceln(Level.INFO, String.format("traindata in all clusters: " + traindata_count)); 
    459                 } 
    460         } 
    461          
    462  
    463         /** 
    464          * Payload for the QuadTree. 
    465          * x and y are the calculated Fastmap values. 
    466          * T is a weka instance. 
    467          */ 
    468         public class QuadTreePayload<T> { 
    469  
    470                 public double x; 
    471                 public double y; 
    472                 private T inst; 
    473                  
    474                 public QuadTreePayload(double x, double y, T value) { 
    475                         this.x = x; 
    476                         this.y = y; 
    477                         this.inst = value; 
    478                 } 
    479                  
    480                 public T getInst() { 
    481                         return this.inst; 
    482                 } 
    483         } 
    484          
    485          
    486         /** 
    487          * Fastmap implementation 
    488          *  
    489          * Faloutsos, C., & Lin, K. I. (1995).  
    490          * FastMap: A fast algorithm for indexing, data-mining and visualization of traditional and multimedia datasets  
    491          * (Vol. 24, No. 2, pp. 163-174). ACM. 
    492          */ 
    493         public class Fastmap { 
    494                  
    495                 /*N x k Array, at the end, the i-th row will be the image of the i-th object*/ 
    496                 private double[][] X; 
    497                  
    498                 /*2 x k pivot Array one pair per recursive call*/ 
    499                 private int[][] PA; 
    500                  
    501                 /*Objects we got (distance matrix)*/ 
    502                 private double[][] O; 
    503                  
    504                 /*column of X currently updated (also the dimension)*/ 
    505                 private int col = 0; 
    506                  
    507                 /*number of dimensions we want*/ 
    508                 private int target_dims = 0; 
    509                  
    510                 // if we already have the pivot elements 
    511                 private boolean pivot_set = false; 
    512                  
    513  
    514                 public Fastmap(int k) { 
    515                         this.target_dims = k; 
    516                 } 
    517                  
    518                 /** 
    519                  * Sets the distance matrix 
    520                  * and params that depend on this 
    521                  * @param O 
    522                  */ 
    523                 public void setDistmat(double[][] O) { 
    524                         this.O = O; 
    525                         int N = O.length; 
    526                         this.X = new double[N][this.target_dims]; 
    527                         this.PA = new int[2][this.target_dims]; 
    528                 } 
    529                  
    530                 /** 
    531                  * Set pivot elements, we need that to classify instances 
    532                  * after the calculation is complete (because we then want to reuse 
    533                  * only the pivot elements). 
    534                  *  
    535                  * @param pi 
    536                  */ 
    537                 public void setPivots(int[][] pi) { 
    538                         this.pivot_set = true; 
    539                         this.PA = pi; 
    540                 } 
    541                  
    542                 /** 
    543                  * Return the pivot elements that were chosen during the calculation 
    544                  *  
    545                  * @return 
    546                  */ 
    547                 public int[][] getPivots() { 
    548                         return this.PA; 
    549                 } 
    550                  
    551                 /** 
    552                  * The distance function for euclidean distance 
    553                  *  
    554                  * Acts according to equation 4 of the fastmap paper 
    555                  *   
    556                  * @param x x index of x image (if k==0 x object) 
    557                  * @param y y index of y image (if k==0 y object) 
    558                  * @param kdimensionality 
    559                  * @return distance 
    560                  */ 
    561                 private double dist(int x, int y, int k) { 
    562                          
    563                         // basis is object distance, we get this from our distance matrix 
    564                         double tmp = this.O[x][y] * this.O[x][y];  
    565                          
    566                         // decrease by projections 
    567                         for( int i=0; i < k; i++ ) { 
    568                                 double tmp2 = (this.X[x][i] - this.X[y][i]); 
    569                                 tmp -= tmp2 * tmp2; 
    570                         } 
    571                          
    572                         return Math.abs(tmp); 
    573                 } 
    574  
    575                 /** 
    576                  * Find the object farthest from the given index 
    577                  * This method is a helper Method for findDistandObjects 
    578                  *  
    579                  * @param index of the object  
    580                  * @return index of the farthest object from the given index 
    581                  */ 
    582                 private int findFarthest(int index) { 
    583                         double furthest = Double.MIN_VALUE; 
    584                         int ret = 0; 
    585                          
    586                         for( int i=0; i < O.length; i++ ) { 
    587                                 double dist = this.dist(i, index, this.col); 
    588                                 if( i != index && dist > furthest ) { 
    589                                         furthest = dist; 
    590                                         ret = i; 
    591                                 } 
    592                         } 
    593                         return ret; 
    594                 } 
    595                  
    596                 /** 
    597                  * Finds the pivot objects  
    598                  *  
    599                  * This method is basically algorithm 1 of the fastmap paper. 
    600                  *  
    601                  * @return 2 indexes of the choosen pivot objects 
    602                  */ 
    603                 private int[] findDistantObjects() { 
    604                         // 1. choose object randomly 
    605                         Random r = new Random(); 
    606                         int obj = r.nextInt(this.O.length); 
    607                          
    608                         // 2. find farthest object from randomly chosen object 
    609                         int idx1 = this.findFarthest(obj); 
    610                          
    611                         // 3. find farthest object from previously farthest object 
    612                         int idx2 = this.findFarthest(idx1); 
    613  
    614                         return new int[] {idx1, idx2}; 
    615                 } 
    616          
    617                 /** 
    618                  * Calculates the new k-vector values (projections) 
    619                  *  
    620                  * This is basically algorithm 2 of the fastmap paper. 
    621                  * We just added the possibility to pre-set the pivot elements because 
    622                  * we need to classify single instances after the computation is already done. 
    623                  *  
    624                  * @param dims dimensionality 
    625                  */ 
    626                 public void calculate() { 
    627                          
    628                         for( int k=0; k < this.target_dims; k++ ) { 
    629                                 // 2) choose pivot objects 
    630                                 if ( !this.pivot_set ) { 
    631                                         int[] pivots = this.findDistantObjects(); 
    632                  
    633                                         // 3) record ids of pivot objects  
    634                                         this.PA[0][this.col] = pivots[0]; 
    635                                         this.PA[1][this.col] = pivots[1]; 
    636                                 } 
    637                                  
    638                                 // 4) inter object distances are zero (this.X is initialized with 0 so we just continue) 
    639                                 if( this.dist(this.PA[0][this.col], this.PA[1][this.col], this.col) == 0 ) { 
    640                                         continue; 
    641                                 } 
    642                                  
    643                                 // 5) project the objects on the line between the pivots 
    644                                 double dxy = this.dist(this.PA[0][this.col], this.PA[1][this.col], this.col); 
    645                                 for( int i=0; i < this.O.length; i++ ) { 
    646                                          
    647                                         double dix = this.dist(i, this.PA[0][this.col], this.col); 
    648                                         double diy = this.dist(i, this.PA[1][this.col], this.col); 
    649                                          
    650                                         double tmp = (dix + dxy - diy) / (2 * Math.sqrt(dxy)); 
    651                                          
    652                                         // save the projection 
    653                                         this.X[i][this.col] = tmp; 
    654                                 } 
    655                                  
    656                                 this.col += 1; 
    657                         } 
    658                 } 
    659                  
    660                 /** 
    661                  * returns the result matrix of the projections 
    662                  *  
    663                  * @return calculated result 
    664                  */ 
    665                 public double[][] getX() { 
    666                         return this.X; 
    667                 } 
    668         } 
     58 
     59    private final TraindatasetCluster classifier = new TraindatasetCluster(); 
     60 
     61    @Override 
     62    public Classifier getClassifier() { 
     63        return classifier; 
     64    } 
     65 
     66    @Override 
     67    public void apply(Instances traindata) { 
     68        PrintStream errStr = System.err; 
     69        System.setErr(new PrintStream(new NullOutputStream())); 
     70        try { 
     71            classifier.buildClassifier(traindata); 
     72        } 
     73        catch (Exception e) { 
     74            throw new RuntimeException(e); 
     75        } 
     76        finally { 
     77            System.setErr(errStr); 
     78        } 
     79    } 
     80 
     81    public class TraindatasetCluster extends AbstractClassifier { 
     82 
     83        private static final long serialVersionUID = 1L; 
     84 
     85        /* classifier per cluster */ 
     86        private HashMap<Integer, Classifier> cclassifier; 
     87 
     88        /* instances per cluster */ 
     89        private HashMap<Integer, Instances> ctraindata; 
     90 
     91        /* 
     92         * holds the instances and indices of the pivot objects of the Fastmap calculation in 
     93         * buildClassifier 
     94         */ 
     95        private HashMap<Integer, Instance> cpivots; 
     96 
     97        /* holds the indices of the pivot objects for x,y and the dimension [x,y][dimension] */ 
     98        private int[][] cpivotindices; 
     99 
     100        /* holds the sizes of the cluster multiple "boxes" per cluster */ 
     101        private HashMap<Integer, ArrayList<Double[][]>> csize; 
     102 
     103        /* debug vars */ 
     104        @SuppressWarnings("unused") 
     105        private boolean show_biggest = true; 
     106 
     107        @SuppressWarnings("unused") 
     108        private int CFOUND = 0; 
     109        @SuppressWarnings("unused") 
     110        private int CNOTFOUND = 0; 
     111 
     112        private Instance createInstance(Instances instances, Instance instance) { 
     113            // attributes for feeding instance to classifier 
     114            Set<String> attributeNames = new HashSet<>(); 
     115            for (int j = 0; j < instances.numAttributes(); j++) { 
     116                attributeNames.add(instances.attribute(j).name()); 
     117            } 
     118 
     119            double[] values = new double[instances.numAttributes()]; 
     120            int index = 0; 
     121            for (int j = 0; j < instance.numAttributes(); j++) { 
     122                if (attributeNames.contains(instance.attribute(j).name())) { 
     123                    values[index] = instance.value(j); 
     124                    index++; 
     125                } 
     126            } 
     127 
     128            Instances tmp = new Instances(instances); 
     129            tmp.clear(); 
     130            Instance instCopy = new DenseInstance(instance.weight(), values); 
     131            instCopy.setDataset(tmp); 
     132 
     133            return instCopy; 
     134        } 
     135 
     136        /** 
     137         * Because Fastmap saves only the image not the values of the attributes it used we can not 
     138         * use the old data directly to classify single instances to clusters. 
     139         *  
     140         * To classify a single instance we do a new fastmap computation with only the instance and 
     141         * the old pivot elements. 
     142         *  
     143         * After that we find the cluster with our fastmap result for x and y. 
     144         */ 
     145        @Override 
     146        public double classifyInstance(Instance instance) { 
     147 
     148            double ret = 0; 
     149            try { 
     150                // classinstance gets passed to classifier 
     151                Instances traindata = ctraindata.get(0); 
     152                Instance classInstance = createInstance(traindata, instance); 
     153 
     154                // this one keeps the class attribute 
     155                Instances traindata2 = ctraindata.get(1); 
     156 
     157                // remove class attribute before clustering 
     158                Remove filter = new Remove(); 
     159                filter.setAttributeIndices("" + (traindata.classIndex() + 1)); 
     160                filter.setInputFormat(traindata); 
     161                traindata = Filter.useFilter(traindata, filter); 
     162                Instance clusterInstance = createInstance(traindata, instance); 
     163 
     164                Fastmap FMAP = new Fastmap(2); 
     165                EuclideanDistance dist = new EuclideanDistance(traindata); 
     166 
     167                // we set our pivot indices [x=0,y=1][dimension] 
     168                int[][] npivotindices = new int[2][2]; 
     169                npivotindices[0][0] = 1; 
     170                npivotindices[1][0] = 2; 
     171                npivotindices[0][1] = 3; 
     172                npivotindices[1][1] = 4; 
     173 
     174                // build temp dist matrix (2 pivots per dimension + 1 instance we want to classify) 
     175                // the instance we want to classify comes first after that the pivot elements in the 
     176                // order defined above 
     177                double[][] distmat = new double[2 * FMAP.target_dims + 1][2 * FMAP.target_dims + 1]; 
     178                distmat[0][0] = 0; 
     179                distmat[0][1] = 
     180                    dist.distance(clusterInstance, 
     181                                  this.cpivots.get((Integer) this.cpivotindices[0][0])); 
     182                distmat[0][2] = 
     183                    dist.distance(clusterInstance, 
     184                                  this.cpivots.get((Integer) this.cpivotindices[1][0])); 
     185                distmat[0][3] = 
     186                    dist.distance(clusterInstance, 
     187                                  this.cpivots.get((Integer) this.cpivotindices[0][1])); 
     188                distmat[0][4] = 
     189                    dist.distance(clusterInstance, 
     190                                  this.cpivots.get((Integer) this.cpivotindices[1][1])); 
     191 
     192                distmat[1][0] = 
     193                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][0]), 
     194                                  clusterInstance); 
     195                distmat[1][1] = 0; 
     196                distmat[1][2] = 
     197                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][0]), 
     198                                  this.cpivots.get((Integer) this.cpivotindices[1][0])); 
     199                distmat[1][3] = 
     200                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][0]), 
     201                                  this.cpivots.get((Integer) this.cpivotindices[0][1])); 
     202                distmat[1][4] = 
     203                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][0]), 
     204                                  this.cpivots.get((Integer) this.cpivotindices[1][1])); 
     205 
     206                distmat[2][0] = 
     207                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][0]), 
     208                                  clusterInstance); 
     209                distmat[2][1] = 
     210                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][0]), 
     211                                  this.cpivots.get((Integer) this.cpivotindices[0][0])); 
     212                distmat[2][2] = 0; 
     213                distmat[2][3] = 
     214                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][0]), 
     215                                  this.cpivots.get((Integer) this.cpivotindices[0][1])); 
     216                distmat[2][4] = 
     217                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][0]), 
     218                                  this.cpivots.get((Integer) this.cpivotindices[1][1])); 
     219 
     220                distmat[3][0] = 
     221                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][1]), 
     222                                  clusterInstance); 
     223                distmat[3][1] = 
     224                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][1]), 
     225                                  this.cpivots.get((Integer) this.cpivotindices[0][0])); 
     226                distmat[3][2] = 
     227                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][1]), 
     228                                  this.cpivots.get((Integer) this.cpivotindices[1][0])); 
     229                distmat[3][3] = 0; 
     230                distmat[3][4] = 
     231                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][1]), 
     232                                  this.cpivots.get((Integer) this.cpivotindices[1][1])); 
     233 
     234                distmat[4][0] = 
     235                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][1]), 
     236                                  clusterInstance); 
     237                distmat[4][1] = 
     238                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][1]), 
     239                                  this.cpivots.get((Integer) this.cpivotindices[0][0])); 
     240                distmat[4][2] = 
     241                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][1]), 
     242                                  this.cpivots.get((Integer) this.cpivotindices[1][0])); 
     243                distmat[4][3] = 
     244                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][1]), 
     245                                  this.cpivots.get((Integer) this.cpivotindices[0][1])); 
     246                distmat[4][4] = 0; 
     247 
     248                /* 
     249                 * debug output: show biggest distance found within the new distance matrix double 
     250                 * biggest = 0; for(int i=0; i < distmat.length; i++) { for(int j=0; j < 
     251                 * distmat[0].length; j++) { if(biggest < distmat[i][j]) { biggest = distmat[i][j]; 
     252                 * } } } if(this.show_biggest) { Console.traceln(Level.INFO, 
     253                 * String.format(""+clusterInstance)); Console.traceln(Level.INFO, 
     254                 * String.format("biggest distances: "+ biggest)); this.show_biggest = false; } 
     255                 */ 
     256 
     257                FMAP.setDistmat(distmat); 
     258                FMAP.setPivots(npivotindices); 
     259                FMAP.calculate(); 
     260                double[][] x = FMAP.getX(); 
     261                double[] proj = x[0]; 
     262 
     263                // debug output: show the calculated distance matrix, our result vektor for the 
     264                // instance and the complete result matrix 
     265                /* 
     266                 * Console.traceln(Level.INFO, "distmat:"); for(int i=0; i<distmat.length; i++){ 
     267                 * for(int j=0; j<distmat[0].length; j++){ Console.trace(Level.INFO, 
     268                 * String.format("%20s", distmat[i][j])); } Console.traceln(Level.INFO, ""); } 
     269                 *  
     270                 * Console.traceln(Level.INFO, "vector:"); for(int i=0; i < proj.length; i++) { 
     271                 * Console.trace(Level.INFO, String.format("%20s", proj[i])); } 
     272                 * Console.traceln(Level.INFO, ""); 
     273                 *  
     274                 * Console.traceln(Level.INFO, "resultmat:"); for(int i=0; i<x.length; i++){ for(int 
     275                 * j=0; j<x[0].length; j++){ Console.trace(Level.INFO, String.format("%20s", 
     276                 * x[i][j])); } Console.traceln(Level.INFO, ""); } 
     277                 */ 
     278 
     279                // now we iterate over all clusters (well, boxes of sizes per cluster really) and 
     280                // save the number of the 
     281                // cluster in which we are 
     282                int cnumber; 
     283                int found_cnumber = -1; 
     284                Iterator<Integer> clusternumber = this.csize.keySet().iterator(); 
     285                while (clusternumber.hasNext() && found_cnumber == -1) { 
     286                    cnumber = clusternumber.next(); 
     287 
     288                    // now iterate over the boxes of the cluster and hope we find one (cluster could 
     289                    // have been removed) 
     290                    // or we are too far away from any cluster because of the fastmap calculation 
     291                    // with the initial pivot objects 
     292                    for (int box = 0; box < this.csize.get(cnumber).size(); box++) { 
     293                        Double[][] current = this.csize.get(cnumber).get(box); 
     294 
     295                        if (proj[0] >= current[0][0] && proj[0] <= current[0][1] && // x 
     296                            proj[1] >= current[1][0] && proj[1] <= current[1][1]) 
     297                        { // y 
     298                            found_cnumber = cnumber; 
     299                        } 
     300                    } 
     301                } 
     302 
     303                // we want to count how often we are really inside a cluster 
     304                // if ( found_cnumber == -1 ) { 
     305                // CNOTFOUND += 1; 
     306                // }else { 
     307                // CFOUND += 1; 
     308                // } 
     309 
     310                // now it can happen that we do not find a cluster because we deleted it previously 
     311                // (too few instances) 
     312                // or we get bigger distance measures from weka so that we are completely outside of 
     313                // our clusters. 
     314                // in these cases we just find the nearest cluster to our instance and use it for 
     315                // classification. 
     316                // to do that we use the EuclideanDistance again to compare our distance to all 
     317                // other Instances 
     318                // then we take the cluster of the closest weka instance 
     319                dist = new EuclideanDistance(traindata2); 
     320                if (!this.ctraindata.containsKey(found_cnumber)) { 
     321                    double min_distance = Double.MAX_VALUE; 
     322                    clusternumber = ctraindata.keySet().iterator(); 
     323                    while (clusternumber.hasNext()) { 
     324                        cnumber = clusternumber.next(); 
     325                        for (int i = 0; i < ctraindata.get(cnumber).size(); i++) { 
     326                            if (dist.distance(instance, ctraindata.get(cnumber).get(i)) <= min_distance) 
     327                            { 
     328                                found_cnumber = cnumber; 
     329                                min_distance = 
     330                                    dist.distance(instance, ctraindata.get(cnumber).get(i)); 
     331                            } 
     332                        } 
     333                    } 
     334                } 
     335 
     336                // here we have the cluster where an instance has the minimum distance between 
     337                // itself and the 
     338                // instance we want to classify 
     339                // if we still have not found a cluster we exit because something is really wrong 
     340                if (found_cnumber == -1) { 
     341                    Console.traceln(Level.INFO, String 
     342                        .format("ERROR matching instance to cluster with full search!")); 
     343                    throw new RuntimeException("cluster not found with full search"); 
     344                } 
     345 
     346                // classify the passed instance with the cluster we found and its training data 
     347                ret = cclassifier.get(found_cnumber).classifyInstance(classInstance); 
     348 
     349            } 
     350            catch (Exception e) { 
     351                Console.traceln(Level.INFO, String.format("ERROR matching instance to cluster!")); 
     352                throw new RuntimeException(e); 
     353            } 
     354            return ret; 
     355        } 
     356 
     357        @Override 
     358        public void buildClassifier(Instances traindata) throws Exception { 
     359 
     360            // Console.traceln(Level.INFO, String.format("found: "+ CFOUND + ", notfound: " + 
     361            // CNOTFOUND)); 
     362            this.show_biggest = true; 
     363 
     364            cclassifier = new HashMap<Integer, Classifier>(); 
     365            ctraindata = new HashMap<Integer, Instances>(); 
     366            cpivots = new HashMap<Integer, Instance>(); 
     367            cpivotindices = new int[2][2]; 
     368 
     369            // 1. copy traindata 
     370            Instances train = new Instances(traindata); 
     371            Instances train2 = new Instances(traindata); // this one keeps the class attribute 
     372 
     373            // 2. remove class attribute for clustering 
     374            Remove filter = new Remove(); 
     375            filter.setAttributeIndices("" + (train.classIndex() + 1)); 
     376            filter.setInputFormat(train); 
     377            train = Filter.useFilter(train, filter); 
     378 
     379            // 3. calculate distance matrix (needed for Fastmap because it starts at dimension 1) 
     380            double biggest = 0; 
     381            EuclideanDistance dist = new EuclideanDistance(train); 
     382            double[][] distmat = new double[train.size()][train.size()]; 
     383            for (int i = 0; i < train.size(); i++) { 
     384                for (int j = 0; j < train.size(); j++) { 
     385                    distmat[i][j] = dist.distance(train.get(i), train.get(j)); 
     386                    if (distmat[i][j] > biggest) { 
     387                        biggest = distmat[i][j]; 
     388                    } 
     389                } 
     390            } 
     391            // Console.traceln(Level.INFO, String.format("biggest distances: "+ biggest)); 
     392 
     393            // 4. run fastmap for 2 dimensions on the distance matrix 
     394            Fastmap FMAP = new Fastmap(2); 
     395            FMAP.setDistmat(distmat); 
     396            FMAP.calculate(); 
     397 
     398            cpivotindices = FMAP.getPivots(); 
     399 
     400            double[][] X = FMAP.getX(); 
     401            distmat = new double[0][0]; 
     402            System.gc(); 
     403 
     404            // quadtree payload generation 
     405            ArrayList<QuadTreePayload<Instance>> qtp = new ArrayList<QuadTreePayload<Instance>>(); 
     406 
     407            // we need these for the sizes of the quadrants 
     408            double[] big = 
     409                { 0, 0 }; 
     410            double[] small = 
     411                { Double.MAX_VALUE, Double.MAX_VALUE }; 
     412 
     413            // set quadtree payload values and get max and min x and y values for size 
     414            for (int i = 0; i < X.length; i++) { 
     415                if (X[i][0] >= big[0]) { 
     416                    big[0] = X[i][0]; 
     417                } 
     418                if (X[i][1] >= big[1]) { 
     419                    big[1] = X[i][1]; 
     420                } 
     421                if (X[i][0] <= small[0]) { 
     422                    small[0] = X[i][0]; 
     423                } 
     424                if (X[i][1] <= small[1]) { 
     425                    small[1] = X[i][1]; 
     426                } 
     427                QuadTreePayload<Instance> tmp = 
     428                    new QuadTreePayload<Instance>(X[i][0], X[i][1], train2.get(i)); 
     429                qtp.add(tmp); 
     430            } 
     431 
     432            // Console.traceln(Level.INFO, 
     433            // String.format("size for cluster ("+small[0]+","+small[1]+") - ("+big[0]+","+big[1]+")")); 
     434 
     435            // 5. generate quadtree 
     436            QuadTree TREE = new QuadTree(null, qtp); 
     437            QuadTree.size = train.size(); 
     438            QuadTree.alpha = Math.sqrt(train.size()); 
     439            QuadTree.ccluster = new ArrayList<ArrayList<QuadTreePayload<Instance>>>(); 
     440            QuadTree.csize = new HashMap<Integer, ArrayList<Double[][]>>(); 
     441 
     442            // Console.traceln(Level.INFO, String.format("Generate QuadTree with "+ QuadTree.size + 
     443            // " size, Alpha: "+ QuadTree.alpha+ "")); 
     444 
     445            // set the size and then split the tree recursively at the median value for x, y 
     446            TREE.setSize(new double[] 
     447                { small[0], big[0] }, new double[] 
     448                { small[1], big[1] }); 
     449 
     450            // recursive split und grid clustering eher static 
     451            TREE.recursiveSplit(TREE); 
     452 
     453            // generate list of nodes sorted by density (childs only) 
     454            ArrayList<QuadTree> l = new ArrayList<QuadTree>(TREE.getList(TREE)); 
     455 
     456            // recursive grid clustering (tree pruning), the values are stored in ccluster 
     457            TREE.gridClustering(l); 
     458 
     459            // wir iterieren durch die cluster und sammeln uns die instanzen daraus 
     460            // ctraindata.clear(); 
     461            for (int i = 0; i < QuadTree.ccluster.size(); i++) { 
     462                ArrayList<QuadTreePayload<Instance>> current = QuadTree.ccluster.get(i); 
     463 
     464                // i is the clusternumber 
     465                // we only allow clusters with Instances > ALPHA, other clusters are not considered! 
     466                // if(current.size() > QuadTree.alpha) { 
     467                if (current.size() > 4) { 
     468                    for (int j = 0; j < current.size(); j++) { 
     469                        if (!ctraindata.containsKey(i)) { 
     470                            ctraindata.put(i, new Instances(train2)); 
     471                            ctraindata.get(i).delete(); 
     472                        } 
     473                        ctraindata.get(i).add(current.get(j).getInst()); 
     474                    } 
     475                } 
     476                else { 
     477                    Console.traceln(Level.INFO, 
     478                                    String.format("drop cluster, only: " + current.size() + 
     479                                        " instances")); 
     480                } 
     481            } 
     482 
     483            // here we keep things we need later on 
     484            // QuadTree sizes for later use (matching new instances) 
     485            this.csize = new HashMap<Integer, ArrayList<Double[][]>>(QuadTree.csize); 
     486 
     487            // pivot elements 
     488            // this.cpivots.clear(); 
     489            for (int i = 0; i < FMAP.PA[0].length; i++) { 
     490                this.cpivots.put(FMAP.PA[0][i], (Instance) train.get(FMAP.PA[0][i]).copy()); 
     491            } 
     492            for (int j = 0; j < FMAP.PA[0].length; j++) { 
     493                this.cpivots.put(FMAP.PA[1][j], (Instance) train.get(FMAP.PA[1][j]).copy()); 
     494            } 
     495 
     496            /* 
     497             * debug output int pnumber; Iterator<Integer> pivotnumber = 
     498             * cpivots.keySet().iterator(); while ( pivotnumber.hasNext() ) { pnumber = 
     499             * pivotnumber.next(); Console.traceln(Level.INFO, String.format("pivot: "+pnumber+ 
     500             * " inst: "+cpivots.get(pnumber))); } 
     501             */ 
     502 
     503            // train one classifier per cluster, we get the cluster number from the traindata 
     504            int cnumber; 
     505            Iterator<Integer> clusternumber = ctraindata.keySet().iterator(); 
     506            // cclassifier.clear(); 
     507 
     508            // int traindata_count = 0; 
     509            while (clusternumber.hasNext()) { 
     510                cnumber = clusternumber.next(); 
     511                cclassifier.put(cnumber, setupClassifier()); // this is the classifier used for the 
     512                                                             // cluster 
     513                cclassifier.get(cnumber).buildClassifier(ctraindata.get(cnumber)); 
     514                // Console.traceln(Level.INFO, String.format("classifier in cluster "+cnumber)); 
     515                // traindata_count += ctraindata.get(cnumber).size(); 
     516                // Console.traceln(Level.INFO, 
     517                // String.format("building classifier in cluster "+cnumber +"  with "+ 
     518                // ctraindata.get(cnumber).size() +" traindata instances")); 
     519            } 
     520 
     521            // add all traindata 
     522            // Console.traceln(Level.INFO, String.format("traindata in all clusters: " + 
     523            // traindata_count)); 
     524        } 
     525    } 
     526 
     527    /** 
     528     * Payload for the QuadTree. x and y are the calculated Fastmap values. T is a weka instance. 
     529     */ 
     530    public class QuadTreePayload<T> { 
     531 
     532        public double x; 
     533        public double y; 
     534        private T inst; 
     535 
     536        public QuadTreePayload(double x, double y, T value) { 
     537            this.x = x; 
     538            this.y = y; 
     539            this.inst = value; 
     540        } 
     541 
     542        public T getInst() { 
     543            return this.inst; 
     544        } 
     545    } 
     546 
     547    /** 
     548     * Fastmap implementation 
     549     *  
     550     * Faloutsos, C., & Lin, K. I. (1995). FastMap: A fast algorithm for indexing, data-mining and 
     551     * visualization of traditional and multimedia datasets (Vol. 24, No. 2, pp. 163-174). ACM. 
     552     */ 
     553    public class Fastmap { 
     554 
     555        /* N x k Array, at the end, the i-th row will be the image of the i-th object */ 
     556        private double[][] X; 
     557 
     558        /* 2 x k pivot Array one pair per recursive call */ 
     559        private int[][] PA; 
     560 
     561        /* Objects we got (distance matrix) */ 
     562        private double[][] O; 
     563 
     564        /* column of X currently updated (also the dimension) */ 
     565        private int col = 0; 
     566 
     567        /* number of dimensions we want */ 
     568        private int target_dims = 0; 
     569 
     570        // if we already have the pivot elements 
     571        private boolean pivot_set = false; 
     572 
     573        public Fastmap(int k) { 
     574            this.target_dims = k; 
     575        } 
     576 
     577        /** 
     578         * Sets the distance matrix and params that depend on this 
     579         *  
     580         * @param O 
     581         */ 
     582        public void setDistmat(double[][] O) { 
     583            this.O = O; 
     584            int N = O.length; 
     585            this.X = new double[N][this.target_dims]; 
     586            this.PA = new int[2][this.target_dims]; 
     587        } 
     588 
     589        /** 
     590         * Set pivot elements, we need that to classify instances after the calculation is complete 
     591         * (because we then want to reuse only the pivot elements). 
     592         *  
     593         * @param pi 
     594         */ 
     595        public void setPivots(int[][] pi) { 
     596            this.pivot_set = true; 
     597            this.PA = pi; 
     598        } 
     599 
     600        /** 
     601         * Return the pivot elements that were chosen during the calculation 
     602         *  
     603         * @return 
     604         */ 
     605        public int[][] getPivots() { 
     606            return this.PA; 
     607        } 
     608 
     609        /** 
     610         * The distance function for euclidean distance 
     611         *  
     612         * Acts according to equation 4 of the fastmap paper 
     613         *  
     614         * @param x 
     615         *            x index of x image (if k==0 x object) 
     616         * @param y 
     617         *            y index of y image (if k==0 y object) 
     618         * @param kdimensionality 
     619         * @return distance 
     620         */ 
     621        private double dist(int x, int y, int k) { 
     622 
     623            // basis is object distance, we get this from our distance matrix 
     624            double tmp = this.O[x][y] * this.O[x][y]; 
     625 
     626            // decrease by projections 
     627            for (int i = 0; i < k; i++) { 
     628                double tmp2 = (this.X[x][i] - this.X[y][i]); 
     629                tmp -= tmp2 * tmp2; 
     630            } 
     631 
     632            return Math.abs(tmp); 
     633        } 
     634 
     635        /** 
     636         * Find the object farthest from the given index This method is a helper Method for 
     637         * findDistandObjects 
     638         *  
     639         * @param index 
     640         *            of the object 
     641         * @return index of the farthest object from the given index 
     642         */ 
     643        private int findFarthest(int index) { 
     644            double furthest = Double.MIN_VALUE; 
     645            int ret = 0; 
     646 
     647            for (int i = 0; i < O.length; i++) { 
     648                double dist = this.dist(i, index, this.col); 
     649                if (i != index && dist > furthest) { 
     650                    furthest = dist; 
     651                    ret = i; 
     652                } 
     653            } 
     654            return ret; 
     655        } 
     656 
     657        /** 
     658         * Finds the pivot objects 
     659         *  
     660         * This method is basically algorithm 1 of the fastmap paper. 
     661         *  
     662         * @return 2 indexes of the choosen pivot objects 
     663         */ 
     664        private int[] findDistantObjects() { 
     665            // 1. choose object randomly 
     666            Random r = new Random(); 
     667            int obj = r.nextInt(this.O.length); 
     668 
     669            // 2. find farthest object from randomly chosen object 
     670            int idx1 = this.findFarthest(obj); 
     671 
     672            // 3. find farthest object from previously farthest object 
     673            int idx2 = this.findFarthest(idx1); 
     674 
     675            return new int[] 
     676                { idx1, idx2 }; 
     677        } 
     678 
     679        /** 
     680         * Calculates the new k-vector values (projections) 
     681         *  
     682         * This is basically algorithm 2 of the fastmap paper. We just added the possibility to 
     683         * pre-set the pivot elements because we need to classify single instances after the 
     684         * computation is already done. 
     685         *  
     686         * @param dims 
     687         *            dimensionality 
     688         */ 
     689        public void calculate() { 
     690 
     691            for (int k = 0; k < this.target_dims; k++) { 
     692                // 2) choose pivot objects 
     693                if (!this.pivot_set) { 
     694                    int[] pivots = this.findDistantObjects(); 
     695 
     696                    // 3) record ids of pivot objects 
     697                    this.PA[0][this.col] = pivots[0]; 
     698                    this.PA[1][this.col] = pivots[1]; 
     699                } 
     700 
     701                // 4) inter object distances are zero (this.X is initialized with 0 so we just 
     702                // continue) 
     703                if (this.dist(this.PA[0][this.col], this.PA[1][this.col], this.col) == 0) { 
     704                    continue; 
     705                } 
     706 
     707                // 5) project the objects on the line between the pivots 
     708                double dxy = this.dist(this.PA[0][this.col], this.PA[1][this.col], this.col); 
     709                for (int i = 0; i < this.O.length; i++) { 
     710 
     711                    double dix = this.dist(i, this.PA[0][this.col], this.col); 
     712                    double diy = this.dist(i, this.PA[1][this.col], this.col); 
     713 
     714                    double tmp = (dix + dxy - diy) / (2 * Math.sqrt(dxy)); 
     715 
     716                    // save the projection 
     717                    this.X[i][this.col] = tmp; 
     718                } 
     719 
     720                this.col += 1; 
     721            } 
     722        } 
     723 
     724        /** 
     725         * returns the result matrix of the projections 
     726         *  
     727         * @return calculated result 
     728         */ 
     729        public double[][] getX() { 
     730            return this.X; 
     731        } 
     732    } 
    669733} 
Note: See TracChangeset for help on using the changeset viewer.