Ignore:
Timestamp:
09/24/15 10:59:05 (9 years ago)
Author:
sherbold
Message:
  • formatted code and added copyrights
Location:
trunk/CrossPare/src/de/ugoe/cs/cpdp/training
Files:
12 edited

Legend:

Unmodified
Added
Removed
  • trunk/CrossPare/src/de/ugoe/cs/cpdp/training/FixClass.java

    r31 r41  
     1// Copyright 2015 Georg-August-Universität Göttingen, Germany 
     2// 
     3//   Licensed under the Apache License, Version 2.0 (the "License"); 
     4//   you may not use this file except in compliance with the License. 
     5//   You may obtain a copy of the License at 
     6// 
     7//       http://www.apache.org/licenses/LICENSE-2.0 
     8// 
     9//   Unless required by applicable law or agreed to in writing, software 
     10//   distributed under the License is distributed on an "AS IS" BASIS, 
     11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
     12//   See the License for the specific language governing permissions and 
     13//   limitations under the License. 
     14 
    115package de.ugoe.cs.cpdp.training; 
    216 
     
    1428 * @author Steffen Herbold 
    1529 */ 
    16 public class FixClass extends AbstractClassifier implements ITrainingStrategy, IWekaCompatibleTrainer { 
     30public class FixClass extends AbstractClassifier implements ITrainingStrategy, 
     31    IWekaCompatibleTrainer 
     32{ 
    1733 
    18         private static final long serialVersionUID = 1L; 
     34    private static final long serialVersionUID = 1L; 
    1935 
    20         private double fixedClassValue = 0.0d; 
     36    private double fixedClassValue = 0.0d; 
    2137 
    22         /** 
    23         * Returns default capabilities of the classifier. 
    24         *  
    25         * @return the capabilities of this classifier 
    26         */ 
    27         @Override 
    28         public Capabilities getCapabilities() { 
    29                 Capabilities result = super.getCapabilities(); 
    30                 result.disableAll(); 
     38    /** 
     39    * Returns default capabilities of the classifier. 
     40    *  
     41    * @return the capabilities of this classifier 
     42    */ 
     43    @Override 
     44    public Capabilities getCapabilities() { 
     45        Capabilities result = super.getCapabilities(); 
     46        result.disableAll(); 
    3147 
    32                 // attributes 
    33                 result.enable(Capability.NOMINAL_ATTRIBUTES); 
    34                 result.enable(Capability.NUMERIC_ATTRIBUTES); 
    35                 result.enable(Capability.DATE_ATTRIBUTES); 
    36                 result.enable(Capability.STRING_ATTRIBUTES); 
    37                 result.enable(Capability.RELATIONAL_ATTRIBUTES); 
    38                 result.enable(Capability.MISSING_VALUES); 
     48        // attributes 
     49        result.enable(Capability.NOMINAL_ATTRIBUTES); 
     50        result.enable(Capability.NUMERIC_ATTRIBUTES); 
     51        result.enable(Capability.DATE_ATTRIBUTES); 
     52        result.enable(Capability.STRING_ATTRIBUTES); 
     53        result.enable(Capability.RELATIONAL_ATTRIBUTES); 
     54        result.enable(Capability.MISSING_VALUES); 
    3955 
    40                 // class 
    41                 result.enable(Capability.NOMINAL_CLASS); 
    42                 result.enable(Capability.NUMERIC_CLASS); 
    43                 result.enable(Capability.MISSING_CLASS_VALUES); 
     56        // class 
     57        result.enable(Capability.NOMINAL_CLASS); 
     58        result.enable(Capability.NUMERIC_CLASS); 
     59        result.enable(Capability.MISSING_CLASS_VALUES); 
    4460 
    45                 // instances 
    46                 result.setMinimumNumberInstances(0); 
     61        // instances 
     62        result.setMinimumNumberInstances(0); 
    4763 
    48                 return result; 
    49         } 
     64        return result; 
     65    } 
    5066 
    51         @Override 
    52         public void setOptions(String[] options) throws Exception { 
    53                 fixedClassValue = Double.parseDouble(Utils.getOption('C', options)); 
    54         } 
     67    @Override 
     68    public void setOptions(String[] options) throws Exception { 
     69        fixedClassValue = Double.parseDouble(Utils.getOption('C', options)); 
     70    } 
    5571 
    56         @Override 
    57         public double classifyInstance(Instance instance) { 
    58                 return fixedClassValue; 
    59         } 
     72    @Override 
     73    public double classifyInstance(Instance instance) { 
     74        return fixedClassValue; 
     75    } 
    6076 
    61         @Override 
    62         public void buildClassifier(Instances traindata) throws Exception { 
    63                 // do nothing 
    64         } 
     77    @Override 
     78    public void buildClassifier(Instances traindata) throws Exception { 
     79        // do nothing 
     80    } 
    6581 
    66         @Override 
    67         public void setParameter(String parameters) { 
    68                 try { 
    69                         this.setOptions(parameters.split(" ")); 
    70                 } catch (Exception e) { 
    71                         e.printStackTrace(); 
    72                 }                
    73         } 
     82    @Override 
     83    public void setParameter(String parameters) { 
     84        try { 
     85            this.setOptions(parameters.split(" ")); 
     86        } 
     87        catch (Exception e) { 
     88            e.printStackTrace(); 
     89        } 
     90    } 
    7491 
    75         @Override 
    76         public void apply(Instances traindata) { 
    77                 // do nothing! 
    78         } 
     92    @Override 
     93    public void apply(Instances traindata) { 
     94        // do nothing! 
     95    } 
    7996 
    80         @Override 
    81         public String getName() { 
    82                 return "FixClass"; 
    83         } 
     97    @Override 
     98    public String getName() { 
     99        return "FixClass"; 
     100    } 
    84101 
    85         @Override 
    86         public Classifier getClassifier() { 
    87                 return this; 
    88         } 
     102    @Override 
     103    public Classifier getClassifier() { 
     104        return this; 
     105    } 
    89106 
    90107} 
  • trunk/CrossPare/src/de/ugoe/cs/cpdp/training/ISetWiseTrainingStrategy.java

    r2 r41  
     1// Copyright 2015 Georg-August-Universität Göttingen, Germany 
     2// 
     3//   Licensed under the Apache License, Version 2.0 (the "License"); 
     4//   you may not use this file except in compliance with the License. 
     5//   You may obtain a copy of the License at 
     6// 
     7//       http://www.apache.org/licenses/LICENSE-2.0 
     8// 
     9//   Unless required by applicable law or agreed to in writing, software 
     10//   distributed under the License is distributed on an "AS IS" BASIS, 
     11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
     12//   See the License for the specific language governing permissions and 
     13//   limitations under the License. 
     14 
    115package de.ugoe.cs.cpdp.training; 
    216 
     
    721// Bagging Strategy: separate models for each training data set 
    822public interface ISetWiseTrainingStrategy extends ITrainer { 
    9          
    10         void apply(SetUniqueList<Instances> traindataSet); 
    11          
    12         String getName(); 
     23 
     24    void apply(SetUniqueList<Instances> traindataSet); 
     25 
     26    String getName(); 
    1327} 
  • trunk/CrossPare/src/de/ugoe/cs/cpdp/training/ITrainer.java

    r2 r41  
     1// Copyright 2015 Georg-August-Universität Göttingen, Germany 
     2// 
     3//   Licensed under the Apache License, Version 2.0 (the "License"); 
     4//   you may not use this file except in compliance with the License. 
     5//   You may obtain a copy of the License at 
     6// 
     7//       http://www.apache.org/licenses/LICENSE-2.0 
     8// 
     9//   Unless required by applicable law or agreed to in writing, software 
     10//   distributed under the License is distributed on an "AS IS" BASIS, 
     11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
     12//   See the License for the specific language governing permissions and 
     13//   limitations under the License. 
     14 
    115package de.ugoe.cs.cpdp.training; 
    216 
  • trunk/CrossPare/src/de/ugoe/cs/cpdp/training/ITrainingStrategy.java

    r6 r41  
     1// Copyright 2015 Georg-August-Universität Göttingen, Germany 
     2// 
     3//   Licensed under the Apache License, Version 2.0 (the "License"); 
     4//   you may not use this file except in compliance with the License. 
     5//   You may obtain a copy of the License at 
     6// 
     7//       http://www.apache.org/licenses/LICENSE-2.0 
     8// 
     9//   Unless required by applicable law or agreed to in writing, software 
     10//   distributed under the License is distributed on an "AS IS" BASIS, 
     11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
     12//   See the License for the specific language governing permissions and 
     13//   limitations under the License. 
     14 
    115package de.ugoe.cs.cpdp.training; 
    216 
     
    418 
    519public interface ITrainingStrategy extends ITrainer { 
    6          
    7         void apply(Instances traindata); 
    8          
    9         String getName(); 
     20 
     21    void apply(Instances traindata); 
     22 
     23    String getName(); 
    1024} 
  • trunk/CrossPare/src/de/ugoe/cs/cpdp/training/IWekaCompatibleTrainer.java

    r24 r41  
     1// Copyright 2015 Georg-August-Universität Göttingen, Germany 
     2// 
     3//   Licensed under the Apache License, Version 2.0 (the "License"); 
     4//   you may not use this file except in compliance with the License. 
     5//   You may obtain a copy of the License at 
     6// 
     7//       http://www.apache.org/licenses/LICENSE-2.0 
     8// 
     9//   Unless required by applicable law or agreed to in writing, software 
     10//   distributed under the License is distributed on an "AS IS" BASIS, 
     11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
     12//   See the License for the specific language governing permissions and 
     13//   limitations under the License. 
     14 
    115package de.ugoe.cs.cpdp.training; 
    216 
     
    418 
    519public interface IWekaCompatibleTrainer extends ITrainer { 
    6          
    7         Classifier getClassifier(); 
    8          
    9         String getName(); 
     20 
     21    Classifier getClassifier(); 
     22 
     23    String getName(); 
    1024} 
  • trunk/CrossPare/src/de/ugoe/cs/cpdp/training/QuadTree.java

    r23 r41  
     1// Copyright 2015 Georg-August-Universität Göttingen, Germany 
     2// 
     3//   Licensed under the Apache License, Version 2.0 (the "License"); 
     4//   you may not use this file except in compliance with the License. 
     5//   You may obtain a copy of the License at 
     6// 
     7//       http://www.apache.org/licenses/LICENSE-2.0 
     8// 
     9//   Unless required by applicable law or agreed to in writing, software 
     10//   distributed under the License is distributed on an "AS IS" BASIS, 
     11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
     12//   See the License for the specific language governing permissions and 
     13//   limitations under the License. 
     14 
    115package de.ugoe.cs.cpdp.training; 
    216 
     
    1226 * QuadTree implementation 
    1327 *  
    14  * QuadTree gets a list of instances and then recursively split them into 4 childs 
    15  * For this it uses the median of the 2 values x,y 
     28 * QuadTree gets a list of instances and then recursively split them into 4 childs For this it uses 
     29 * the median of the 2 values x,y 
    1630 */ 
    1731public class QuadTree { 
    18          
    19         /* 1 parent or null */ 
    20         private QuadTree parent = null; 
    21          
    22         /* 4 childs, 1 per quadrant */ 
    23         private QuadTree child_nw; 
    24         private QuadTree child_ne; 
    25         private QuadTree child_se; 
    26         private QuadTree child_sw; 
    27          
    28         /* list (only helps with generation of list of childs!) */ 
    29         private ArrayList<QuadTree> l = new ArrayList<QuadTree>(); 
    30          
    31         /* level only used for debugging */ 
    32         public int level = 0; 
    33          
    34         /* size of the quadrant */ 
    35         private double[] x; 
    36         private double[] y; 
    37          
    38         public static boolean verbose = false; 
    39         public static int size = 0; 
    40         public static double alpha = 0; 
    41          
    42         /* cluster payloads */ 
    43         public static ArrayList<ArrayList<QuadTreePayload<Instance>>> ccluster = new ArrayList<ArrayList<QuadTreePayload<Instance>>>(); 
    44          
    45         /* cluster sizes (index is cluster number, arraylist is list of boxes (x0,y0,x1,y1) */  
    46         public static HashMap<Integer, ArrayList<Double[][]>> csize = new HashMap<Integer, ArrayList<Double[][]>>(); 
    47          
    48         /* payload of this instance */ 
    49         private ArrayList<QuadTreePayload<Instance>> payload; 
    50  
    51          
    52         public QuadTree(QuadTree parent, ArrayList<QuadTreePayload<Instance>> payload) { 
    53                 this.parent = parent; 
    54                 this.payload = payload; 
    55         } 
    56          
    57          
    58         public String toString() { 
    59                 String n = ""; 
    60                 if(this.parent == null) { 
    61                         n += "rootnode "; 
    62                 } 
    63                 String level = new String(new char[this.level]).replace("\0", "-"); 
    64                 n += level + " instances: " + this.getNumbers(); 
    65                 return n; 
    66         } 
    67          
    68         /** 
    69          * Returns the payload, used for clustering 
    70          * in the clustering list we only have children with paylod 
    71          *  
    72          * @return payload 
    73          */ 
    74         public ArrayList<QuadTreePayload<Instance>> getPayload() { 
    75                 return this.payload; 
    76         } 
    77          
    78         /** 
    79          * Calculate the density of this quadrant 
    80          *  
    81          * density = number of instances / global size (all instances) 
    82          *  
    83          * @return density 
    84          */ 
    85         public double getDensity() { 
    86                 double dens = 0; 
    87                 dens = (double)this.getNumbers() / QuadTree.size; 
    88                 return dens; 
    89         } 
    90          
    91         public void setSize(double[] x, double[] y){ 
    92                 this.x = x; 
    93                 this.y = y; 
    94         } 
    95          
    96         public double[][] getSize() { 
    97                 return new double[][] {this.x, this.y};  
    98         } 
    99          
    100         public Double[][] getSizeDouble() { 
    101                 Double[] tmpX = new Double[2]; 
    102                 Double[] tmpY = new Double[2]; 
    103                  
    104                 tmpX[0] = this.x[0]; 
    105                 tmpX[1] = this.x[1]; 
    106                  
    107                 tmpY[0] = this.y[0]; 
    108                 tmpY[1] = this.y[1]; 
    109                  
    110                 return new Double[][] {tmpX, tmpY};  
    111         } 
    112          
    113         /** 
    114          * TODO: DRY, median ist immer dasselbe 
    115          *   
    116          * @return median for x 
    117          */ 
    118         private double getMedianForX() { 
    119                 double med_x =0 ; 
    120                  
    121                 Collections.sort(this.payload, new Comparator<QuadTreePayload<Instance>>() { 
    122                 @Override 
    123                 public int compare(QuadTreePayload<Instance> x1, QuadTreePayload<Instance> x2) { 
    124                     return Double.compare(x1.x, x2.x); 
    125                 } 
    126             }); 
    127  
    128                 if(this.payload.size() % 2 == 0) { 
    129                         int mid = this.payload.size() / 2; 
    130                         med_x = (this.payload.get(mid).x + this.payload.get(mid+1).x) / 2; 
    131                 }else { 
    132                         int mid = this.payload.size() / 2; 
    133                         med_x = this.payload.get(mid).x; 
    134                 } 
    135                  
    136                 if(QuadTree.verbose) { 
    137                         System.out.println("sorted:"); 
    138                         for(int i = 0; i < this.payload.size(); i++) { 
    139                                 System.out.print(""+this.payload.get(i).x+","); 
    140                         } 
    141                         System.out.println("median x: " + med_x); 
    142                 } 
    143                 return med_x; 
    144         } 
    145          
    146         private double getMedianForY() { 
    147                 double med_y =0 ; 
    148                  
    149                 Collections.sort(this.payload, new Comparator<QuadTreePayload<Instance>>() { 
    150                 @Override 
    151                 public int compare(QuadTreePayload<Instance> y1, QuadTreePayload<Instance> y2) { 
    152                     return Double.compare(y1.y, y2.y); 
    153                 } 
    154             }); 
    155                  
    156                 if(this.payload.size() % 2 == 0) { 
    157                         int mid = this.payload.size() / 2; 
    158                         med_y = (this.payload.get(mid).y + this.payload.get(mid+1).y) / 2; 
    159                 }else { 
    160                         int mid = this.payload.size() / 2; 
    161                         med_y = this.payload.get(mid).y; 
    162                 } 
    163                  
    164                 if(QuadTree.verbose) { 
    165                         System.out.println("sorted:"); 
    166                         for(int i = 0; i < this.payload.size(); i++) { 
    167                                 System.out.print(""+this.payload.get(i).y+","); 
    168                         } 
    169                         System.out.println("median y: " + med_y); 
    170                 } 
    171                 return med_y; 
    172         } 
    173          
    174         /** 
    175          * Reurns the number of instances in the payload 
    176          *  
    177          * @return int number of instances 
    178          */ 
    179         public int getNumbers() { 
    180                 int number = 0; 
    181                 if(this.payload != null) { 
    182                         number = this.payload.size(); 
    183                 } 
    184                 return number; 
    185         } 
    186          
    187         /** 
    188          * Calculate median values of payload for x, y and split into 4 sectors 
    189          *  
    190          * @return Array of QuadTree nodes (4 childs) 
    191          * @throws Exception if we would run into an recursive loop 
    192          */ 
    193         public QuadTree[] split() throws Exception { 
    194                                  
    195                 double medx = this.getMedianForX(); 
    196                 double medy = this.getMedianForY(); 
    197                  
    198                 // Payload lists for each child 
    199                 ArrayList<QuadTreePayload<Instance>> nw = new ArrayList<QuadTreePayload<Instance>>(); 
    200                 ArrayList<QuadTreePayload<Instance>> sw = new ArrayList<QuadTreePayload<Instance>>(); 
    201                 ArrayList<QuadTreePayload<Instance>> ne = new ArrayList<QuadTreePayload<Instance>>(); 
    202                 ArrayList<QuadTreePayload<Instance>> se = new ArrayList<QuadTreePayload<Instance>>(); 
    203                  
    204                 // sort the payloads to new payloads 
    205                 // here we have the problem that payloads with the same values are sorted 
    206                 // into the same slots and it could happen that medx and medy = size_x[1] and size_y[1] 
    207                 // in that case we would have an endless loop 
    208                 for(int i=0; i < this.payload.size(); i++) { 
    209                          
    210                         QuadTreePayload<Instance> item = this.payload.get(i); 
    211                          
    212                         // north west 
    213                         if(item.x <= medx && item.y >= medy) { 
    214                                 nw.add(item); 
    215                         } 
    216                          
    217                         // south west 
    218                         else if(item.x <= medx && item.y <= medy) { 
    219                                 sw.add(item); 
    220                         } 
    221  
    222                         // north east 
    223                         else if(item.x >= medx && item.y >= medy) { 
    224                                 ne.add(item); 
    225                         } 
    226                          
    227                         // south east 
    228                         else if(item.x >= medx && item.y <= medy) { 
    229                                 se.add(item); 
    230                         } 
    231                 } 
    232                  
    233                 // if we assign one child a payload equal to our own (see problem above) 
    234                 // we throw an exceptions which stops the recursion on this node 
    235                 if(nw.equals(this.payload)) { 
    236                         throw new Exception("payload equal"); 
    237                 } 
    238                 if(sw.equals(this.payload)) { 
    239                         throw new Exception("payload equal"); 
    240                 } 
    241                 if(ne.equals(this.payload)) { 
    242                         throw new Exception("payload equal"); 
    243                 } 
    244                 if(se.equals(this.payload)) { 
    245                         throw new Exception("payload equal"); 
    246                 } 
    247  
    248                 this.child_nw = new QuadTree(this, nw); 
    249                 this.child_nw.setSize(new double[] {this.x[0], medx}, new double[] {medy, this.y[1]}); 
    250                 this.child_nw.level = this.level + 1; 
    251                  
    252                 this.child_sw = new QuadTree(this, sw); 
    253                 this.child_sw.setSize(new double[] {this.x[0], medx}, new double[] {this.y[0], medy}); 
    254                 this.child_sw.level = this.level + 1; 
    255                  
    256                 this.child_ne = new QuadTree(this, ne); 
    257                 this.child_ne.setSize(new double[] {medx, this.x[1]}, new double[] {medy, this.y[1]}); 
    258                 this.child_ne.level = this.level + 1; 
    259                  
    260                 this.child_se = new QuadTree(this, se); 
    261                 this.child_se.setSize(new double[] {medx, this.x[1]}, new double[] {this.y[0], medy}); 
    262                 this.child_se.level = this.level + 1;    
    263                  
    264                 this.payload = null; 
    265                 return new QuadTree[] {this.child_nw, this.child_ne, this.child_se, this.child_sw}; 
    266         } 
    267          
    268         /**  
    269          * TODO: static method 
    270          *  
    271          * @param q 
    272          */ 
    273         public void recursiveSplit(QuadTree q) { 
    274                 if(QuadTree.verbose) { 
    275                         System.out.println("splitting: "+ q); 
    276                 } 
    277                 if(q.getNumbers() < QuadTree.alpha) { 
    278                         return; 
    279                 }else{ 
    280                         // exception is thrown if we would run into an endless loop (see comments in split()) 
    281                         try { 
    282                                 QuadTree[] childs = q.split();                   
    283                                 this.recursiveSplit(childs[0]); 
    284                                 this.recursiveSplit(childs[1]); 
    285                                 this.recursiveSplit(childs[2]); 
    286                                 this.recursiveSplit(childs[3]); 
    287                         }catch(Exception e) { 
    288                                 return; 
    289                         } 
    290                 } 
    291         } 
    292          
    293         /** 
    294          * returns an list of childs sorted by density 
    295          *  
    296          * @param q QuadTree 
    297          * @return list of QuadTrees 
    298          */ 
    299         private void generateList(QuadTree q) { 
    300                  
    301                 // we only have all childs or none at all 
    302                 if(q.child_ne == null) { 
    303                         this.l.add(q); 
    304                 } 
    305                  
    306                 if(q.child_ne != null) { 
    307                         this.generateList(q.child_ne); 
    308                 } 
    309                 if(q.child_nw != null) { 
    310                         this.generateList(q.child_nw); 
    311                 } 
    312                 if(q.child_se != null) { 
    313                         this.generateList(q.child_se); 
    314                 } 
    315                 if(q.child_sw != null) { 
    316                         this.generateList(q.child_sw); 
    317                 } 
    318         } 
    319          
    320         /** 
    321          * Checks if passed QuadTree is neighboring to us 
    322          *  
    323          * @param q QuadTree 
    324          * @return true if passed QuadTree is a neighbor 
    325          */ 
    326         public boolean isNeighbour(QuadTree q) { 
    327                 boolean is_neighbour = false; 
    328                  
    329                 double[][] our_size = this.getSize(); 
    330                 double[][] new_size = q.getSize(); 
    331                  
    332                 // X is i=0, Y is i=1 
    333                 for(int i =0; i < 2; i++) { 
    334                         // we are smaller than q 
    335                         // -------------- q 
    336                         //    ------- we 
    337                         if(our_size[i][0] >= new_size[i][0] && our_size[i][1] <= new_size[i][1]) { 
    338                                 is_neighbour = true; 
    339                         } 
    340                         // we overlap with q at some point 
    341                         //a) ---------------q 
    342                         //         ----------- we 
    343                         //b)     --------- q 
    344                         // --------- we 
    345                         if((our_size[i][0] >= new_size[i][0] && our_size[i][0] <= new_size[i][1]) || 
    346                            (our_size[i][1] >= new_size[i][0] && our_size[i][1] <= new_size[i][1])) { 
    347                                 is_neighbour = true; 
    348                         } 
    349                         // we are larger than q 
    350                         //    ---- q 
    351                         // ---------- we 
    352                         if(our_size[i][1] >= new_size[i][1] && our_size[i][0] <= new_size[i][0]) { 
    353                                 is_neighbour = true; 
    354                         } 
    355                 } 
    356                  
    357                 if(is_neighbour && QuadTree.verbose) { 
    358                         System.out.println(this + " neighbour of: " + q); 
    359                 } 
    360                  
    361                 return is_neighbour; 
    362         } 
    363          
    364         /** 
    365          * Perform pruning and clustering of the quadtree 
    366          *  
    367          * Pruning according to: 
    368          * Tim Menzies, Andrew Butcher, David Cok, Andrian Marcus, Lucas Layman,  
    369          * Forrest Shull, Burak Turhan, Thomas Zimmermann,  
    370          * "Local versus Global Lessons for Defect Prediction and Effort Estimation,"  
    371          * IEEE Transactions on Software Engineering, vol. 39, no. 6, pp. 822-834, June, 2013   
    372          *   
    373          * 1) get list of leaf quadrants 
    374          * 2) sort by their density 
    375          * 3) set stop_rule to 0.5 * highest Density in the list 
    376          * 4) merge all nodes with a density > stop_rule to the new cluster and remove all from list 
    377          * 5) repeat 
    378          *  
    379          * @param q List of QuadTree (children only) 
    380          */ 
    381         public void gridClustering(ArrayList<QuadTree> list) { 
    382                  
    383                 if(list.size() == 0) { 
    384                         return; 
    385                 } 
    386                  
    387                 double stop_rule; 
    388                 QuadTree biggest; 
    389                 QuadTree current; 
    390                  
    391                 // current clusterlist 
    392                 ArrayList<QuadTreePayload<Instance>> current_cluster; 
    393  
    394                 // remove list (for removal of items after scanning of the list) 
    395             ArrayList<Integer> remove = new ArrayList<Integer>(); 
    396                  
    397                 // 1. find biggest, and add it 
    398             biggest = list.get(list.size()-1); 
    399             stop_rule = biggest.getDensity() * 0.5; 
    400              
    401             current_cluster = new ArrayList<QuadTreePayload<Instance>>(); 
    402             current_cluster.addAll(biggest.getPayload()); 
    403  
    404             // remove the biggest because we are starting with it 
    405             remove.add(list.size()-1); 
    406              
    407             ArrayList<Double[][]> tmpSize = new ArrayList<Double[][]>(); 
    408             tmpSize.add(biggest.getSizeDouble()); 
    409              
    410                 // check the items for their density 
    411             for(int i=list.size()-1; i >= 0; i--) { 
    412                 current = list.get(i); 
    413                  
    414                         // 2. find neighbors with correct density 
    415                 // if density > stop_rule and is_neighbour add to cluster and remove from list 
    416                 if(current.getDensity() > stop_rule && !current.equals(biggest) && current.isNeighbour(biggest)) { 
    417                         current_cluster.addAll(current.getPayload()); 
    418                          
    419                         // add it to remove list (we cannot remove it inside the loop because it would move the index) 
    420                         remove.add(i); 
    421                          
    422                         // get the size 
    423                         tmpSize.add(current.getSizeDouble()); 
    424                 } 
    425                 } 
    426              
    427                 // 3. remove our removal candidates from the list 
    428             for(Integer item: remove) { 
    429                 list.remove((int)item); 
    430             } 
    431              
    432                 // 4. add to cluster 
    433             QuadTree.ccluster.add(current_cluster); 
    434                  
    435             // 5. add sizes of our current (biggest) this adds a number of sizes (all QuadTree Instances belonging to this cluster) 
    436             // we need that to classify test instances to a cluster later 
    437             Integer cnumber = new Integer(QuadTree.ccluster.size()-1); 
    438             if(QuadTree.csize.containsKey(cnumber) == false) { 
    439                 QuadTree.csize.put(cnumber, tmpSize); 
    440             } 
    441  
    442                 // repeat 
    443             this.gridClustering(list); 
    444         } 
    445          
    446         public void printInfo() { 
    447             System.out.println("we have " + ccluster.size() + " clusters"); 
    448              
    449             for(int i=0; i < ccluster.size(); i++) { 
    450                 System.out.println("cluster: "+i+ " size: "+ ccluster.get(i).size()); 
    451             } 
    452         } 
    453          
    454         /** 
    455          * Helper Method to get a sorted list (by density) for all 
    456          * children 
    457          *  
    458          * @param q QuadTree 
    459          * @return Sorted ArrayList of quadtrees 
    460          */ 
    461         public ArrayList<QuadTree> getList(QuadTree q) { 
    462                 this.generateList(q); 
    463                  
    464                 Collections.sort(this.l, new Comparator<QuadTree>() { 
    465                 @Override 
    466                 public int compare(QuadTree x1, QuadTree x2) { 
    467                     return Double.compare(x1.getDensity(), x2.getDensity()); 
    468                 } 
    469             }); 
    470                  
    471                 return this.l; 
    472         } 
     32 
     33    /* 1 parent or null */ 
     34    private QuadTree parent = null; 
     35 
     36    /* 4 childs, 1 per quadrant */ 
     37    private QuadTree child_nw; 
     38    private QuadTree child_ne; 
     39    private QuadTree child_se; 
     40    private QuadTree child_sw; 
     41 
     42    /* list (only helps with generation of list of childs!) */ 
     43    private ArrayList<QuadTree> l = new ArrayList<QuadTree>(); 
     44 
     45    /* level only used for debugging */ 
     46    public int level = 0; 
     47 
     48    /* size of the quadrant */ 
     49    private double[] x; 
     50    private double[] y; 
     51 
     52    public static boolean verbose = false; 
     53    public static int size = 0; 
     54    public static double alpha = 0; 
     55 
     56    /* cluster payloads */ 
     57    public static ArrayList<ArrayList<QuadTreePayload<Instance>>> ccluster = 
     58        new ArrayList<ArrayList<QuadTreePayload<Instance>>>(); 
     59 
     60    /* cluster sizes (index is cluster number, arraylist is list of boxes (x0,y0,x1,y1) */ 
     61    public static HashMap<Integer, ArrayList<Double[][]>> csize = 
     62        new HashMap<Integer, ArrayList<Double[][]>>(); 
     63 
     64    /* payload of this instance */ 
     65    private ArrayList<QuadTreePayload<Instance>> payload; 
     66 
     67    public QuadTree(QuadTree parent, ArrayList<QuadTreePayload<Instance>> payload) { 
     68        this.parent = parent; 
     69        this.payload = payload; 
     70    } 
     71 
     72    public String toString() { 
     73        String n = ""; 
     74        if (this.parent == null) { 
     75            n += "rootnode "; 
     76        } 
     77        String level = new String(new char[this.level]).replace("\0", "-"); 
     78        n += level + " instances: " + this.getNumbers(); 
     79        return n; 
     80    } 
     81 
     82    /** 
     83     * Returns the payload, used for clustering in the clustering list we only have children with 
     84     * paylod 
     85     *  
     86     * @return payload 
     87     */ 
     88    public ArrayList<QuadTreePayload<Instance>> getPayload() { 
     89        return this.payload; 
     90    } 
     91 
     92    /** 
     93     * Calculate the density of this quadrant 
     94     *  
     95     * density = number of instances / global size (all instances) 
     96     *  
     97     * @return density 
     98     */ 
     99    public double getDensity() { 
     100        double dens = 0; 
     101        dens = (double) this.getNumbers() / QuadTree.size; 
     102        return dens; 
     103    } 
     104 
     105    public void setSize(double[] x, double[] y) { 
     106        this.x = x; 
     107        this.y = y; 
     108    } 
     109 
     110    public double[][] getSize() { 
     111        return new double[][] 
     112            { this.x, this.y }; 
     113    } 
     114 
     115    public Double[][] getSizeDouble() { 
     116        Double[] tmpX = new Double[2]; 
     117        Double[] tmpY = new Double[2]; 
     118 
     119        tmpX[0] = this.x[0]; 
     120        tmpX[1] = this.x[1]; 
     121 
     122        tmpY[0] = this.y[0]; 
     123        tmpY[1] = this.y[1]; 
     124 
     125        return new Double[][] 
     126            { tmpX, tmpY }; 
     127    } 
     128 
     129    /** 
     130     * TODO: DRY, median ist immer dasselbe 
     131     *  
     132     * @return median for x 
     133     */ 
     134    private double getMedianForX() { 
     135        double med_x = 0; 
     136 
     137        Collections.sort(this.payload, new Comparator<QuadTreePayload<Instance>>() { 
     138            @Override 
     139            public int compare(QuadTreePayload<Instance> x1, QuadTreePayload<Instance> x2) { 
     140                return Double.compare(x1.x, x2.x); 
     141            } 
     142        }); 
     143 
     144        if (this.payload.size() % 2 == 0) { 
     145            int mid = this.payload.size() / 2; 
     146            med_x = (this.payload.get(mid).x + this.payload.get(mid + 1).x) / 2; 
     147        } 
     148        else { 
     149            int mid = this.payload.size() / 2; 
     150            med_x = this.payload.get(mid).x; 
     151        } 
     152 
     153        if (QuadTree.verbose) { 
     154            System.out.println("sorted:"); 
     155            for (int i = 0; i < this.payload.size(); i++) { 
     156                System.out.print("" + this.payload.get(i).x + ","); 
     157            } 
     158            System.out.println("median x: " + med_x); 
     159        } 
     160        return med_x; 
     161    } 
     162 
     163    private double getMedianForY() { 
     164        double med_y = 0; 
     165 
     166        Collections.sort(this.payload, new Comparator<QuadTreePayload<Instance>>() { 
     167            @Override 
     168            public int compare(QuadTreePayload<Instance> y1, QuadTreePayload<Instance> y2) { 
     169                return Double.compare(y1.y, y2.y); 
     170            } 
     171        }); 
     172 
     173        if (this.payload.size() % 2 == 0) { 
     174            int mid = this.payload.size() / 2; 
     175            med_y = (this.payload.get(mid).y + this.payload.get(mid + 1).y) / 2; 
     176        } 
     177        else { 
     178            int mid = this.payload.size() / 2; 
     179            med_y = this.payload.get(mid).y; 
     180        } 
     181 
     182        if (QuadTree.verbose) { 
     183            System.out.println("sorted:"); 
     184            for (int i = 0; i < this.payload.size(); i++) { 
     185                System.out.print("" + this.payload.get(i).y + ","); 
     186            } 
     187            System.out.println("median y: " + med_y); 
     188        } 
     189        return med_y; 
     190    } 
     191 
     192    /** 
     193     * Reurns the number of instances in the payload 
     194     *  
     195     * @return int number of instances 
     196     */ 
     197    public int getNumbers() { 
     198        int number = 0; 
     199        if (this.payload != null) { 
     200            number = this.payload.size(); 
     201        } 
     202        return number; 
     203    } 
     204 
     205    /** 
     206     * Calculate median values of payload for x, y and split into 4 sectors 
     207     *  
     208     * @return Array of QuadTree nodes (4 childs) 
     209     * @throws Exception 
     210     *             if we would run into an recursive loop 
     211     */ 
     212    public QuadTree[] split() throws Exception { 
     213 
     214        double medx = this.getMedianForX(); 
     215        double medy = this.getMedianForY(); 
     216 
     217        // Payload lists for each child 
     218        ArrayList<QuadTreePayload<Instance>> nw = new ArrayList<QuadTreePayload<Instance>>(); 
     219        ArrayList<QuadTreePayload<Instance>> sw = new ArrayList<QuadTreePayload<Instance>>(); 
     220        ArrayList<QuadTreePayload<Instance>> ne = new ArrayList<QuadTreePayload<Instance>>(); 
     221        ArrayList<QuadTreePayload<Instance>> se = new ArrayList<QuadTreePayload<Instance>>(); 
     222 
     223        // sort the payloads to new payloads 
     224        // here we have the problem that payloads with the same values are sorted 
     225        // into the same slots and it could happen that medx and medy = size_x[1] and size_y[1] 
     226        // in that case we would have an endless loop 
     227        for (int i = 0; i < this.payload.size(); i++) { 
     228 
     229            QuadTreePayload<Instance> item = this.payload.get(i); 
     230 
     231            // north west 
     232            if (item.x <= medx && item.y >= medy) { 
     233                nw.add(item); 
     234            } 
     235 
     236            // south west 
     237            else if (item.x <= medx && item.y <= medy) { 
     238                sw.add(item); 
     239            } 
     240 
     241            // north east 
     242            else if (item.x >= medx && item.y >= medy) { 
     243                ne.add(item); 
     244            } 
     245 
     246            // south east 
     247            else if (item.x >= medx && item.y <= medy) { 
     248                se.add(item); 
     249            } 
     250        } 
     251 
     252        // if we assign one child a payload equal to our own (see problem above) 
     253        // we throw an exceptions which stops the recursion on this node 
     254        if (nw.equals(this.payload)) { 
     255            throw new Exception("payload equal"); 
     256        } 
     257        if (sw.equals(this.payload)) { 
     258            throw new Exception("payload equal"); 
     259        } 
     260        if (ne.equals(this.payload)) { 
     261            throw new Exception("payload equal"); 
     262        } 
     263        if (se.equals(this.payload)) { 
     264            throw new Exception("payload equal"); 
     265        } 
     266 
     267        this.child_nw = new QuadTree(this, nw); 
     268        this.child_nw.setSize(new double[] 
     269            { this.x[0], medx }, new double[] 
     270            { medy, this.y[1] }); 
     271        this.child_nw.level = this.level + 1; 
     272 
     273        this.child_sw = new QuadTree(this, sw); 
     274        this.child_sw.setSize(new double[] 
     275            { this.x[0], medx }, new double[] 
     276            { this.y[0], medy }); 
     277        this.child_sw.level = this.level + 1; 
     278 
     279        this.child_ne = new QuadTree(this, ne); 
     280        this.child_ne.setSize(new double[] 
     281            { medx, this.x[1] }, new double[] 
     282            { medy, this.y[1] }); 
     283        this.child_ne.level = this.level + 1; 
     284 
     285        this.child_se = new QuadTree(this, se); 
     286        this.child_se.setSize(new double[] 
     287            { medx, this.x[1] }, new double[] 
     288            { this.y[0], medy }); 
     289        this.child_se.level = this.level + 1; 
     290 
     291        this.payload = null; 
     292        return new QuadTree[] 
     293            { this.child_nw, this.child_ne, this.child_se, this.child_sw }; 
     294    } 
     295 
     296    /** 
     297     * TODO: static method 
     298     *  
     299     * @param q 
     300     */ 
     301    public void recursiveSplit(QuadTree q) { 
     302        if (QuadTree.verbose) { 
     303            System.out.println("splitting: " + q); 
     304        } 
     305        if (q.getNumbers() < QuadTree.alpha) { 
     306            return; 
     307        } 
     308        else { 
     309            // exception is thrown if we would run into an endless loop (see comments in split()) 
     310            try { 
     311                QuadTree[] childs = q.split(); 
     312                this.recursiveSplit(childs[0]); 
     313                this.recursiveSplit(childs[1]); 
     314                this.recursiveSplit(childs[2]); 
     315                this.recursiveSplit(childs[3]); 
     316            } 
     317            catch (Exception e) { 
     318                return; 
     319            } 
     320        } 
     321    } 
     322 
     323    /** 
     324     * returns an list of childs sorted by density 
     325     *  
     326     * @param q 
     327     *            QuadTree 
     328     * @return list of QuadTrees 
     329     */ 
     330    private void generateList(QuadTree q) { 
     331 
     332        // we only have all childs or none at all 
     333        if (q.child_ne == null) { 
     334            this.l.add(q); 
     335        } 
     336 
     337        if (q.child_ne != null) { 
     338            this.generateList(q.child_ne); 
     339        } 
     340        if (q.child_nw != null) { 
     341            this.generateList(q.child_nw); 
     342        } 
     343        if (q.child_se != null) { 
     344            this.generateList(q.child_se); 
     345        } 
     346        if (q.child_sw != null) { 
     347            this.generateList(q.child_sw); 
     348        } 
     349    } 
     350 
     351    /** 
     352     * Checks if passed QuadTree is neighboring to us 
     353     *  
     354     * @param q 
     355     *            QuadTree 
     356     * @return true if passed QuadTree is a neighbor 
     357     */ 
     358    public boolean isNeighbour(QuadTree q) { 
     359        boolean is_neighbour = false; 
     360 
     361        double[][] our_size = this.getSize(); 
     362        double[][] new_size = q.getSize(); 
     363 
     364        // X is i=0, Y is i=1 
     365        for (int i = 0; i < 2; i++) { 
     366            // we are smaller than q 
     367            // -------------- q 
     368            // ------- we 
     369            if (our_size[i][0] >= new_size[i][0] && our_size[i][1] <= new_size[i][1]) { 
     370                is_neighbour = true; 
     371            } 
     372            // we overlap with q at some point 
     373            // a) ---------------q 
     374            // ----------- we 
     375            // b) --------- q 
     376            // --------- we 
     377            if ((our_size[i][0] >= new_size[i][0] && our_size[i][0] <= new_size[i][1]) || 
     378                (our_size[i][1] >= new_size[i][0] && our_size[i][1] <= new_size[i][1])) 
     379            { 
     380                is_neighbour = true; 
     381            } 
     382            // we are larger than q 
     383            // ---- q 
     384            // ---------- we 
     385            if (our_size[i][1] >= new_size[i][1] && our_size[i][0] <= new_size[i][0]) { 
     386                is_neighbour = true; 
     387            } 
     388        } 
     389 
     390        if (is_neighbour && QuadTree.verbose) { 
     391            System.out.println(this + " neighbour of: " + q); 
     392        } 
     393 
     394        return is_neighbour; 
     395    } 
     396 
     397    /** 
     398     * Perform pruning and clustering of the quadtree 
     399     *  
     400     * Pruning according to: Tim Menzies, Andrew Butcher, David Cok, Andrian Marcus, Lucas Layman, 
     401     * Forrest Shull, Burak Turhan, Thomas Zimmermann, 
     402     * "Local versus Global Lessons for Defect Prediction and Effort Estimation," IEEE Transactions 
     403     * on Software Engineering, vol. 39, no. 6, pp. 822-834, June, 2013 
     404     *  
     405     * 1) get list of leaf quadrants 2) sort by their density 3) set stop_rule to 0.5 * highest 
     406     * Density in the list 4) merge all nodes with a density > stop_rule to the new cluster and 
     407     * remove all from list 5) repeat 
     408     *  
     409     * @param q 
     410     *            List of QuadTree (children only) 
     411     */ 
     412    public void gridClustering(ArrayList<QuadTree> list) { 
     413 
     414        if (list.size() == 0) { 
     415            return; 
     416        } 
     417 
     418        double stop_rule; 
     419        QuadTree biggest; 
     420        QuadTree current; 
     421 
     422        // current clusterlist 
     423        ArrayList<QuadTreePayload<Instance>> current_cluster; 
     424 
     425        // remove list (for removal of items after scanning of the list) 
     426        ArrayList<Integer> remove = new ArrayList<Integer>(); 
     427 
     428        // 1. find biggest, and add it 
     429        biggest = list.get(list.size() - 1); 
     430        stop_rule = biggest.getDensity() * 0.5; 
     431 
     432        current_cluster = new ArrayList<QuadTreePayload<Instance>>(); 
     433        current_cluster.addAll(biggest.getPayload()); 
     434 
     435        // remove the biggest because we are starting with it 
     436        remove.add(list.size() - 1); 
     437 
     438        ArrayList<Double[][]> tmpSize = new ArrayList<Double[][]>(); 
     439        tmpSize.add(biggest.getSizeDouble()); 
     440 
     441        // check the items for their density 
     442        for (int i = list.size() - 1; i >= 0; i--) { 
     443            current = list.get(i); 
     444 
     445            // 2. find neighbors with correct density 
     446            // if density > stop_rule and is_neighbour add to cluster and remove from list 
     447            if (current.getDensity() > stop_rule && !current.equals(biggest) && 
     448                current.isNeighbour(biggest)) 
     449            { 
     450                current_cluster.addAll(current.getPayload()); 
     451 
     452                // add it to remove list (we cannot remove it inside the loop because it would move 
     453                // the index) 
     454                remove.add(i); 
     455 
     456                // get the size 
     457                tmpSize.add(current.getSizeDouble()); 
     458            } 
     459        } 
     460 
     461        // 3. remove our removal candidates from the list 
     462        for (Integer item : remove) { 
     463            list.remove((int) item); 
     464        } 
     465 
     466        // 4. add to cluster 
     467        QuadTree.ccluster.add(current_cluster); 
     468 
     469        // 5. add sizes of our current (biggest) this adds a number of sizes (all QuadTree Instances 
     470        // belonging to this cluster) 
     471        // we need that to classify test instances to a cluster later 
     472        Integer cnumber = new Integer(QuadTree.ccluster.size() - 1); 
     473        if (QuadTree.csize.containsKey(cnumber) == false) { 
     474            QuadTree.csize.put(cnumber, tmpSize); 
     475        } 
     476 
     477        // repeat 
     478        this.gridClustering(list); 
     479    } 
     480 
     481    public void printInfo() { 
     482        System.out.println("we have " + ccluster.size() + " clusters"); 
     483 
     484        for (int i = 0; i < ccluster.size(); i++) { 
     485            System.out.println("cluster: " + i + " size: " + ccluster.get(i).size()); 
     486        } 
     487    } 
     488 
     489    /** 
     490     * Helper Method to get a sorted list (by density) for all children 
     491     *  
     492     * @param q 
     493     *            QuadTree 
     494     * @return Sorted ArrayList of quadtrees 
     495     */ 
     496    public ArrayList<QuadTree> getList(QuadTree q) { 
     497        this.generateList(q); 
     498 
     499        Collections.sort(this.l, new Comparator<QuadTree>() { 
     500            @Override 
     501            public int compare(QuadTree x1, QuadTree x2) { 
     502                return Double.compare(x1.getDensity(), x2.getDensity()); 
     503            } 
     504        }); 
     505 
     506        return this.l; 
     507    } 
    473508} 
  • trunk/CrossPare/src/de/ugoe/cs/cpdp/training/RandomClass.java

    r38 r41  
     1// Copyright 2015 Georg-August-Universität Göttingen, Germany 
     2// 
     3//   Licensed under the Apache License, Version 2.0 (the "License"); 
     4//   you may not use this file except in compliance with the License. 
     5//   You may obtain a copy of the License at 
     6// 
     7//       http://www.apache.org/licenses/LICENSE-2.0 
     8// 
     9//   Unless required by applicable law or agreed to in writing, software 
     10//   distributed under the License is distributed on an "AS IS" BASIS, 
     11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
     12//   See the License for the specific language governing permissions and 
     13//   limitations under the License. 
     14 
    115package de.ugoe.cs.cpdp.training; 
    216 
     
    1125 * Assigns a random class label to the instance it is evaluated on. 
    1226 *  
    13  * The range of class labels are hardcoded in fixedClassValues. 
    14  * This can later be extended to take values from the XML configuration.  
     27 * The range of class labels are hardcoded in fixedClassValues. This can later be extended to take 
     28 * values from the XML configuration. 
    1529 */ 
    16 public class RandomClass extends AbstractClassifier implements ITrainingStrategy, IWekaCompatibleTrainer { 
     30public class RandomClass extends AbstractClassifier implements ITrainingStrategy, 
     31    IWekaCompatibleTrainer 
     32{ 
    1733 
    18         private static final long serialVersionUID = 1L; 
     34    private static final long serialVersionUID = 1L; 
    1935 
    20         private double[] fixedClassValues = {0.0d, 1.0d}; 
    21          
    22         @Override 
    23         public void setParameter(String parameters) { 
    24                 // do nothing, maybe take percentages for distribution later 
    25         } 
     36    private double[] fixedClassValues = 
     37        { 0.0d, 1.0d }; 
    2638 
    27         @Override 
    28         public void buildClassifier(Instances arg0) throws Exception { 
    29                 // do nothing 
    30         } 
     39    @Override 
     40    public void setParameter(String parameters) { 
     41        // do nothing, maybe take percentages for distribution later 
     42    } 
    3143 
    32         @Override 
    33         public Classifier getClassifier() { 
    34                 return this; 
    35         } 
     44    @Override 
     45    public void buildClassifier(Instances arg0) throws Exception { 
     46        // do nothing 
     47    } 
    3648 
    37         @Override 
    38         public void apply(Instances traindata) { 
    39                 // nothing to do 
    40         } 
     49    @Override 
     50    public Classifier getClassifier() { 
     51        return this; 
     52    } 
    4153 
    42         @Override 
    43         public String getName() { 
    44                 return "RandomClass"; 
    45         } 
    46          
    47         @Override 
    48         public double classifyInstance(Instance instance) { 
    49                 Random rand = new Random(); 
    50             int randomNum = rand.nextInt(this.fixedClassValues.length); 
    51                 return this.fixedClassValues[randomNum]; 
    52         } 
     54    @Override 
     55    public void apply(Instances traindata) { 
     56        // nothing to do 
     57    } 
     58 
     59    @Override 
     60    public String getName() { 
     61        return "RandomClass"; 
     62    } 
     63 
     64    @Override 
     65    public double classifyInstance(Instance instance) { 
     66        Random rand = new Random(); 
     67        int randomNum = rand.nextInt(this.fixedClassValues.length); 
     68        return this.fixedClassValues[randomNum]; 
     69    } 
    5370} 
  • trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaBaggingTraining.java

    r25 r41  
     1// Copyright 2015 Georg-August-Universität Göttingen, Germany 
     2// 
     3//   Licensed under the Apache License, Version 2.0 (the "License"); 
     4//   you may not use this file except in compliance with the License. 
     5//   You may obtain a copy of the License at 
     6// 
     7//       http://www.apache.org/licenses/LICENSE-2.0 
     8// 
     9//   Unless required by applicable law or agreed to in writing, software 
     10//   distributed under the License is distributed on an "AS IS" BASIS, 
     11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
     12//   See the License for the specific language governing permissions and 
     13//   limitations under the License. 
     14 
    115package de.ugoe.cs.cpdp.training; 
    216 
     
    1832/** 
    1933 * Programmatic WekaBaggingTraining 
    20  * 
    21  * first parameter is Trainer Name. 
    22  * second parameter is class name 
    2334 *  
    24  * all subsequent parameters are configuration params (for example for trees) 
    25  * Cross Validation params always come last and are prepended with -CVPARAM 
     35 * first parameter is Trainer Name. second parameter is class name 
     36 *  
     37 * all subsequent parameters are configuration params (for example for trees) Cross Validation 
     38 * params always come last and are prepended with -CVPARAM 
    2639 *  
    2740 * XML Configurations for Weka Classifiers: 
     41 *  
    2842 * <pre> 
    2943 * {@code 
     
    3751public class WekaBaggingTraining extends WekaBaseTraining implements ISetWiseTrainingStrategy { 
    3852 
    39         private final TraindatasetBagging classifier = new TraindatasetBagging(); 
    40          
    41         @Override 
    42         public Classifier getClassifier() { 
    43                 return classifier; 
    44         } 
    45          
    46         @Override 
    47         public void apply(SetUniqueList<Instances> traindataSet) { 
    48                 PrintStream errStr      = System.err; 
    49                 System.setErr(new PrintStream(new NullOutputStream())); 
    50                 try { 
    51                         classifier.buildClassifier(traindataSet); 
    52                 } catch (Exception e) { 
    53                         throw new RuntimeException(e); 
    54                 } finally { 
    55                         System.setErr(errStr); 
    56                 } 
    57         } 
    58          
    59         public class TraindatasetBagging extends AbstractClassifier { 
    60                  
    61                 private static final long serialVersionUID = 1L; 
     53    private final TraindatasetBagging classifier = new TraindatasetBagging(); 
    6254 
    63                 private List<Instances> trainingData = null; 
    64                  
    65                 private List<Classifier> classifiers = null; 
    66          
    67                 @Override 
    68                 public double classifyInstance(Instance instance) { 
    69                         if( classifiers==null ) { 
    70                                 return 0.0; 
    71                         } 
    72                          
    73                         double classification = 0.0; 
    74                         for( int i=0 ; i<classifiers.size(); i++ ) { 
    75                                 Classifier classifier = classifiers.get(i); 
    76                                 Instances traindata = trainingData.get(i); 
    77                                  
    78                                 Set<String> attributeNames = new HashSet<>(); 
    79                                 for( int j=0; j<traindata.numAttributes(); j++ ) { 
    80                                         attributeNames.add(traindata.attribute(j).name()); 
    81                                 } 
    82                                  
    83                                 double[] values = new double[traindata.numAttributes()]; 
    84                                 int index = 0; 
    85                                 for( int j=0; j<instance.numAttributes(); j++ ) { 
    86                                         if( attributeNames.contains(instance.attribute(j).name())) { 
    87                                                 values[index] = instance.value(j); 
    88                                                 index++; 
    89                                         } 
    90                                 } 
    91                                  
    92                                 Instances tmp = new Instances(traindata); 
    93                                 tmp.clear(); 
    94                                 Instance instCopy = new DenseInstance(instance.weight(), values); 
    95                                 instCopy.setDataset(tmp); 
    96                                 try { 
    97                                         classification += classifier.classifyInstance(instCopy); 
    98                                 } catch (Exception e) { 
    99                                         throw new RuntimeException("bagging classifier could not classify an instance", e); 
    100                                 } 
    101                         } 
    102                         classification /= classifiers.size(); 
    103                         return (classification>=0.5) ? 1.0 : 0.0; 
    104                 } 
    105                  
    106                 public void buildClassifier(SetUniqueList<Instances> traindataSet) throws Exception { 
    107                         classifiers = new LinkedList<>(); 
    108                         trainingData = new LinkedList<>(); 
    109                         for( Instances traindata : traindataSet ) { 
    110                                 Classifier classifier = setupClassifier(); 
    111                                 classifier.buildClassifier(traindata); 
    112                                 classifiers.add(classifier); 
    113                                 trainingData.add(new Instances(traindata)); 
    114                         } 
    115                 } 
    116          
    117                 @Override 
    118                 public void buildClassifier(Instances traindata) throws Exception { 
    119                         classifiers = new LinkedList<>(); 
    120                         trainingData = new LinkedList<>(); 
    121                         final Classifier classifier = setupClassifier(); 
    122                         classifier.buildClassifier(traindata); 
    123                         classifiers.add(classifier); 
    124                         trainingData.add(new Instances(traindata)); 
    125                 } 
    126         } 
     55    @Override 
     56    public Classifier getClassifier() { 
     57        return classifier; 
     58    } 
     59 
     60    @Override 
     61    public void apply(SetUniqueList<Instances> traindataSet) { 
     62        PrintStream errStr = System.err; 
     63        System.setErr(new PrintStream(new NullOutputStream())); 
     64        try { 
     65            classifier.buildClassifier(traindataSet); 
     66        } 
     67        catch (Exception e) { 
     68            throw new RuntimeException(e); 
     69        } 
     70        finally { 
     71            System.setErr(errStr); 
     72        } 
     73    } 
     74 
     75    public class TraindatasetBagging extends AbstractClassifier { 
     76 
     77        private static final long serialVersionUID = 1L; 
     78 
     79        private List<Instances> trainingData = null; 
     80 
     81        private List<Classifier> classifiers = null; 
     82 
     83        @Override 
     84        public double classifyInstance(Instance instance) { 
     85            if (classifiers == null) { 
     86                return 0.0; 
     87            } 
     88 
     89            double classification = 0.0; 
     90            for (int i = 0; i < classifiers.size(); i++) { 
     91                Classifier classifier = classifiers.get(i); 
     92                Instances traindata = trainingData.get(i); 
     93 
     94                Set<String> attributeNames = new HashSet<>(); 
     95                for (int j = 0; j < traindata.numAttributes(); j++) { 
     96                    attributeNames.add(traindata.attribute(j).name()); 
     97                } 
     98 
     99                double[] values = new double[traindata.numAttributes()]; 
     100                int index = 0; 
     101                for (int j = 0; j < instance.numAttributes(); j++) { 
     102                    if (attributeNames.contains(instance.attribute(j).name())) { 
     103                        values[index] = instance.value(j); 
     104                        index++; 
     105                    } 
     106                } 
     107 
     108                Instances tmp = new Instances(traindata); 
     109                tmp.clear(); 
     110                Instance instCopy = new DenseInstance(instance.weight(), values); 
     111                instCopy.setDataset(tmp); 
     112                try { 
     113                    classification += classifier.classifyInstance(instCopy); 
     114                } 
     115                catch (Exception e) { 
     116                    throw new RuntimeException("bagging classifier could not classify an instance", 
     117                                               e); 
     118                } 
     119            } 
     120            classification /= classifiers.size(); 
     121            return (classification >= 0.5) ? 1.0 : 0.0; 
     122        } 
     123 
     124        public void buildClassifier(SetUniqueList<Instances> traindataSet) throws Exception { 
     125            classifiers = new LinkedList<>(); 
     126            trainingData = new LinkedList<>(); 
     127            for (Instances traindata : traindataSet) { 
     128                Classifier classifier = setupClassifier(); 
     129                classifier.buildClassifier(traindata); 
     130                classifiers.add(classifier); 
     131                trainingData.add(new Instances(traindata)); 
     132            } 
     133        } 
     134 
     135        @Override 
     136        public void buildClassifier(Instances traindata) throws Exception { 
     137            classifiers = new LinkedList<>(); 
     138            trainingData = new LinkedList<>(); 
     139            final Classifier classifier = setupClassifier(); 
     140            classifier.buildClassifier(traindata); 
     141            classifiers.add(classifier); 
     142            trainingData.add(new Instances(traindata)); 
     143        } 
     144    } 
    127145} 
  • trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaBaseTraining.java

    r25 r41  
     1// Copyright 2015 Georg-August-Universität Göttingen, Germany 
     2// 
     3//   Licensed under the Apache License, Version 2.0 (the "License"); 
     4//   you may not use this file except in compliance with the License. 
     5//   You may obtain a copy of the License at 
     6// 
     7//       http://www.apache.org/licenses/LICENSE-2.0 
     8// 
     9//   Unless required by applicable law or agreed to in writing, software 
     10//   distributed under the License is distributed on an "AS IS" BASIS, 
     11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
     12//   See the License for the specific language governing permissions and 
     13//   limitations under the License. 
     14 
    115package de.ugoe.cs.cpdp.training; 
    216 
     
    1529 * Allows specification of the Weka classifier and its params in the XML experiment configuration. 
    1630 *  
    17  * Important conventions of the XML format:  
    18  * Cross Validation params always come last and are prepended with -CVPARAM 
    19  * Example: <trainer name="WekaTraining" param="RandomForestLocal weka.classifiers.trees.RandomForest -CVPARAM I 5 25 5"/> 
     31 * Important conventions of the XML format: Cross Validation params always come last and are 
     32 * prepended with -CVPARAM Example: <trainer name="WekaTraining" 
     33 * param="RandomForestLocal weka.classifiers.trees.RandomForest -CVPARAM I 5 25 5"/> 
    2034 */ 
    2135public abstract class WekaBaseTraining implements IWekaCompatibleTrainer { 
    22          
    23         protected Classifier classifier = null; 
    24         protected String classifierClassName; 
    25         protected String classifierName; 
    26         protected String[] classifierParams; 
    27          
    28         @Override 
    29         public void setParameter(String parameters) { 
    30                 String[] params = parameters.split(" "); 
    3136 
    32                 // first part of the params is the classifierName (e.g. SMORBF) 
    33                 classifierName = params[0]; 
    34                  
    35                 // the following parameters can be copied from weka! 
    36                  
    37                 // second param is classifierClassName (e.g. weka.classifiers.functions.SMO) 
    38                 classifierClassName = params[1]; 
    39          
    40                 // rest are params to the specified classifier (e.g. -K weka.classifiers.functions.supportVector.RBFKernel) 
    41                 classifierParams = Arrays.copyOfRange(params, 2, params.length); 
    42                  
    43                 classifier = setupClassifier(); 
    44         } 
     37    protected Classifier classifier = null; 
     38    protected String classifierClassName; 
     39    protected String classifierName; 
     40    protected String[] classifierParams; 
    4541 
    46         @Override 
    47         public Classifier getClassifier() { 
    48                 return classifier; 
    49         } 
     42    @Override 
     43    public void setParameter(String parameters) { 
     44        String[] params = parameters.split(" "); 
    5045 
    51         public Classifier setupClassifier() { 
    52                 Classifier cl = null; 
    53                 try{ 
    54                         @SuppressWarnings("rawtypes") 
    55                         Class c = Class.forName(classifierClassName); 
    56                         Classifier obj = (Classifier) c.newInstance(); 
    57                          
    58                         // Filter out -CVPARAM, these are special because they do not belong to the Weka classifier class as parameters 
    59                         String[] param = Arrays.copyOf(classifierParams, classifierParams.length); 
    60                         String[] cvparam = {}; 
    61                         boolean cv = false; 
    62                         for ( int i=0; i < classifierParams.length; i++ ) { 
    63                                 if(classifierParams[i].equals("-CVPARAM")) { 
    64                                         // rest of array are cvparam 
    65                                         cvparam = Arrays.copyOfRange(classifierParams, i+1, classifierParams.length); 
    66                                          
    67                                         // before this we have normal params 
    68                                         param = Arrays.copyOfRange(classifierParams, 0, i); 
    69                                          
    70                                         cv = true; 
    71                                         break; 
    72                                 } 
    73                         } 
    74                          
    75                         // set classifier params 
    76                         ((OptionHandler)obj).setOptions(param); 
    77                         cl = obj; 
    78                          
    79                         // we have cross val params 
    80                         // cant check on cvparam.length here, it may not be initialized                  
    81                         if(cv) { 
    82                                 final CVParameterSelection ps = new CVParameterSelection(); 
    83                                 ps.setClassifier(obj); 
    84                                 ps.setNumFolds(5); 
    85                                 //ps.addCVParameter("I 5 25 5"); 
    86                                 for( int i=1 ; i<cvparam.length/4 ; i++ ) { 
    87                                         ps.addCVParameter(Arrays.asList(Arrays.copyOfRange(cvparam, 0, 4*i)).toString().replaceAll(", ", " ").replaceAll("^\\[|\\]$", "")); 
    88                                 } 
    89                                  
    90                                 cl = ps; 
    91                         } 
     46        // first part of the params is the classifierName (e.g. SMORBF) 
     47        classifierName = params[0]; 
    9248 
    93                 }catch(ClassNotFoundException e) { 
    94                         Console.traceln(Level.WARNING, String.format("class not found: %s", e.toString())); 
    95                         e.printStackTrace(); 
    96                 } catch (InstantiationException e) { 
    97                         Console.traceln(Level.WARNING, String.format("Instantiation Exception: %s", e.toString())); 
    98                         e.printStackTrace(); 
    99                 } catch (IllegalAccessException e) { 
    100                         Console.traceln(Level.WARNING, String.format("Illegal Access Exception: %s", e.toString())); 
    101                         e.printStackTrace(); 
    102                 } catch (Exception e) { 
    103                         Console.traceln(Level.WARNING, String.format("Exception: %s", e.toString())); 
    104                         e.printStackTrace(); 
    105                 } 
    106                  
    107                 return cl; 
    108         } 
     49        // the following parameters can be copied from weka! 
    10950 
    110         @Override 
    111         public String getName() { 
    112                 return classifierName; 
    113         } 
    114          
     51        // second param is classifierClassName (e.g. weka.classifiers.functions.SMO) 
     52        classifierClassName = params[1]; 
     53 
     54        // rest are params to the specified classifier (e.g. -K 
     55        // weka.classifiers.functions.supportVector.RBFKernel) 
     56        classifierParams = Arrays.copyOfRange(params, 2, params.length); 
     57 
     58        classifier = setupClassifier(); 
     59    } 
     60 
     61    @Override 
     62    public Classifier getClassifier() { 
     63        return classifier; 
     64    } 
     65 
     66    public Classifier setupClassifier() { 
     67        Classifier cl = null; 
     68        try { 
     69            @SuppressWarnings("rawtypes") 
     70            Class c = Class.forName(classifierClassName); 
     71            Classifier obj = (Classifier) c.newInstance(); 
     72 
     73            // Filter out -CVPARAM, these are special because they do not belong to the Weka 
     74            // classifier class as parameters 
     75            String[] param = Arrays.copyOf(classifierParams, classifierParams.length); 
     76            String[] cvparam = { }; 
     77            boolean cv = false; 
     78            for (int i = 0; i < classifierParams.length; i++) { 
     79                if (classifierParams[i].equals("-CVPARAM")) { 
     80                    // rest of array are cvparam 
     81                    cvparam = Arrays.copyOfRange(classifierParams, i + 1, classifierParams.length); 
     82 
     83                    // before this we have normal params 
     84                    param = Arrays.copyOfRange(classifierParams, 0, i); 
     85 
     86                    cv = true; 
     87                    break; 
     88                } 
     89            } 
     90 
     91            // set classifier params 
     92            ((OptionHandler) obj).setOptions(param); 
     93            cl = obj; 
     94 
     95            // we have cross val params 
     96            // cant check on cvparam.length here, it may not be initialized 
     97            if (cv) { 
     98                final CVParameterSelection ps = new CVParameterSelection(); 
     99                ps.setClassifier(obj); 
     100                ps.setNumFolds(5); 
     101                // ps.addCVParameter("I 5 25 5"); 
     102                for (int i = 1; i < cvparam.length / 4; i++) { 
     103                    ps.addCVParameter(Arrays.asList(Arrays.copyOfRange(cvparam, 0, 4 * i)) 
     104                        .toString().replaceAll(", ", " ").replaceAll("^\\[|\\]$", "")); 
     105                } 
     106 
     107                cl = ps; 
     108            } 
     109 
     110        } 
     111        catch (ClassNotFoundException e) { 
     112            Console.traceln(Level.WARNING, String.format("class not found: %s", e.toString())); 
     113            e.printStackTrace(); 
     114        } 
     115        catch (InstantiationException e) { 
     116            Console.traceln(Level.WARNING, 
     117                            String.format("Instantiation Exception: %s", e.toString())); 
     118            e.printStackTrace(); 
     119        } 
     120        catch (IllegalAccessException e) { 
     121            Console.traceln(Level.WARNING, 
     122                            String.format("Illegal Access Exception: %s", e.toString())); 
     123            e.printStackTrace(); 
     124        } 
     125        catch (Exception e) { 
     126            Console.traceln(Level.WARNING, String.format("Exception: %s", e.toString())); 
     127            e.printStackTrace(); 
     128        } 
     129 
     130        return cl; 
     131    } 
     132 
     133    @Override 
     134    public String getName() { 
     135        return classifierName; 
     136    } 
     137 
    115138} 
  • trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaLocalEMTraining.java

    r25 r41  
     1// Copyright 2015 Georg-August-Universität Göttingen, Germany 
     2// 
     3//   Licensed under the Apache License, Version 2.0 (the "License"); 
     4//   you may not use this file except in compliance with the License. 
     5//   You may obtain a copy of the License at 
     6// 
     7//       http://www.apache.org/licenses/LICENSE-2.0 
     8// 
     9//   Unless required by applicable law or agreed to in writing, software 
     10//   distributed under the License is distributed on an "AS IS" BASIS, 
     11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
     12//   See the License for the specific language governing permissions and 
     13//   limitations under the License. 
     14 
    115package de.ugoe.cs.cpdp.training; 
    216 
     
    2438 * WekaLocalEMTraining 
    2539 *  
    26  * Local Trainer with EM Clustering for data partitioning. 
    27  * Currently supports only EM Clustering. 
    28  *  
    29  * 1. Cluster training data 
    30  * 2. for each cluster train a classifier with training data from cluster 
     40 * Local Trainer with EM Clustering for data partitioning. Currently supports only EM Clustering. 
     41 *  
     42 * 1. Cluster training data 2. for each cluster train a classifier with training data from cluster 
    3143 * 3. match test data instance to a cluster, then classify with classifier from the cluster 
    3244 *  
    33  * XML configuration: 
    34  * <!-- because of clustering --> 
    35  * <preprocessor name="Normalization" param=""/> 
    36  *  
    37  * <!-- cluster trainer --> 
    38  * <trainer name="WekaLocalEMTraining" param="NaiveBayes weka.classifiers.bayes.NaiveBayes" /> 
     45 * XML configuration: <!-- because of clustering --> <preprocessor name="Normalization" param=""/> 
     46 *  
     47 * <!-- cluster trainer --> <trainer name="WekaLocalEMTraining" 
     48 * param="NaiveBayes weka.classifiers.bayes.NaiveBayes" /> 
    3949 */ 
    4050public class WekaLocalEMTraining extends WekaBaseTraining implements ITrainingStrategy { 
    4151 
    42         private final TraindatasetCluster classifier = new TraindatasetCluster(); 
    43          
    44         @Override 
    45         public Classifier getClassifier() { 
    46                 return classifier; 
    47         } 
    48          
    49         @Override 
    50         public void apply(Instances traindata) { 
    51                 PrintStream errStr      = System.err; 
    52                 System.setErr(new PrintStream(new NullOutputStream())); 
    53                 try { 
    54                         classifier.buildClassifier(traindata); 
    55                 } catch (Exception e) { 
    56                         throw new RuntimeException(e); 
    57                 } finally { 
    58                         System.setErr(errStr); 
    59                 } 
    60         } 
    61          
    62  
    63         public class TraindatasetCluster extends AbstractClassifier { 
    64                  
    65                 private static final long serialVersionUID = 1L; 
    66  
    67                 private EM clusterer = null; 
    68  
    69                 private HashMap<Integer, Classifier> cclassifier; 
    70                 private HashMap<Integer, Instances> ctraindata;  
    71                  
    72                  
    73                 /** 
    74                  * Helper method that gives us a clean instance copy with  
    75                  * the values of the instancelist of the first parameter.  
    76                  *  
    77                  * @param instancelist with attributes 
    78                  * @param instance with only values 
    79                  * @return copy of the instance 
    80                  */ 
    81                 private Instance createInstance(Instances instances, Instance instance) { 
    82                         // attributes for feeding instance to classifier 
    83                         Set<String> attributeNames = new HashSet<>(); 
    84                         for( int j=0; j<instances.numAttributes(); j++ ) { 
    85                                 attributeNames.add(instances.attribute(j).name()); 
    86                         } 
    87                          
    88                         double[] values = new double[instances.numAttributes()]; 
    89                         int index = 0; 
    90                         for( int j=0; j<instance.numAttributes(); j++ ) { 
    91                                 if( attributeNames.contains(instance.attribute(j).name())) { 
    92                                         values[index] = instance.value(j); 
    93                                         index++; 
    94                                 } 
    95                         } 
    96                          
    97                         Instances tmp = new Instances(instances); 
    98                         tmp.clear(); 
    99                         Instance instCopy = new DenseInstance(instance.weight(), values); 
    100                         instCopy.setDataset(tmp); 
    101                          
    102                         return instCopy; 
    103                 } 
    104                  
    105                 @Override 
    106                 public double classifyInstance(Instance instance) { 
    107                         double ret = 0; 
    108                         try { 
    109                                 // 1. copy the instance (keep the class attribute) 
    110                                 Instances traindata = ctraindata.get(0); 
    111                                 Instance classInstance = createInstance(traindata, instance); 
    112                                  
    113                                 // 2. remove class attribute before clustering 
    114                                 Remove filter = new Remove(); 
    115                                 filter.setAttributeIndices("" + (traindata.classIndex() + 1)); 
    116                                 filter.setInputFormat(traindata); 
    117                                 traindata = Filter.useFilter(traindata, filter); 
    118                                  
    119                                 // 3. copy the instance (without the class attribute) for clustering 
    120                                 Instance clusterInstance = createInstance(traindata, instance); 
    121                                  
    122                                 // 4. match instance without class attribute to a cluster number 
    123                                 int cnum = clusterer.clusterInstance(clusterInstance); 
    124                                  
    125                                 // 5. classify instance with class attribute to the classifier of that cluster number 
    126                                 ret = cclassifier.get(cnum).classifyInstance(classInstance); 
    127                                  
    128                         }catch( Exception e ) { 
    129                                 Console.traceln(Level.INFO, String.format("ERROR matching instance to cluster!")); 
    130                                 throw new RuntimeException(e); 
    131                         } 
    132                         return ret; 
    133                 } 
    134  
    135                 @Override 
    136                 public void buildClassifier(Instances traindata) throws Exception { 
    137                          
    138                         // 1. copy training data 
    139                         Instances train = new Instances(traindata); 
    140                          
    141                         // 2. remove class attribute for clustering 
    142                         Remove filter = new Remove(); 
    143                         filter.setAttributeIndices("" + (train.classIndex() + 1)); 
    144                         filter.setInputFormat(train); 
    145                         train = Filter.useFilter(train, filter); 
    146                          
    147                         // new objects 
    148                         cclassifier = new HashMap<Integer, Classifier>(); 
    149                         ctraindata = new HashMap<Integer, Instances>(); 
    150                                                  
    151                         Instances ctrain; 
    152                         int maxNumClusters = train.size(); 
    153                         boolean sufficientInstancesInEachCluster; 
    154                         do { // while(onlyTarget) 
    155                                 sufficientInstancesInEachCluster = true; 
    156                                 clusterer = new EM(); 
    157                                 clusterer.setMaximumNumberOfClusters(maxNumClusters); 
    158                                 clusterer.buildClusterer(train); 
    159                                  
    160                                 // 4. get cluster membership of our traindata 
    161                                 //AddCluster cfilter = new AddCluster(); 
    162                                 //cfilter.setClusterer(clusterer); 
    163                                 //cfilter.setInputFormat(train); 
    164                                 //Instances ctrain = Filter.useFilter(train, cfilter); 
    165                                  
    166                                 ctrain = new Instances(train); 
    167                                 ctraindata = new HashMap<>(); 
    168                                  
    169                                 // get traindata per cluster 
    170                                 for ( int j=0; j < ctrain.numInstances(); j++ ) { 
    171                                         // get the cluster number from the attributes, subract 1 because if we clusterInstance we get 0-n, and this is 1-n 
    172                                         //cnumber = Integer.parseInt(ctrain.get(j).stringValue(ctrain.get(j).numAttributes()-1).replace("cluster", "")) - 1; 
    173                                          
    174                                         int cnumber = clusterer.clusterInstance(ctrain.get(j)); 
    175                                         // add training data to list of instances for this cluster number 
    176                                         if ( !ctraindata.containsKey(cnumber) ) { 
    177                                                 ctraindata.put(cnumber, new Instances(traindata)); 
    178                                                 ctraindata.get(cnumber).delete(); 
    179                                         } 
    180                                         ctraindata.get(cnumber).add(traindata.get(j)); 
    181                                 } 
    182                                  
    183                                 for( Entry<Integer,Instances> entry : ctraindata.entrySet() ) { 
    184                                         Instances instances = entry.getValue(); 
    185                                         int[] counts = instances.attributeStats(instances.classIndex()).nominalCounts; 
    186                                         for( int count : counts ) { 
    187                                                 sufficientInstancesInEachCluster &= count>0; 
    188                                         } 
    189                                         sufficientInstancesInEachCluster &= instances.numInstances()>=5; 
    190                                 } 
    191                                 maxNumClusters = clusterer.numberOfClusters()-1; 
    192                         } while(!sufficientInstancesInEachCluster); 
    193                          
    194                         // train one classifier per cluster, we get the cluster number from the training data 
    195                         Iterator<Integer> clusternumber = ctraindata.keySet().iterator(); 
    196                         while ( clusternumber.hasNext() ) { 
    197                                 int cnumber = clusternumber.next();                      
    198                                 cclassifier.put(cnumber,setupClassifier()); 
    199                                 cclassifier.get(cnumber).buildClassifier(ctraindata.get(cnumber)); 
    200                                  
    201                                 //Console.traceln(Level.INFO, String.format("classifier in cluster "+cnumber)); 
    202                         } 
    203                 } 
    204         } 
     52    private final TraindatasetCluster classifier = new TraindatasetCluster(); 
     53 
     54    @Override 
     55    public Classifier getClassifier() { 
     56        return classifier; 
     57    } 
     58 
     59    @Override 
     60    public void apply(Instances traindata) { 
     61        PrintStream errStr = System.err; 
     62        System.setErr(new PrintStream(new NullOutputStream())); 
     63        try { 
     64            classifier.buildClassifier(traindata); 
     65        } 
     66        catch (Exception e) { 
     67            throw new RuntimeException(e); 
     68        } 
     69        finally { 
     70            System.setErr(errStr); 
     71        } 
     72    } 
     73 
     74    public class TraindatasetCluster extends AbstractClassifier { 
     75 
     76        private static final long serialVersionUID = 1L; 
     77 
     78        private EM clusterer = null; 
     79 
     80        private HashMap<Integer, Classifier> cclassifier; 
     81        private HashMap<Integer, Instances> ctraindata; 
     82 
     83        /** 
     84         * Helper method that gives us a clean instance copy with the values of the instancelist of 
     85         * the first parameter. 
     86         *  
     87         * @param instancelist 
     88         *            with attributes 
     89         * @param instance 
     90         *            with only values 
     91         * @return copy of the instance 
     92         */ 
     93        private Instance createInstance(Instances instances, Instance instance) { 
     94            // attributes for feeding instance to classifier 
     95            Set<String> attributeNames = new HashSet<>(); 
     96            for (int j = 0; j < instances.numAttributes(); j++) { 
     97                attributeNames.add(instances.attribute(j).name()); 
     98            } 
     99 
     100            double[] values = new double[instances.numAttributes()]; 
     101            int index = 0; 
     102            for (int j = 0; j < instance.numAttributes(); j++) { 
     103                if (attributeNames.contains(instance.attribute(j).name())) { 
     104                    values[index] = instance.value(j); 
     105                    index++; 
     106                } 
     107            } 
     108 
     109            Instances tmp = new Instances(instances); 
     110            tmp.clear(); 
     111            Instance instCopy = new DenseInstance(instance.weight(), values); 
     112            instCopy.setDataset(tmp); 
     113 
     114            return instCopy; 
     115        } 
     116 
     117        @Override 
     118        public double classifyInstance(Instance instance) { 
     119            double ret = 0; 
     120            try { 
     121                // 1. copy the instance (keep the class attribute) 
     122                Instances traindata = ctraindata.get(0); 
     123                Instance classInstance = createInstance(traindata, instance); 
     124 
     125                // 2. remove class attribute before clustering 
     126                Remove filter = new Remove(); 
     127                filter.setAttributeIndices("" + (traindata.classIndex() + 1)); 
     128                filter.setInputFormat(traindata); 
     129                traindata = Filter.useFilter(traindata, filter); 
     130 
     131                // 3. copy the instance (without the class attribute) for clustering 
     132                Instance clusterInstance = createInstance(traindata, instance); 
     133 
     134                // 4. match instance without class attribute to a cluster number 
     135                int cnum = clusterer.clusterInstance(clusterInstance); 
     136 
     137                // 5. classify instance with class attribute to the classifier of that cluster 
     138                // number 
     139                ret = cclassifier.get(cnum).classifyInstance(classInstance); 
     140 
     141            } 
     142            catch (Exception e) { 
     143                Console.traceln(Level.INFO, String.format("ERROR matching instance to cluster!")); 
     144                throw new RuntimeException(e); 
     145            } 
     146            return ret; 
     147        } 
     148 
     149        @Override 
     150        public void buildClassifier(Instances traindata) throws Exception { 
     151 
     152            // 1. copy training data 
     153            Instances train = new Instances(traindata); 
     154 
     155            // 2. remove class attribute for clustering 
     156            Remove filter = new Remove(); 
     157            filter.setAttributeIndices("" + (train.classIndex() + 1)); 
     158            filter.setInputFormat(train); 
     159            train = Filter.useFilter(train, filter); 
     160 
     161            // new objects 
     162            cclassifier = new HashMap<Integer, Classifier>(); 
     163            ctraindata = new HashMap<Integer, Instances>(); 
     164 
     165            Instances ctrain; 
     166            int maxNumClusters = train.size(); 
     167            boolean sufficientInstancesInEachCluster; 
     168            do { // while(onlyTarget) 
     169                sufficientInstancesInEachCluster = true; 
     170                clusterer = new EM(); 
     171                clusterer.setMaximumNumberOfClusters(maxNumClusters); 
     172                clusterer.buildClusterer(train); 
     173 
     174                // 4. get cluster membership of our traindata 
     175                // AddCluster cfilter = new AddCluster(); 
     176                // cfilter.setClusterer(clusterer); 
     177                // cfilter.setInputFormat(train); 
     178                // Instances ctrain = Filter.useFilter(train, cfilter); 
     179 
     180                ctrain = new Instances(train); 
     181                ctraindata = new HashMap<>(); 
     182 
     183                // get traindata per cluster 
     184                for (int j = 0; j < ctrain.numInstances(); j++) { 
     185                    // get the cluster number from the attributes, subract 1 because if we 
     186                    // clusterInstance we get 0-n, and this is 1-n 
     187                    // cnumber = 
     188                    // Integer.parseInt(ctrain.get(j).stringValue(ctrain.get(j).numAttributes()-1).replace("cluster", 
     189                    // "")) - 1; 
     190 
     191                    int cnumber = clusterer.clusterInstance(ctrain.get(j)); 
     192                    // add training data to list of instances for this cluster number 
     193                    if (!ctraindata.containsKey(cnumber)) { 
     194                        ctraindata.put(cnumber, new Instances(traindata)); 
     195                        ctraindata.get(cnumber).delete(); 
     196                    } 
     197                    ctraindata.get(cnumber).add(traindata.get(j)); 
     198                } 
     199 
     200                for (Entry<Integer, Instances> entry : ctraindata.entrySet()) { 
     201                    Instances instances = entry.getValue(); 
     202                    int[] counts = instances.attributeStats(instances.classIndex()).nominalCounts; 
     203                    for (int count : counts) { 
     204                        sufficientInstancesInEachCluster &= count > 0; 
     205                    } 
     206                    sufficientInstancesInEachCluster &= instances.numInstances() >= 5; 
     207                } 
     208                maxNumClusters = clusterer.numberOfClusters() - 1; 
     209            } 
     210            while (!sufficientInstancesInEachCluster); 
     211 
     212            // train one classifier per cluster, we get the cluster number from the training data 
     213            Iterator<Integer> clusternumber = ctraindata.keySet().iterator(); 
     214            while (clusternumber.hasNext()) { 
     215                int cnumber = clusternumber.next(); 
     216                cclassifier.put(cnumber, setupClassifier()); 
     217                cclassifier.get(cnumber).buildClassifier(ctraindata.get(cnumber)); 
     218 
     219                // Console.traceln(Level.INFO, String.format("classifier in cluster "+cnumber)); 
     220            } 
     221        } 
     222    } 
    205223} 
  • trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaLocalFQTraining.java

    r25 r41  
     1// Copyright 2015 Georg-August-Universität Göttingen, Germany 
     2// 
     3//   Licensed under the Apache License, Version 2.0 (the "License"); 
     4//   you may not use this file except in compliance with the License. 
     5//   You may obtain a copy of the License at 
     6// 
     7//       http://www.apache.org/licenses/LICENSE-2.0 
     8// 
     9//   Unless required by applicable law or agreed to in writing, software 
     10//   distributed under the License is distributed on an "AS IS" BASIS, 
     11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
     12//   See the License for the specific language governing permissions and 
     13//   limitations under the License. 
     14 
    115package de.ugoe.cs.cpdp.training; 
    216 
     
    2438 
    2539/** 
    26  * Trainer with reimplementation of WHERE clustering algorithm from: 
    27  * Tim Menzies, Andrew Butcher, David Cok, Andrian Marcus, Lucas Layman,  
    28  * Forrest Shull, Burak Turhan, Thomas Zimmermann,  
    29  * "Local versus Global Lessons for Defect Prediction and Effort Estimation,"  
    30  * IEEE Transactions on Software Engineering, vol. 39, no. 6, pp. 822-834, June, 2013   
     40 * Trainer with reimplementation of WHERE clustering algorithm from: Tim Menzies, Andrew Butcher, 
     41 * David Cok, Andrian Marcus, Lucas Layman, Forrest Shull, Burak Turhan, Thomas Zimmermann, 
     42 * "Local versus Global Lessons for Defect Prediction and Effort Estimation," IEEE Transactions on 
     43 * Software Engineering, vol. 39, no. 6, pp. 822-834, June, 2013 
    3144 *  
    32  * With WekaLocalFQTraining we do the following: 
    33  * 1) Run the Fastmap algorithm on all training data, let it calculate the 2 most significant  
    34  *    dimensions and projections of each instance to these dimensions 
    35  * 2) With these 2 dimensions we span a QuadTree which gets recursively split on median(x) and median(y) values. 
    36  * 3) We cluster the QuadTree nodes together if they have similar density (50%) 
    37  * 4) We save the clusters and their training data 
    38  * 5) We only use clusters with > ALPHA instances (currently Math.sqrt(SIZE)), rest is discarded with the training data of this cluster 
    39  * 6) We train a Weka classifier for each cluster with the clusters training data 
    40  * 7) We recalculate Fastmap distances for a single instance with the old pivots and then try to find a cluster containing the coords of the instance. 
    41  * 7.1.) If we can not find a cluster (due to coords outside of all clusters) we find the nearest cluster. 
    42  * 8) We classify the Instance with the classifier and traindata from the Cluster we found in 7. 
     45 * With WekaLocalFQTraining we do the following: 1) Run the Fastmap algorithm on all training data, 
     46 * let it calculate the 2 most significant dimensions and projections of each instance to these 
     47 * dimensions 2) With these 2 dimensions we span a QuadTree which gets recursively split on 
     48 * median(x) and median(y) values. 3) We cluster the QuadTree nodes together if they have similar 
     49 * density (50%) 4) We save the clusters and their training data 5) We only use clusters with > 
     50 * ALPHA instances (currently Math.sqrt(SIZE)), rest is discarded with the training data of this 
     51 * cluster 6) We train a Weka classifier for each cluster with the clusters training data 7) We 
     52 * recalculate Fastmap distances for a single instance with the old pivots and then try to find a 
     53 * cluster containing the coords of the instance. 7.1.) If we can not find a cluster (due to coords 
     54 * outside of all clusters) we find the nearest cluster. 8) We classify the Instance with the 
     55 * classifier and traindata from the Cluster we found in 7. 
    4356 */ 
    4457public class WekaLocalFQTraining extends WekaBaseTraining implements ITrainingStrategy { 
    45          
    46         private final TraindatasetCluster classifier = new TraindatasetCluster(); 
    47          
    48         @Override 
    49         public Classifier getClassifier() { 
    50                 return classifier; 
    51         } 
    52          
    53         @Override 
    54         public void apply(Instances traindata) { 
    55                 PrintStream errStr      = System.err; 
    56                 System.setErr(new PrintStream(new NullOutputStream())); 
    57                 try { 
    58                         classifier.buildClassifier(traindata); 
    59                 } catch (Exception e) { 
    60                         throw new RuntimeException(e); 
    61                 } finally { 
    62                         System.setErr(errStr); 
    63                 } 
    64         } 
    65          
    66          
    67         public class TraindatasetCluster extends AbstractClassifier { 
    68                  
    69                 private static final long serialVersionUID = 1L; 
    70                  
    71                 /* classifier per cluster */ 
    72                 private HashMap<Integer, Classifier> cclassifier; 
    73                  
    74                 /* instances per cluster */ 
    75                 private HashMap<Integer, Instances> ctraindata;  
    76                  
    77                 /* holds the instances and indices of the pivot objects of the Fastmap calculation in buildClassifier*/ 
    78                 private HashMap<Integer, Instance> cpivots; 
    79                  
    80                 /* holds the indices of the pivot objects for x,y and the dimension [x,y][dimension]*/ 
    81                 private int[][] cpivotindices; 
    82  
    83                 /* holds the sizes of the cluster multiple "boxes" per cluster */ 
    84                 private HashMap<Integer, ArrayList<Double[][]>> csize; 
    85                  
    86                 /* debug vars */ 
    87                 @SuppressWarnings("unused") 
    88                 private boolean show_biggest = true; 
    89                  
    90                 @SuppressWarnings("unused") 
    91                 private int CFOUND = 0; 
    92                 @SuppressWarnings("unused") 
    93                 private int CNOTFOUND = 0; 
    94                  
    95                  
    96                 private Instance createInstance(Instances instances, Instance instance) { 
    97                         // attributes for feeding instance to classifier 
    98                         Set<String> attributeNames = new HashSet<>(); 
    99                         for( int j=0; j<instances.numAttributes(); j++ ) { 
    100                                 attributeNames.add(instances.attribute(j).name()); 
    101                         } 
    102                          
    103                         double[] values = new double[instances.numAttributes()]; 
    104                         int index = 0; 
    105                         for( int j=0; j<instance.numAttributes(); j++ ) { 
    106                                 if( attributeNames.contains(instance.attribute(j).name())) { 
    107                                         values[index] = instance.value(j); 
    108                                         index++; 
    109                                 } 
    110                         } 
    111                          
    112                         Instances tmp = new Instances(instances); 
    113                         tmp.clear(); 
    114                         Instance instCopy = new DenseInstance(instance.weight(), values); 
    115                         instCopy.setDataset(tmp); 
    116                          
    117                         return instCopy; 
    118                 } 
    119                  
    120                 /** 
    121                  * Because Fastmap saves only the image not the values of the attributes it used 
    122                  * we can not use the old data directly to classify single instances to clusters. 
    123                  *  
    124                  * To classify a single instance we do a new fastmap computation with only the instance and 
    125                  * the old pivot elements. 
    126                  *  
    127                  * After that we find the cluster with our fastmap result for x and y. 
    128                  */ 
    129                 @Override 
    130                 public double classifyInstance(Instance instance) { 
    131                          
    132                         double ret = 0; 
    133                         try { 
    134                                 // classinstance gets passed to classifier 
    135                                 Instances traindata = ctraindata.get(0); 
    136                                 Instance classInstance = createInstance(traindata, instance); 
    137  
    138                                 // this one keeps the class attribute 
    139                                 Instances traindata2 = ctraindata.get(1);   
    140                                  
    141                                 // remove class attribute before clustering 
    142                                 Remove filter = new Remove(); 
    143                                 filter.setAttributeIndices("" + (traindata.classIndex() + 1)); 
    144                                 filter.setInputFormat(traindata); 
    145                                 traindata = Filter.useFilter(traindata, filter); 
    146                                 Instance clusterInstance = createInstance(traindata, instance); 
    147                                  
    148                                 Fastmap FMAP = new Fastmap(2); 
    149                                 EuclideanDistance dist = new EuclideanDistance(traindata); 
    150                                  
    151                                 // we set our pivot indices [x=0,y=1][dimension] 
    152                                 int[][] npivotindices = new int[2][2]; 
    153                                 npivotindices[0][0] = 1; 
    154                                 npivotindices[1][0] = 2; 
    155                                 npivotindices[0][1] = 3; 
    156                                 npivotindices[1][1] = 4; 
    157                                  
    158                                 // build temp dist matrix (2 pivots per dimension + 1 instance we want to classify) 
    159                                 // the instance we want to classify comes first after that the pivot elements in the order defined above 
    160                                 double[][] distmat = new double[2*FMAP.target_dims+1][2*FMAP.target_dims+1]; 
    161                                 distmat[0][0] = 0; 
    162                                 distmat[0][1] = dist.distance(clusterInstance, this.cpivots.get((Integer)this.cpivotindices[0][0])); 
    163                                 distmat[0][2] = dist.distance(clusterInstance, this.cpivots.get((Integer)this.cpivotindices[1][0])); 
    164                                 distmat[0][3] = dist.distance(clusterInstance, this.cpivots.get((Integer)this.cpivotindices[0][1])); 
    165                                 distmat[0][4] = dist.distance(clusterInstance, this.cpivots.get((Integer)this.cpivotindices[1][1])); 
    166                                  
    167                                 distmat[1][0] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][0]), clusterInstance); 
    168                                 distmat[1][1] = 0; 
    169                                 distmat[1][2] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][0]), this.cpivots.get((Integer)this.cpivotindices[1][0])); 
    170                                 distmat[1][3] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][0]), this.cpivots.get((Integer)this.cpivotindices[0][1])); 
    171                                 distmat[1][4] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][0]), this.cpivots.get((Integer)this.cpivotindices[1][1])); 
    172                                  
    173                                 distmat[2][0] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][0]), clusterInstance); 
    174                                 distmat[2][1] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][0]), this.cpivots.get((Integer)this.cpivotindices[0][0])); 
    175                                 distmat[2][2] = 0; 
    176                                 distmat[2][3] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][0]), this.cpivots.get((Integer)this.cpivotindices[0][1])); 
    177                                 distmat[2][4] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][0]), this.cpivots.get((Integer)this.cpivotindices[1][1])); 
    178                                  
    179                                 distmat[3][0] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][1]), clusterInstance); 
    180                                 distmat[3][1] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][1]), this.cpivots.get((Integer)this.cpivotindices[0][0])); 
    181                                 distmat[3][2] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][1]), this.cpivots.get((Integer)this.cpivotindices[1][0])); 
    182                                 distmat[3][3] = 0; 
    183                                 distmat[3][4] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][1]), this.cpivots.get((Integer)this.cpivotindices[1][1])); 
    184  
    185                                 distmat[4][0] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][1]), clusterInstance); 
    186                                 distmat[4][1] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][1]), this.cpivots.get((Integer)this.cpivotindices[0][0])); 
    187                                 distmat[4][2] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][1]), this.cpivots.get((Integer)this.cpivotindices[1][0])); 
    188                                 distmat[4][3] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][1]), this.cpivots.get((Integer)this.cpivotindices[0][1])); 
    189                                 distmat[4][4] = 0; 
    190                                  
    191                                  
    192                                 /* debug output: show biggest distance found within the new distance matrix 
    193                                 double biggest = 0; 
    194                                 for(int i=0; i < distmat.length; i++) { 
    195                                         for(int j=0; j < distmat[0].length; j++) { 
    196                                                 if(biggest < distmat[i][j]) { 
    197                                                         biggest = distmat[i][j]; 
    198                                                 } 
    199                                         } 
    200                                 } 
    201                                 if(this.show_biggest) { 
    202                                         Console.traceln(Level.INFO, String.format(""+clusterInstance)); 
    203                                         Console.traceln(Level.INFO, String.format("biggest distances: "+ biggest)); 
    204                                         this.show_biggest = false; 
    205                                 } 
    206                                 */ 
    207  
    208                                 FMAP.setDistmat(distmat); 
    209                                 FMAP.setPivots(npivotindices); 
    210                                 FMAP.calculate(); 
    211                                 double[][] x = FMAP.getX(); 
    212                                 double[] proj = x[0]; 
    213  
    214                                 // debug output: show the calculated distance matrix, our result vektor for the instance and the complete result matrix 
    215                                 /* 
    216                                 Console.traceln(Level.INFO, "distmat:"); 
    217                             for(int i=0; i<distmat.length; i++){ 
    218                                 for(int j=0; j<distmat[0].length; j++){ 
    219                                         Console.trace(Level.INFO, String.format("%20s", distmat[i][j])); 
    220                                 } 
    221                                 Console.traceln(Level.INFO, ""); 
    222                             } 
    223                              
    224                             Console.traceln(Level.INFO, "vector:"); 
    225                             for(int i=0; i < proj.length; i++) { 
    226                                 Console.trace(Level.INFO, String.format("%20s", proj[i])); 
    227                             } 
    228                             Console.traceln(Level.INFO, ""); 
    229                              
    230                                 Console.traceln(Level.INFO, "resultmat:"); 
    231                             for(int i=0; i<x.length; i++){ 
    232                                 for(int j=0; j<x[0].length; j++){ 
    233                                         Console.trace(Level.INFO, String.format("%20s", x[i][j])); 
    234                                 } 
    235                                 Console.traceln(Level.INFO, ""); 
    236                             } 
    237                             */ 
    238                                  
    239                                 // now we iterate over all clusters (well, boxes of sizes per cluster really) and save the number of the  
    240                                 // cluster in which we are 
    241                                 int cnumber; 
    242                                 int found_cnumber = -1; 
    243                                 Iterator<Integer> clusternumber = this.csize.keySet().iterator(); 
    244                                 while ( clusternumber.hasNext() && found_cnumber == -1) { 
    245                                         cnumber = clusternumber.next(); 
    246                                          
    247                                         // now iterate over the boxes of the cluster and hope we find one (cluster could have been removed) 
    248                                         // or we are too far away from any cluster because of the fastmap calculation with the initial pivot objects 
    249                                         for ( int box=0; box < this.csize.get(cnumber).size(); box++ ) {  
    250                                                 Double[][] current = this.csize.get(cnumber).get(box); 
    251                                                  
    252                                                 if(proj[0] >= current[0][0] && proj[0] <= current[0][1] &&  // x  
    253                                                    proj[1] >= current[1][0] && proj[1] <= current[1][1]) {  // y 
    254                                                         found_cnumber = cnumber; 
    255                                                 } 
    256                                         } 
    257                                 } 
    258                                  
    259                                 // we want to count how often we are really inside a cluster 
    260                                 //if ( found_cnumber == -1 ) { 
    261                                 //      CNOTFOUND += 1; 
    262                                 //}else { 
    263                                 //      CFOUND += 1; 
    264                                 //} 
    265  
    266                                 // now it can happen that we do not find a cluster because we deleted it previously (too few instances) 
    267                                 // or we get bigger distance measures from weka so that we are completely outside of our clusters. 
    268                                 // in these cases we just find the nearest cluster to our instance and use it for classification. 
    269                                 // to do that we use the EuclideanDistance again to compare our distance to all other Instances 
    270                                 // then we take the cluster of the closest weka instance 
    271                                 dist = new EuclideanDistance(traindata2); 
    272                                 if( !this.ctraindata.containsKey(found_cnumber) ) {  
    273                                         double min_distance = Double.MAX_VALUE; 
    274                                         clusternumber = ctraindata.keySet().iterator(); 
    275                                         while ( clusternumber.hasNext() ) { 
    276                                                 cnumber = clusternumber.next(); 
    277                                                 for(int i=0; i < ctraindata.get(cnumber).size(); i++) { 
    278                                                         if(dist.distance(instance, ctraindata.get(cnumber).get(i)) <= min_distance) { 
    279                                                                 found_cnumber = cnumber; 
    280                                                                 min_distance = dist.distance(instance, ctraindata.get(cnumber).get(i)); 
    281                                                         } 
    282                                                 } 
    283                                         } 
    284                                 } 
    285                                  
    286                                 // here we have the cluster where an instance has the minimum distance between itself and the 
    287                                 // instance we want to classify 
    288                                 // if we still have not found a cluster we exit because something is really wrong 
    289                                 if( found_cnumber == -1 ) { 
    290                                         Console.traceln(Level.INFO, String.format("ERROR matching instance to cluster with full search!")); 
    291                                         throw new RuntimeException("cluster not found with full search"); 
    292                                 } 
    293                                  
    294                                 // classify the passed instance with the cluster we found and its training data 
    295                                 ret = cclassifier.get(found_cnumber).classifyInstance(classInstance); 
    296                                  
    297                         }catch( Exception e ) { 
    298                                 Console.traceln(Level.INFO, String.format("ERROR matching instance to cluster!")); 
    299                                 throw new RuntimeException(e); 
    300                         } 
    301                         return ret; 
    302                 } 
    303                  
    304                 @Override 
    305                 public void buildClassifier(Instances traindata) throws Exception { 
    306                          
    307                         //Console.traceln(Level.INFO, String.format("found: "+ CFOUND + ", notfound: " + CNOTFOUND)); 
    308                         this.show_biggest = true; 
    309                          
    310                         cclassifier = new HashMap<Integer, Classifier>(); 
    311                         ctraindata = new HashMap<Integer, Instances>(); 
    312                         cpivots = new HashMap<Integer, Instance>(); 
    313                         cpivotindices = new int[2][2]; 
    314                          
    315                         // 1. copy traindata 
    316                         Instances train = new Instances(traindata); 
    317                         Instances train2 = new Instances(traindata);  // this one keeps the class attribute 
    318                          
    319                         // 2. remove class attribute for clustering 
    320                         Remove filter = new Remove(); 
    321                         filter.setAttributeIndices("" + (train.classIndex() + 1)); 
    322                         filter.setInputFormat(train); 
    323                         train = Filter.useFilter(train, filter); 
    324                          
    325                         // 3. calculate distance matrix (needed for Fastmap because it starts at dimension 1) 
    326                         double biggest = 0; 
    327                         EuclideanDistance dist = new EuclideanDistance(train); 
    328                         double[][] distmat = new double[train.size()][train.size()]; 
    329                         for( int i=0; i < train.size(); i++ ) { 
    330                                 for( int j=0; j < train.size(); j++ ) { 
    331                                         distmat[i][j] = dist.distance(train.get(i), train.get(j)); 
    332                                         if( distmat[i][j] > biggest ) { 
    333                                                 biggest = distmat[i][j]; 
    334                                         } 
    335                                 } 
    336                         } 
    337                         //Console.traceln(Level.INFO, String.format("biggest distances: "+ biggest)); 
    338                          
    339                         // 4. run fastmap for 2 dimensions on the distance matrix 
    340                         Fastmap FMAP = new Fastmap(2); 
    341                         FMAP.setDistmat(distmat); 
    342                         FMAP.calculate(); 
    343                          
    344                         cpivotindices = FMAP.getPivots(); 
    345                          
    346                         double[][] X = FMAP.getX(); 
    347                         distmat = new double[0][0]; 
    348                         System.gc(); 
    349                          
    350                         // quadtree payload generation 
    351                         ArrayList<QuadTreePayload<Instance>> qtp = new ArrayList<QuadTreePayload<Instance>>(); 
    352                      
    353                         // we need these for the sizes of the quadrants 
    354                         double[] big = {0,0}; 
    355                         double[] small = {Double.MAX_VALUE,Double.MAX_VALUE}; 
    356                          
    357                         // set quadtree payload values and get max and min x and y values for size 
    358                     for( int i=0; i<X.length; i++ ){ 
    359                         if(X[i][0] >= big[0]) { 
    360                                 big[0] = X[i][0]; 
    361                         } 
    362                         if(X[i][1] >= big[1]) { 
    363                                 big[1] = X[i][1]; 
    364                         } 
    365                         if(X[i][0] <= small[0]) { 
    366                                 small[0] = X[i][0]; 
    367                         } 
    368                         if(X[i][1] <= small[1]) { 
    369                                 small[1] = X[i][1]; 
    370                         } 
    371                         QuadTreePayload<Instance> tmp = new QuadTreePayload<Instance>(X[i][0], X[i][1], train2.get(i)); 
    372                         qtp.add(tmp); 
    373                     } 
    374                      
    375                     //Console.traceln(Level.INFO, String.format("size for cluster ("+small[0]+","+small[1]+") - ("+big[0]+","+big[1]+")")); 
    376                      
    377                     // 5. generate quadtree 
    378                     QuadTree TREE = new QuadTree(null, qtp); 
    379                     QuadTree.size = train.size(); 
    380                     QuadTree.alpha = Math.sqrt(train.size()); 
    381                     QuadTree.ccluster = new ArrayList<ArrayList<QuadTreePayload<Instance>>>(); 
    382                     QuadTree.csize = new HashMap<Integer, ArrayList<Double[][]>>(); 
    383                      
    384                     //Console.traceln(Level.INFO, String.format("Generate QuadTree with "+ QuadTree.size + " size, Alpha: "+ QuadTree.alpha+ "")); 
    385                      
    386                     // set the size and then split the tree recursively at the median value for x, y 
    387                     TREE.setSize(new double[] {small[0], big[0]}, new double[] {small[1], big[1]}); 
    388                      
    389                     // recursive split und grid clustering eher static 
    390                     TREE.recursiveSplit(TREE); 
    391                      
    392                     // generate list of nodes sorted by density (childs only) 
    393                     ArrayList<QuadTree> l = new ArrayList<QuadTree>(TREE.getList(TREE)); 
    394                      
    395                     // recursive grid clustering (tree pruning), the values are stored in ccluster 
    396                     TREE.gridClustering(l); 
    397                      
    398                     // wir iterieren durch die cluster und sammeln uns die instanzen daraus 
    399                     //ctraindata.clear(); 
    400                     for( int i=0; i < QuadTree.ccluster.size(); i++ ) { 
    401                         ArrayList<QuadTreePayload<Instance>> current = QuadTree.ccluster.get(i); 
    402                          
    403                         // i is the clusternumber 
    404                         // we only allow clusters with Instances > ALPHA, other clusters are not considered! 
    405                         //if(current.size() > QuadTree.alpha) { 
    406                         if( current.size() > 4 ) { 
    407                                 for( int j=0; j < current.size(); j++ ) { 
    408                                         if( !ctraindata.containsKey(i) ) { 
    409                                                 ctraindata.put(i, new Instances(train2)); 
    410                                                 ctraindata.get(i).delete(); 
    411                                         } 
    412                                         ctraindata.get(i).add(current.get(j).getInst()); 
    413                                 } 
    414                         }else{ 
    415                                 Console.traceln(Level.INFO, String.format("drop cluster, only: " + current.size() + " instances")); 
    416                         } 
    417                     } 
    418                          
    419                         // here we keep things we need later on 
    420                         // QuadTree sizes for later use (matching new instances) 
    421                         this.csize = new HashMap<Integer, ArrayList<Double[][]>>(QuadTree.csize); 
    422                  
    423                         // pivot elements 
    424                         //this.cpivots.clear(); 
    425                         for( int i=0; i < FMAP.PA[0].length; i++ ) { 
    426                                 this.cpivots.put(FMAP.PA[0][i], (Instance)train.get(FMAP.PA[0][i]).copy()); 
    427                         } 
    428                         for( int j=0; j < FMAP.PA[0].length; j++ ) { 
    429                                 this.cpivots.put(FMAP.PA[1][j], (Instance)train.get(FMAP.PA[1][j]).copy()); 
    430                         } 
    431                          
    432                          
    433                         /* debug output 
    434                         int pnumber; 
    435                         Iterator<Integer> pivotnumber = cpivots.keySet().iterator(); 
    436                         while ( pivotnumber.hasNext() ) { 
    437                                 pnumber = pivotnumber.next(); 
    438                                 Console.traceln(Level.INFO, String.format("pivot: "+pnumber+ " inst: "+cpivots.get(pnumber))); 
    439                         } 
    440                         */ 
    441                          
    442                     // train one classifier per cluster, we get the cluster number from the traindata 
    443                     int cnumber; 
    444                         Iterator<Integer> clusternumber = ctraindata.keySet().iterator(); 
    445                         //cclassifier.clear(); 
    446                          
    447                         //int traindata_count = 0; 
    448                         while ( clusternumber.hasNext() ) { 
    449                                 cnumber = clusternumber.next(); 
    450                                 cclassifier.put(cnumber,setupClassifier());  // this is the classifier used for the cluster  
    451                                 cclassifier.get(cnumber).buildClassifier(ctraindata.get(cnumber)); 
    452                                 //Console.traceln(Level.INFO, String.format("classifier in cluster "+cnumber)); 
    453                                 //traindata_count += ctraindata.get(cnumber).size(); 
    454                                 //Console.traceln(Level.INFO, String.format("building classifier in cluster "+cnumber +"  with "+ ctraindata.get(cnumber).size() +" traindata instances")); 
    455                         } 
    456                          
    457                         // add all traindata 
    458                         //Console.traceln(Level.INFO, String.format("traindata in all clusters: " + traindata_count)); 
    459                 } 
    460         } 
    461          
    462  
    463         /** 
    464          * Payload for the QuadTree. 
    465          * x and y are the calculated Fastmap values. 
    466          * T is a weka instance. 
    467          */ 
    468         public class QuadTreePayload<T> { 
    469  
    470                 public double x; 
    471                 public double y; 
    472                 private T inst; 
    473                  
    474                 public QuadTreePayload(double x, double y, T value) { 
    475                         this.x = x; 
    476                         this.y = y; 
    477                         this.inst = value; 
    478                 } 
    479                  
    480                 public T getInst() { 
    481                         return this.inst; 
    482                 } 
    483         } 
    484          
    485          
    486         /** 
    487          * Fastmap implementation 
    488          *  
    489          * Faloutsos, C., & Lin, K. I. (1995).  
    490          * FastMap: A fast algorithm for indexing, data-mining and visualization of traditional and multimedia datasets  
    491          * (Vol. 24, No. 2, pp. 163-174). ACM. 
    492          */ 
    493         public class Fastmap { 
    494                  
    495                 /*N x k Array, at the end, the i-th row will be the image of the i-th object*/ 
    496                 private double[][] X; 
    497                  
    498                 /*2 x k pivot Array one pair per recursive call*/ 
    499                 private int[][] PA; 
    500                  
    501                 /*Objects we got (distance matrix)*/ 
    502                 private double[][] O; 
    503                  
    504                 /*column of X currently updated (also the dimension)*/ 
    505                 private int col = 0; 
    506                  
    507                 /*number of dimensions we want*/ 
    508                 private int target_dims = 0; 
    509                  
    510                 // if we already have the pivot elements 
    511                 private boolean pivot_set = false; 
    512                  
    513  
    514                 public Fastmap(int k) { 
    515                         this.target_dims = k; 
    516                 } 
    517                  
    518                 /** 
    519                  * Sets the distance matrix 
    520                  * and params that depend on this 
    521                  * @param O 
    522                  */ 
    523                 public void setDistmat(double[][] O) { 
    524                         this.O = O; 
    525                         int N = O.length; 
    526                         this.X = new double[N][this.target_dims]; 
    527                         this.PA = new int[2][this.target_dims]; 
    528                 } 
    529                  
    530                 /** 
    531                  * Set pivot elements, we need that to classify instances 
    532                  * after the calculation is complete (because we then want to reuse 
    533                  * only the pivot elements). 
    534                  *  
    535                  * @param pi 
    536                  */ 
    537                 public void setPivots(int[][] pi) { 
    538                         this.pivot_set = true; 
    539                         this.PA = pi; 
    540                 } 
    541                  
    542                 /** 
    543                  * Return the pivot elements that were chosen during the calculation 
    544                  *  
    545                  * @return 
    546                  */ 
    547                 public int[][] getPivots() { 
    548                         return this.PA; 
    549                 } 
    550                  
    551                 /** 
    552                  * The distance function for euclidean distance 
    553                  *  
    554                  * Acts according to equation 4 of the fastmap paper 
    555                  *   
    556                  * @param x x index of x image (if k==0 x object) 
    557                  * @param y y index of y image (if k==0 y object) 
    558                  * @param kdimensionality 
    559                  * @return distance 
    560                  */ 
    561                 private double dist(int x, int y, int k) { 
    562                          
    563                         // basis is object distance, we get this from our distance matrix 
    564                         double tmp = this.O[x][y] * this.O[x][y];  
    565                          
    566                         // decrease by projections 
    567                         for( int i=0; i < k; i++ ) { 
    568                                 double tmp2 = (this.X[x][i] - this.X[y][i]); 
    569                                 tmp -= tmp2 * tmp2; 
    570                         } 
    571                          
    572                         return Math.abs(tmp); 
    573                 } 
    574  
    575                 /** 
    576                  * Find the object farthest from the given index 
    577                  * This method is a helper Method for findDistandObjects 
    578                  *  
    579                  * @param index of the object  
    580                  * @return index of the farthest object from the given index 
    581                  */ 
    582                 private int findFarthest(int index) { 
    583                         double furthest = Double.MIN_VALUE; 
    584                         int ret = 0; 
    585                          
    586                         for( int i=0; i < O.length; i++ ) { 
    587                                 double dist = this.dist(i, index, this.col); 
    588                                 if( i != index && dist > furthest ) { 
    589                                         furthest = dist; 
    590                                         ret = i; 
    591                                 } 
    592                         } 
    593                         return ret; 
    594                 } 
    595                  
    596                 /** 
    597                  * Finds the pivot objects  
    598                  *  
    599                  * This method is basically algorithm 1 of the fastmap paper. 
    600                  *  
    601                  * @return 2 indexes of the choosen pivot objects 
    602                  */ 
    603                 private int[] findDistantObjects() { 
    604                         // 1. choose object randomly 
    605                         Random r = new Random(); 
    606                         int obj = r.nextInt(this.O.length); 
    607                          
    608                         // 2. find farthest object from randomly chosen object 
    609                         int idx1 = this.findFarthest(obj); 
    610                          
    611                         // 3. find farthest object from previously farthest object 
    612                         int idx2 = this.findFarthest(idx1); 
    613  
    614                         return new int[] {idx1, idx2}; 
    615                 } 
    616          
    617                 /** 
    618                  * Calculates the new k-vector values (projections) 
    619                  *  
    620                  * This is basically algorithm 2 of the fastmap paper. 
    621                  * We just added the possibility to pre-set the pivot elements because 
    622                  * we need to classify single instances after the computation is already done. 
    623                  *  
    624                  * @param dims dimensionality 
    625                  */ 
    626                 public void calculate() { 
    627                          
    628                         for( int k=0; k < this.target_dims; k++ ) { 
    629                                 // 2) choose pivot objects 
    630                                 if ( !this.pivot_set ) { 
    631                                         int[] pivots = this.findDistantObjects(); 
    632                  
    633                                         // 3) record ids of pivot objects  
    634                                         this.PA[0][this.col] = pivots[0]; 
    635                                         this.PA[1][this.col] = pivots[1]; 
    636                                 } 
    637                                  
    638                                 // 4) inter object distances are zero (this.X is initialized with 0 so we just continue) 
    639                                 if( this.dist(this.PA[0][this.col], this.PA[1][this.col], this.col) == 0 ) { 
    640                                         continue; 
    641                                 } 
    642                                  
    643                                 // 5) project the objects on the line between the pivots 
    644                                 double dxy = this.dist(this.PA[0][this.col], this.PA[1][this.col], this.col); 
    645                                 for( int i=0; i < this.O.length; i++ ) { 
    646                                          
    647                                         double dix = this.dist(i, this.PA[0][this.col], this.col); 
    648                                         double diy = this.dist(i, this.PA[1][this.col], this.col); 
    649                                          
    650                                         double tmp = (dix + dxy - diy) / (2 * Math.sqrt(dxy)); 
    651                                          
    652                                         // save the projection 
    653                                         this.X[i][this.col] = tmp; 
    654                                 } 
    655                                  
    656                                 this.col += 1; 
    657                         } 
    658                 } 
    659                  
    660                 /** 
    661                  * returns the result matrix of the projections 
    662                  *  
    663                  * @return calculated result 
    664                  */ 
    665                 public double[][] getX() { 
    666                         return this.X; 
    667                 } 
    668         } 
     58 
     59    private final TraindatasetCluster classifier = new TraindatasetCluster(); 
     60 
     61    @Override 
     62    public Classifier getClassifier() { 
     63        return classifier; 
     64    } 
     65 
     66    @Override 
     67    public void apply(Instances traindata) { 
     68        PrintStream errStr = System.err; 
     69        System.setErr(new PrintStream(new NullOutputStream())); 
     70        try { 
     71            classifier.buildClassifier(traindata); 
     72        } 
     73        catch (Exception e) { 
     74            throw new RuntimeException(e); 
     75        } 
     76        finally { 
     77            System.setErr(errStr); 
     78        } 
     79    } 
     80 
     81    public class TraindatasetCluster extends AbstractClassifier { 
     82 
     83        private static final long serialVersionUID = 1L; 
     84 
     85        /* classifier per cluster */ 
     86        private HashMap<Integer, Classifier> cclassifier; 
     87 
     88        /* instances per cluster */ 
     89        private HashMap<Integer, Instances> ctraindata; 
     90 
     91        /* 
     92         * holds the instances and indices of the pivot objects of the Fastmap calculation in 
     93         * buildClassifier 
     94         */ 
     95        private HashMap<Integer, Instance> cpivots; 
     96 
     97        /* holds the indices of the pivot objects for x,y and the dimension [x,y][dimension] */ 
     98        private int[][] cpivotindices; 
     99 
     100        /* holds the sizes of the cluster multiple "boxes" per cluster */ 
     101        private HashMap<Integer, ArrayList<Double[][]>> csize; 
     102 
     103        /* debug vars */ 
     104        @SuppressWarnings("unused") 
     105        private boolean show_biggest = true; 
     106 
     107        @SuppressWarnings("unused") 
     108        private int CFOUND = 0; 
     109        @SuppressWarnings("unused") 
     110        private int CNOTFOUND = 0; 
     111 
     112        private Instance createInstance(Instances instances, Instance instance) { 
     113            // attributes for feeding instance to classifier 
     114            Set<String> attributeNames = new HashSet<>(); 
     115            for (int j = 0; j < instances.numAttributes(); j++) { 
     116                attributeNames.add(instances.attribute(j).name()); 
     117            } 
     118 
     119            double[] values = new double[instances.numAttributes()]; 
     120            int index = 0; 
     121            for (int j = 0; j < instance.numAttributes(); j++) { 
     122                if (attributeNames.contains(instance.attribute(j).name())) { 
     123                    values[index] = instance.value(j); 
     124                    index++; 
     125                } 
     126            } 
     127 
     128            Instances tmp = new Instances(instances); 
     129            tmp.clear(); 
     130            Instance instCopy = new DenseInstance(instance.weight(), values); 
     131            instCopy.setDataset(tmp); 
     132 
     133            return instCopy; 
     134        } 
     135 
     136        /** 
     137         * Because Fastmap saves only the image not the values of the attributes it used we can not 
     138         * use the old data directly to classify single instances to clusters. 
     139         *  
     140         * To classify a single instance we do a new fastmap computation with only the instance and 
     141         * the old pivot elements. 
     142         *  
     143         * After that we find the cluster with our fastmap result for x and y. 
     144         */ 
     145        @Override 
     146        public double classifyInstance(Instance instance) { 
     147 
     148            double ret = 0; 
     149            try { 
     150                // classinstance gets passed to classifier 
     151                Instances traindata = ctraindata.get(0); 
     152                Instance classInstance = createInstance(traindata, instance); 
     153 
     154                // this one keeps the class attribute 
     155                Instances traindata2 = ctraindata.get(1); 
     156 
     157                // remove class attribute before clustering 
     158                Remove filter = new Remove(); 
     159                filter.setAttributeIndices("" + (traindata.classIndex() + 1)); 
     160                filter.setInputFormat(traindata); 
     161                traindata = Filter.useFilter(traindata, filter); 
     162                Instance clusterInstance = createInstance(traindata, instance); 
     163 
     164                Fastmap FMAP = new Fastmap(2); 
     165                EuclideanDistance dist = new EuclideanDistance(traindata); 
     166 
     167                // we set our pivot indices [x=0,y=1][dimension] 
     168                int[][] npivotindices = new int[2][2]; 
     169                npivotindices[0][0] = 1; 
     170                npivotindices[1][0] = 2; 
     171                npivotindices[0][1] = 3; 
     172                npivotindices[1][1] = 4; 
     173 
     174                // build temp dist matrix (2 pivots per dimension + 1 instance we want to classify) 
     175                // the instance we want to classify comes first after that the pivot elements in the 
     176                // order defined above 
     177                double[][] distmat = new double[2 * FMAP.target_dims + 1][2 * FMAP.target_dims + 1]; 
     178                distmat[0][0] = 0; 
     179                distmat[0][1] = 
     180                    dist.distance(clusterInstance, 
     181                                  this.cpivots.get((Integer) this.cpivotindices[0][0])); 
     182                distmat[0][2] = 
     183                    dist.distance(clusterInstance, 
     184                                  this.cpivots.get((Integer) this.cpivotindices[1][0])); 
     185                distmat[0][3] = 
     186                    dist.distance(clusterInstance, 
     187                                  this.cpivots.get((Integer) this.cpivotindices[0][1])); 
     188                distmat[0][4] = 
     189                    dist.distance(clusterInstance, 
     190                                  this.cpivots.get((Integer) this.cpivotindices[1][1])); 
     191 
     192                distmat[1][0] = 
     193                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][0]), 
     194                                  clusterInstance); 
     195                distmat[1][1] = 0; 
     196                distmat[1][2] = 
     197                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][0]), 
     198                                  this.cpivots.get((Integer) this.cpivotindices[1][0])); 
     199                distmat[1][3] = 
     200                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][0]), 
     201                                  this.cpivots.get((Integer) this.cpivotindices[0][1])); 
     202                distmat[1][4] = 
     203                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][0]), 
     204                                  this.cpivots.get((Integer) this.cpivotindices[1][1])); 
     205 
     206                distmat[2][0] = 
     207                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][0]), 
     208                                  clusterInstance); 
     209                distmat[2][1] = 
     210                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][0]), 
     211                                  this.cpivots.get((Integer) this.cpivotindices[0][0])); 
     212                distmat[2][2] = 0; 
     213                distmat[2][3] = 
     214                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][0]), 
     215                                  this.cpivots.get((Integer) this.cpivotindices[0][1])); 
     216                distmat[2][4] = 
     217                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][0]), 
     218                                  this.cpivots.get((Integer) this.cpivotindices[1][1])); 
     219 
     220                distmat[3][0] = 
     221                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][1]), 
     222                                  clusterInstance); 
     223                distmat[3][1] = 
     224                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][1]), 
     225                                  this.cpivots.get((Integer) this.cpivotindices[0][0])); 
     226                distmat[3][2] = 
     227                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][1]), 
     228                                  this.cpivots.get((Integer) this.cpivotindices[1][0])); 
     229                distmat[3][3] = 0; 
     230                distmat[3][4] = 
     231                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][1]), 
     232                                  this.cpivots.get((Integer) this.cpivotindices[1][1])); 
     233 
     234                distmat[4][0] = 
     235                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][1]), 
     236                                  clusterInstance); 
     237                distmat[4][1] = 
     238                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][1]), 
     239                                  this.cpivots.get((Integer) this.cpivotindices[0][0])); 
     240                distmat[4][2] = 
     241                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][1]), 
     242                                  this.cpivots.get((Integer) this.cpivotindices[1][0])); 
     243                distmat[4][3] = 
     244                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][1]), 
     245                                  this.cpivots.get((Integer) this.cpivotindices[0][1])); 
     246                distmat[4][4] = 0; 
     247 
     248                /* 
     249                 * debug output: show biggest distance found within the new distance matrix double 
     250                 * biggest = 0; for(int i=0; i < distmat.length; i++) { for(int j=0; j < 
     251                 * distmat[0].length; j++) { if(biggest < distmat[i][j]) { biggest = distmat[i][j]; 
     252                 * } } } if(this.show_biggest) { Console.traceln(Level.INFO, 
     253                 * String.format(""+clusterInstance)); Console.traceln(Level.INFO, 
     254                 * String.format("biggest distances: "+ biggest)); this.show_biggest = false; } 
     255                 */ 
     256 
     257                FMAP.setDistmat(distmat); 
     258                FMAP.setPivots(npivotindices); 
     259                FMAP.calculate(); 
     260                double[][] x = FMAP.getX(); 
     261                double[] proj = x[0]; 
     262 
     263                // debug output: show the calculated distance matrix, our result vektor for the 
     264                // instance and the complete result matrix 
     265                /* 
     266                 * Console.traceln(Level.INFO, "distmat:"); for(int i=0; i<distmat.length; i++){ 
     267                 * for(int j=0; j<distmat[0].length; j++){ Console.trace(Level.INFO, 
     268                 * String.format("%20s", distmat[i][j])); } Console.traceln(Level.INFO, ""); } 
     269                 *  
     270                 * Console.traceln(Level.INFO, "vector:"); for(int i=0; i < proj.length; i++) { 
     271                 * Console.trace(Level.INFO, String.format("%20s", proj[i])); } 
     272                 * Console.traceln(Level.INFO, ""); 
     273                 *  
     274                 * Console.traceln(Level.INFO, "resultmat:"); for(int i=0; i<x.length; i++){ for(int 
     275                 * j=0; j<x[0].length; j++){ Console.trace(Level.INFO, String.format("%20s", 
     276                 * x[i][j])); } Console.traceln(Level.INFO, ""); } 
     277                 */ 
     278 
     279                // now we iterate over all clusters (well, boxes of sizes per cluster really) and 
     280                // save the number of the 
     281                // cluster in which we are 
     282                int cnumber; 
     283                int found_cnumber = -1; 
     284                Iterator<Integer> clusternumber = this.csize.keySet().iterator(); 
     285                while (clusternumber.hasNext() && found_cnumber == -1) { 
     286                    cnumber = clusternumber.next(); 
     287 
     288                    // now iterate over the boxes of the cluster and hope we find one (cluster could 
     289                    // have been removed) 
     290                    // or we are too far away from any cluster because of the fastmap calculation 
     291                    // with the initial pivot objects 
     292                    for (int box = 0; box < this.csize.get(cnumber).size(); box++) { 
     293                        Double[][] current = this.csize.get(cnumber).get(box); 
     294 
     295                        if (proj[0] >= current[0][0] && proj[0] <= current[0][1] && // x 
     296                            proj[1] >= current[1][0] && proj[1] <= current[1][1]) 
     297                        { // y 
     298                            found_cnumber = cnumber; 
     299                        } 
     300                    } 
     301                } 
     302 
     303                // we want to count how often we are really inside a cluster 
     304                // if ( found_cnumber == -1 ) { 
     305                // CNOTFOUND += 1; 
     306                // }else { 
     307                // CFOUND += 1; 
     308                // } 
     309 
     310                // now it can happen that we do not find a cluster because we deleted it previously 
     311                // (too few instances) 
     312                // or we get bigger distance measures from weka so that we are completely outside of 
     313                // our clusters. 
     314                // in these cases we just find the nearest cluster to our instance and use it for 
     315                // classification. 
     316                // to do that we use the EuclideanDistance again to compare our distance to all 
     317                // other Instances 
     318                // then we take the cluster of the closest weka instance 
     319                dist = new EuclideanDistance(traindata2); 
     320                if (!this.ctraindata.containsKey(found_cnumber)) { 
     321                    double min_distance = Double.MAX_VALUE; 
     322                    clusternumber = ctraindata.keySet().iterator(); 
     323                    while (clusternumber.hasNext()) { 
     324                        cnumber = clusternumber.next(); 
     325                        for (int i = 0; i < ctraindata.get(cnumber).size(); i++) { 
     326                            if (dist.distance(instance, ctraindata.get(cnumber).get(i)) <= min_distance) 
     327                            { 
     328                                found_cnumber = cnumber; 
     329                                min_distance = 
     330                                    dist.distance(instance, ctraindata.get(cnumber).get(i)); 
     331                            } 
     332                        } 
     333                    } 
     334                } 
     335 
     336                // here we have the cluster where an instance has the minimum distance between 
     337                // itself and the 
     338                // instance we want to classify 
     339                // if we still have not found a cluster we exit because something is really wrong 
     340                if (found_cnumber == -1) { 
     341                    Console.traceln(Level.INFO, String 
     342                        .format("ERROR matching instance to cluster with full search!")); 
     343                    throw new RuntimeException("cluster not found with full search"); 
     344                } 
     345 
     346                // classify the passed instance with the cluster we found and its training data 
     347                ret = cclassifier.get(found_cnumber).classifyInstance(classInstance); 
     348 
     349            } 
     350            catch (Exception e) { 
     351                Console.traceln(Level.INFO, String.format("ERROR matching instance to cluster!")); 
     352                throw new RuntimeException(e); 
     353            } 
     354            return ret; 
     355        } 
     356 
     357        @Override 
     358        public void buildClassifier(Instances traindata) throws Exception { 
     359 
     360            // Console.traceln(Level.INFO, String.format("found: "+ CFOUND + ", notfound: " + 
     361            // CNOTFOUND)); 
     362            this.show_biggest = true; 
     363 
     364            cclassifier = new HashMap<Integer, Classifier>(); 
     365            ctraindata = new HashMap<Integer, Instances>(); 
     366            cpivots = new HashMap<Integer, Instance>(); 
     367            cpivotindices = new int[2][2]; 
     368 
     369            // 1. copy traindata 
     370            Instances train = new Instances(traindata); 
     371            Instances train2 = new Instances(traindata); // this one keeps the class attribute 
     372 
     373            // 2. remove class attribute for clustering 
     374            Remove filter = new Remove(); 
     375            filter.setAttributeIndices("" + (train.classIndex() + 1)); 
     376            filter.setInputFormat(train); 
     377            train = Filter.useFilter(train, filter); 
     378 
     379            // 3. calculate distance matrix (needed for Fastmap because it starts at dimension 1) 
     380            double biggest = 0; 
     381            EuclideanDistance dist = new EuclideanDistance(train); 
     382            double[][] distmat = new double[train.size()][train.size()]; 
     383            for (int i = 0; i < train.size(); i++) { 
     384                for (int j = 0; j < train.size(); j++) { 
     385                    distmat[i][j] = dist.distance(train.get(i), train.get(j)); 
     386                    if (distmat[i][j] > biggest) { 
     387                        biggest = distmat[i][j]; 
     388                    } 
     389                } 
     390            } 
     391            // Console.traceln(Level.INFO, String.format("biggest distances: "+ biggest)); 
     392 
     393            // 4. run fastmap for 2 dimensions on the distance matrix 
     394            Fastmap FMAP = new Fastmap(2); 
     395            FMAP.setDistmat(distmat); 
     396            FMAP.calculate(); 
     397 
     398            cpivotindices = FMAP.getPivots(); 
     399 
     400            double[][] X = FMAP.getX(); 
     401            distmat = new double[0][0]; 
     402            System.gc(); 
     403 
     404            // quadtree payload generation 
     405            ArrayList<QuadTreePayload<Instance>> qtp = new ArrayList<QuadTreePayload<Instance>>(); 
     406 
     407            // we need these for the sizes of the quadrants 
     408            double[] big = 
     409                { 0, 0 }; 
     410            double[] small = 
     411                { Double.MAX_VALUE, Double.MAX_VALUE }; 
     412 
     413            // set quadtree payload values and get max and min x and y values for size 
     414            for (int i = 0; i < X.length; i++) { 
     415                if (X[i][0] >= big[0]) { 
     416                    big[0] = X[i][0]; 
     417                } 
     418                if (X[i][1] >= big[1]) { 
     419                    big[1] = X[i][1]; 
     420                } 
     421                if (X[i][0] <= small[0]) { 
     422                    small[0] = X[i][0]; 
     423                } 
     424                if (X[i][1] <= small[1]) { 
     425                    small[1] = X[i][1]; 
     426                } 
     427                QuadTreePayload<Instance> tmp = 
     428                    new QuadTreePayload<Instance>(X[i][0], X[i][1], train2.get(i)); 
     429                qtp.add(tmp); 
     430            } 
     431 
     432            // Console.traceln(Level.INFO, 
     433            // String.format("size for cluster ("+small[0]+","+small[1]+") - ("+big[0]+","+big[1]+")")); 
     434 
     435            // 5. generate quadtree 
     436            QuadTree TREE = new QuadTree(null, qtp); 
     437            QuadTree.size = train.size(); 
     438            QuadTree.alpha = Math.sqrt(train.size()); 
     439            QuadTree.ccluster = new ArrayList<ArrayList<QuadTreePayload<Instance>>>(); 
     440            QuadTree.csize = new HashMap<Integer, ArrayList<Double[][]>>(); 
     441 
     442            // Console.traceln(Level.INFO, String.format("Generate QuadTree with "+ QuadTree.size + 
     443            // " size, Alpha: "+ QuadTree.alpha+ "")); 
     444 
     445            // set the size and then split the tree recursively at the median value for x, y 
     446            TREE.setSize(new double[] 
     447                { small[0], big[0] }, new double[] 
     448                { small[1], big[1] }); 
     449 
     450            // recursive split und grid clustering eher static 
     451            TREE.recursiveSplit(TREE); 
     452 
     453            // generate list of nodes sorted by density (childs only) 
     454            ArrayList<QuadTree> l = new ArrayList<QuadTree>(TREE.getList(TREE)); 
     455 
     456            // recursive grid clustering (tree pruning), the values are stored in ccluster 
     457            TREE.gridClustering(l); 
     458 
     459            // wir iterieren durch die cluster und sammeln uns die instanzen daraus 
     460            // ctraindata.clear(); 
     461            for (int i = 0; i < QuadTree.ccluster.size(); i++) { 
     462                ArrayList<QuadTreePayload<Instance>> current = QuadTree.ccluster.get(i); 
     463 
     464                // i is the clusternumber 
     465                // we only allow clusters with Instances > ALPHA, other clusters are not considered! 
     466                // if(current.size() > QuadTree.alpha) { 
     467                if (current.size() > 4) { 
     468                    for (int j = 0; j < current.size(); j++) { 
     469                        if (!ctraindata.containsKey(i)) { 
     470                            ctraindata.put(i, new Instances(train2)); 
     471                            ctraindata.get(i).delete(); 
     472                        } 
     473                        ctraindata.get(i).add(current.get(j).getInst()); 
     474                    } 
     475                } 
     476                else { 
     477                    Console.traceln(Level.INFO, 
     478                                    String.format("drop cluster, only: " + current.size() + 
     479                                        " instances")); 
     480                } 
     481            } 
     482 
     483            // here we keep things we need later on 
     484            // QuadTree sizes for later use (matching new instances) 
     485            this.csize = new HashMap<Integer, ArrayList<Double[][]>>(QuadTree.csize); 
     486 
     487            // pivot elements 
     488            // this.cpivots.clear(); 
     489            for (int i = 0; i < FMAP.PA[0].length; i++) { 
     490                this.cpivots.put(FMAP.PA[0][i], (Instance) train.get(FMAP.PA[0][i]).copy()); 
     491            } 
     492            for (int j = 0; j < FMAP.PA[0].length; j++) { 
     493                this.cpivots.put(FMAP.PA[1][j], (Instance) train.get(FMAP.PA[1][j]).copy()); 
     494            } 
     495 
     496            /* 
     497             * debug output int pnumber; Iterator<Integer> pivotnumber = 
     498             * cpivots.keySet().iterator(); while ( pivotnumber.hasNext() ) { pnumber = 
     499             * pivotnumber.next(); Console.traceln(Level.INFO, String.format("pivot: "+pnumber+ 
     500             * " inst: "+cpivots.get(pnumber))); } 
     501             */ 
     502 
     503            // train one classifier per cluster, we get the cluster number from the traindata 
     504            int cnumber; 
     505            Iterator<Integer> clusternumber = ctraindata.keySet().iterator(); 
     506            // cclassifier.clear(); 
     507 
     508            // int traindata_count = 0; 
     509            while (clusternumber.hasNext()) { 
     510                cnumber = clusternumber.next(); 
     511                cclassifier.put(cnumber, setupClassifier()); // this is the classifier used for the 
     512                                                             // cluster 
     513                cclassifier.get(cnumber).buildClassifier(ctraindata.get(cnumber)); 
     514                // Console.traceln(Level.INFO, String.format("classifier in cluster "+cnumber)); 
     515                // traindata_count += ctraindata.get(cnumber).size(); 
     516                // Console.traceln(Level.INFO, 
     517                // String.format("building classifier in cluster "+cnumber +"  with "+ 
     518                // ctraindata.get(cnumber).size() +" traindata instances")); 
     519            } 
     520 
     521            // add all traindata 
     522            // Console.traceln(Level.INFO, String.format("traindata in all clusters: " + 
     523            // traindata_count)); 
     524        } 
     525    } 
     526 
     527    /** 
     528     * Payload for the QuadTree. x and y are the calculated Fastmap values. T is a weka instance. 
     529     */ 
     530    public class QuadTreePayload<T> { 
     531 
     532        public double x; 
     533        public double y; 
     534        private T inst; 
     535 
     536        public QuadTreePayload(double x, double y, T value) { 
     537            this.x = x; 
     538            this.y = y; 
     539            this.inst = value; 
     540        } 
     541 
     542        public T getInst() { 
     543            return this.inst; 
     544        } 
     545    } 
     546 
     547    /** 
     548     * Fastmap implementation 
     549     *  
     550     * Faloutsos, C., & Lin, K. I. (1995). FastMap: A fast algorithm for indexing, data-mining and 
     551     * visualization of traditional and multimedia datasets (Vol. 24, No. 2, pp. 163-174). ACM. 
     552     */ 
     553    public class Fastmap { 
     554 
     555        /* N x k Array, at the end, the i-th row will be the image of the i-th object */ 
     556        private double[][] X; 
     557 
     558        /* 2 x k pivot Array one pair per recursive call */ 
     559        private int[][] PA; 
     560 
     561        /* Objects we got (distance matrix) */ 
     562        private double[][] O; 
     563 
     564        /* column of X currently updated (also the dimension) */ 
     565        private int col = 0; 
     566 
     567        /* number of dimensions we want */ 
     568        private int target_dims = 0; 
     569 
     570        // if we already have the pivot elements 
     571        private boolean pivot_set = false; 
     572 
     573        public Fastmap(int k) { 
     574            this.target_dims = k; 
     575        } 
     576 
     577        /** 
     578         * Sets the distance matrix and params that depend on this 
     579         *  
     580         * @param O 
     581         */ 
     582        public void setDistmat(double[][] O) { 
     583            this.O = O; 
     584            int N = O.length; 
     585            this.X = new double[N][this.target_dims]; 
     586            this.PA = new int[2][this.target_dims]; 
     587        } 
     588 
     589        /** 
     590         * Set pivot elements, we need that to classify instances after the calculation is complete 
     591         * (because we then want to reuse only the pivot elements). 
     592         *  
     593         * @param pi 
     594         */ 
     595        public void setPivots(int[][] pi) { 
     596            this.pivot_set = true; 
     597            this.PA = pi; 
     598        } 
     599 
     600        /** 
     601         * Return the pivot elements that were chosen during the calculation 
     602         *  
     603         * @return 
     604         */ 
     605        public int[][] getPivots() { 
     606            return this.PA; 
     607        } 
     608 
     609        /** 
     610         * The distance function for euclidean distance 
     611         *  
     612         * Acts according to equation 4 of the fastmap paper 
     613         *  
     614         * @param x 
     615         *            x index of x image (if k==0 x object) 
     616         * @param y 
     617         *            y index of y image (if k==0 y object) 
     618         * @param kdimensionality 
     619         * @return distance 
     620         */ 
     621        private double dist(int x, int y, int k) { 
     622 
     623            // basis is object distance, we get this from our distance matrix 
     624            double tmp = this.O[x][y] * this.O[x][y]; 
     625 
     626            // decrease by projections 
     627            for (int i = 0; i < k; i++) { 
     628                double tmp2 = (this.X[x][i] - this.X[y][i]); 
     629                tmp -= tmp2 * tmp2; 
     630            } 
     631 
     632            return Math.abs(tmp); 
     633        } 
     634 
     635        /** 
     636         * Find the object farthest from the given index This method is a helper Method for 
     637         * findDistandObjects 
     638         *  
     639         * @param index 
     640         *            of the object 
     641         * @return index of the farthest object from the given index 
     642         */ 
     643        private int findFarthest(int index) { 
     644            double furthest = Double.MIN_VALUE; 
     645            int ret = 0; 
     646 
     647            for (int i = 0; i < O.length; i++) { 
     648                double dist = this.dist(i, index, this.col); 
     649                if (i != index && dist > furthest) { 
     650                    furthest = dist; 
     651                    ret = i; 
     652                } 
     653            } 
     654            return ret; 
     655        } 
     656 
     657        /** 
     658         * Finds the pivot objects 
     659         *  
     660         * This method is basically algorithm 1 of the fastmap paper. 
     661         *  
     662         * @return 2 indexes of the choosen pivot objects 
     663         */ 
     664        private int[] findDistantObjects() { 
     665            // 1. choose object randomly 
     666            Random r = new Random(); 
     667            int obj = r.nextInt(this.O.length); 
     668 
     669            // 2. find farthest object from randomly chosen object 
     670            int idx1 = this.findFarthest(obj); 
     671 
     672            // 3. find farthest object from previously farthest object 
     673            int idx2 = this.findFarthest(idx1); 
     674 
     675            return new int[] 
     676                { idx1, idx2 }; 
     677        } 
     678 
     679        /** 
     680         * Calculates the new k-vector values (projections) 
     681         *  
     682         * This is basically algorithm 2 of the fastmap paper. We just added the possibility to 
     683         * pre-set the pivot elements because we need to classify single instances after the 
     684         * computation is already done. 
     685         *  
     686         * @param dims 
     687         *            dimensionality 
     688         */ 
     689        public void calculate() { 
     690 
     691            for (int k = 0; k < this.target_dims; k++) { 
     692                // 2) choose pivot objects 
     693                if (!this.pivot_set) { 
     694                    int[] pivots = this.findDistantObjects(); 
     695 
     696                    // 3) record ids of pivot objects 
     697                    this.PA[0][this.col] = pivots[0]; 
     698                    this.PA[1][this.col] = pivots[1]; 
     699                } 
     700 
     701                // 4) inter object distances are zero (this.X is initialized with 0 so we just 
     702                // continue) 
     703                if (this.dist(this.PA[0][this.col], this.PA[1][this.col], this.col) == 0) { 
     704                    continue; 
     705                } 
     706 
     707                // 5) project the objects on the line between the pivots 
     708                double dxy = this.dist(this.PA[0][this.col], this.PA[1][this.col], this.col); 
     709                for (int i = 0; i < this.O.length; i++) { 
     710 
     711                    double dix = this.dist(i, this.PA[0][this.col], this.col); 
     712                    double diy = this.dist(i, this.PA[1][this.col], this.col); 
     713 
     714                    double tmp = (dix + dxy - diy) / (2 * Math.sqrt(dxy)); 
     715 
     716                    // save the projection 
     717                    this.X[i][this.col] = tmp; 
     718                } 
     719 
     720                this.col += 1; 
     721            } 
     722        } 
     723 
     724        /** 
     725         * returns the result matrix of the projections 
     726         *  
     727         * @return calculated result 
     728         */ 
     729        public double[][] getX() { 
     730            return this.X; 
     731        } 
     732    } 
    669733} 
  • trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaTraining.java

    r25 r41  
     1// Copyright 2015 Georg-August-Universität Göttingen, Germany 
     2// 
     3//   Licensed under the Apache License, Version 2.0 (the "License"); 
     4//   you may not use this file except in compliance with the License. 
     5//   You may obtain a copy of the License at 
     6// 
     7//       http://www.apache.org/licenses/LICENSE-2.0 
     8// 
     9//   Unless required by applicable law or agreed to in writing, software 
     10//   distributed under the License is distributed on an "AS IS" BASIS, 
     11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
     12//   See the License for the specific language governing permissions and 
     13//   limitations under the License. 
     14 
    115package de.ugoe.cs.cpdp.training; 
    216 
     
    1125/** 
    1226 * Programmatic WekaTraining 
    13  * 
    14  * first parameter is Trainer Name. 
    15  * second parameter is class name 
    1627 *  
    17  * all subsequent parameters are configuration params (for example for trees) 
    18  * Cross Validation params always come last and are prepended with -CVPARAM 
     28 * first parameter is Trainer Name. second parameter is class name 
     29 *  
     30 * all subsequent parameters are configuration params (for example for trees) Cross Validation 
     31 * params always come last and are prepended with -CVPARAM 
    1932 *  
    2033 * XML Configurations for Weka Classifiers: 
     34 *  
    2135 * <pre> 
    2236 * {@code 
     
    3044public class WekaTraining extends WekaBaseTraining implements ITrainingStrategy { 
    3145 
    32         @Override 
    33         public void apply(Instances traindata) { 
    34                 PrintStream errStr      = System.err; 
    35                 System.setErr(new PrintStream(new NullOutputStream())); 
    36                 try { 
    37                         if(classifier == null) { 
    38                                 Console.traceln(Level.WARNING, String.format("classifier null!")); 
    39                         } 
    40                         classifier.buildClassifier(traindata); 
    41                 } catch (Exception e) { 
    42                         throw new RuntimeException(e); 
    43                 } finally { 
    44                         System.setErr(errStr); 
    45                 } 
    46         } 
     46    @Override 
     47    public void apply(Instances traindata) { 
     48        PrintStream errStr = System.err; 
     49        System.setErr(new PrintStream(new NullOutputStream())); 
     50        try { 
     51            if (classifier == null) { 
     52                Console.traceln(Level.WARNING, String.format("classifier null!")); 
     53            } 
     54            classifier.buildClassifier(traindata); 
     55        } 
     56        catch (Exception e) { 
     57            throw new RuntimeException(e); 
     58        } 
     59        finally { 
     60            System.setErr(errStr); 
     61        } 
     62    } 
    4763} 
Note: See TracChangeset for help on using the changeset viewer.