Ignore:
Timestamp:
09/24/15 10:59:05 (9 years ago)
Author:
sherbold
Message:
  • formatted code and added copyrights
File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/CrossPare/src/de/ugoe/cs/cpdp/dataselection/PointWiseEMClusterSelection.java

    r2 r41  
     1// Copyright 2015 Georg-August-Universität Göttingen, Germany 
     2// 
     3//   Licensed under the Apache License, Version 2.0 (the "License"); 
     4//   you may not use this file except in compliance with the License. 
     5//   You may obtain a copy of the License at 
     6// 
     7//       http://www.apache.org/licenses/LICENSE-2.0 
     8// 
     9//   Unless required by applicable law or agreed to in writing, software 
     10//   distributed under the License is distributed on an "AS IS" BASIS, 
     11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
     12//   See the License for the specific language governing permissions and 
     13//   limitations under the License. 
     14 
    115package de.ugoe.cs.cpdp.dataselection; 
    216 
     
    1428import de.ugoe.cs.util.console.Console; 
    1529 
    16  
    1730/** 
    1831 * Use in Config: 
    1932 *  
    20  * Specify number of clusters 
    21  * -N = Num Clusters 
    22  * <pointwiseselector name="PointWiseEMClusterSelection" param="-N 10"/> 
    23  * 
    24  * Try to determine the number of clusters: 
    25  * -I 10 = max iterations 
    26  * -X 5 = 5 folds for cross evaluation 
    27  * -max = max number of clusters 
    28  * <pointwiseselector name="PointWiseEMClusterSelection" param="-I 10 -X 5 -max 300"/> 
     33 * Specify number of clusters -N = Num Clusters <pointwiseselector 
     34 * name="PointWiseEMClusterSelection" param="-N 10"/> 
    2935 *  
    30  * Don't forget to add: 
    31  * <preprocessor name="Normalization" param=""/> 
     36 * Try to determine the number of clusters: -I 10 = max iterations -X 5 = 5 folds for cross 
     37 * evaluation -max = max number of clusters <pointwiseselector name="PointWiseEMClusterSelection" 
     38 * param="-I 10 -X 5 -max 300"/> 
     39 *  
     40 * Don't forget to add: <preprocessor name="Normalization" param=""/> 
    3241 */ 
    3342public class PointWiseEMClusterSelection implements IPointWiseDataselectionStrategy { 
    34          
    35         private String[] params;  
    36          
    37         @Override 
    38         public void setParameter(String parameters) { 
    39                 params = parameters.split(" "); 
    40         } 
    4143 
    42          
    43         /** 
    44          * 1. Cluster the traindata 
    45          * 2. for each instance in the testdata find the assigned cluster 
    46          * 3. select only traindata from the clusters we found in our testdata 
    47          *  
    48          * @returns the selected training data 
    49          */ 
    50         @Override 
    51         public Instances apply(Instances testdata, Instances traindata) { 
    52                 //final Attribute classAttribute = testdata.classAttribute(); 
    53                  
    54                 final List<Integer> selectedCluster = SetUniqueList.setUniqueList(new LinkedList<Integer>()); 
     44    private String[] params; 
    5545 
    56                 // 1. copy train- and testdata 
    57                 Instances train = new Instances(traindata); 
    58                 Instances test = new Instances(testdata); 
    59                  
    60                 Instances selected = null; 
    61                  
    62                 try { 
    63                         // remove class attribute from traindata 
    64                         Remove filter = new Remove(); 
    65                         filter.setAttributeIndices("" + (train.classIndex() + 1)); 
    66                         filter.setInputFormat(train); 
    67                         train = Filter.useFilter(train, filter); 
    68                          
    69                         Console.traceln(Level.INFO, String.format("starting clustering")); 
    70                          
    71                         // 3. cluster data 
    72                         EM clusterer = new EM(); 
    73                         clusterer.setOptions(params); 
    74                         clusterer.buildClusterer(train); 
    75                         int numClusters = clusterer.getNumClusters(); 
    76                         if ( numClusters == -1) { 
    77                                 Console.traceln(Level.INFO, String.format("we have unlimited clusters")); 
    78                         }else { 
    79                                 Console.traceln(Level.INFO, String.format("we have: "+numClusters+" clusters")); 
    80                         } 
    81                          
    82                          
    83                         // 4. classify testdata, save cluster int 
    84                          
    85                         // remove class attribute from testdata? 
    86                         Remove filter2 = new Remove(); 
    87                         filter2.setAttributeIndices("" + (test.classIndex() + 1)); 
    88                         filter2.setInputFormat(test); 
    89                         test = Filter.useFilter(test, filter2); 
    90                          
    91                         int cnum; 
    92                         for( int i=0; i < test.numInstances(); i++ ) { 
    93                                 cnum = ((EM)clusterer).clusterInstance(test.get(i)); 
     46    @Override 
     47    public void setParameter(String parameters) { 
     48        params = parameters.split(" "); 
     49    } 
    9450 
    95                                 // we dont want doubles (maybe use a hashset instead of list?) 
    96                                 if ( !selectedCluster.contains(cnum) ) { 
    97                                         selectedCluster.add(cnum); 
    98                                         //Console.traceln(Level.INFO, String.format("assigned to cluster: "+cnum)); 
    99                                 } 
    100                         } 
    101                          
    102                         Console.traceln(Level.INFO, String.format("our testdata is in: "+selectedCluster.size()+" different clusters")); 
    103                          
    104                         // 5. get cluster membership of our traindata 
    105                         AddCluster cfilter = new AddCluster(); 
    106                         cfilter.setClusterer(clusterer); 
    107                         cfilter.setInputFormat(train); 
    108                         Instances ctrain = Filter.useFilter(train, cfilter); 
    109                          
    110                          
    111                         // 6. for all traindata get the cluster int, if it is in our list of testdata cluster int add the traindata 
    112                         // of this cluster to our returned traindata 
    113                         int cnumber; 
    114                         selected = new Instances(traindata); 
    115                         selected.delete(); 
    116                          
    117                         for ( int j=0; j < ctrain.numInstances(); j++ ) { 
    118                                 // get the cluster number from the attributes 
    119                                 cnumber = Integer.parseInt(ctrain.get(j).stringValue(ctrain.get(j).numAttributes()-1).replace("cluster", "")); 
    120                                  
    121                                 //Console.traceln(Level.INFO, String.format("instance "+j+" is in cluster: "+cnumber)); 
    122                                 if ( selectedCluster.contains(cnumber) ) { 
    123                                         // this only works if the index does not change 
    124                                         selected.add(traindata.get(j)); 
    125                                         // check for differences, just one attribute, we are pretty sure the index does not change 
    126                                         if ( traindata.get(j).value(3) != ctrain.get(j).value(3) ) { 
    127                                                 Console.traceln(Level.WARNING, String.format("we have a difference between train an ctrain!")); 
    128                                         } 
    129                                 } 
    130                         } 
    131                          
    132                         Console.traceln(Level.INFO, String.format("that leaves us with: "+selected.numInstances()+" traindata instances from "+traindata.numInstances())); 
    133                 }catch( Exception e ) { 
    134                         Console.traceln(Level.WARNING, String.format("ERROR")); 
    135                         throw new RuntimeException("error in pointwise em", e); 
    136                 } 
    137          
    138                 return selected; 
    139         } 
     51    /** 
     52     * 1. Cluster the traindata 2. for each instance in the testdata find the assigned cluster 3. 
     53     * select only traindata from the clusters we found in our testdata 
     54     *  
     55     * @returns the selected training data 
     56     */ 
     57    @Override 
     58    public Instances apply(Instances testdata, Instances traindata) { 
     59        // final Attribute classAttribute = testdata.classAttribute(); 
     60 
     61        final List<Integer> selectedCluster = 
     62            SetUniqueList.setUniqueList(new LinkedList<Integer>()); 
     63 
     64        // 1. copy train- and testdata 
     65        Instances train = new Instances(traindata); 
     66        Instances test = new Instances(testdata); 
     67 
     68        Instances selected = null; 
     69 
     70        try { 
     71            // remove class attribute from traindata 
     72            Remove filter = new Remove(); 
     73            filter.setAttributeIndices("" + (train.classIndex() + 1)); 
     74            filter.setInputFormat(train); 
     75            train = Filter.useFilter(train, filter); 
     76 
     77            Console.traceln(Level.INFO, String.format("starting clustering")); 
     78 
     79            // 3. cluster data 
     80            EM clusterer = new EM(); 
     81            clusterer.setOptions(params); 
     82            clusterer.buildClusterer(train); 
     83            int numClusters = clusterer.getNumClusters(); 
     84            if (numClusters == -1) { 
     85                Console.traceln(Level.INFO, String.format("we have unlimited clusters")); 
     86            } 
     87            else { 
     88                Console.traceln(Level.INFO, String.format("we have: " + numClusters + " clusters")); 
     89            } 
     90 
     91            // 4. classify testdata, save cluster int 
     92 
     93            // remove class attribute from testdata? 
     94            Remove filter2 = new Remove(); 
     95            filter2.setAttributeIndices("" + (test.classIndex() + 1)); 
     96            filter2.setInputFormat(test); 
     97            test = Filter.useFilter(test, filter2); 
     98 
     99            int cnum; 
     100            for (int i = 0; i < test.numInstances(); i++) { 
     101                cnum = ((EM) clusterer).clusterInstance(test.get(i)); 
     102 
     103                // we dont want doubles (maybe use a hashset instead of list?) 
     104                if (!selectedCluster.contains(cnum)) { 
     105                    selectedCluster.add(cnum); 
     106                    // Console.traceln(Level.INFO, String.format("assigned to cluster: "+cnum)); 
     107                } 
     108            } 
     109 
     110            Console.traceln(Level.INFO, 
     111                            String.format("our testdata is in: " + selectedCluster.size() + 
     112                                " different clusters")); 
     113 
     114            // 5. get cluster membership of our traindata 
     115            AddCluster cfilter = new AddCluster(); 
     116            cfilter.setClusterer(clusterer); 
     117            cfilter.setInputFormat(train); 
     118            Instances ctrain = Filter.useFilter(train, cfilter); 
     119 
     120            // 6. for all traindata get the cluster int, if it is in our list of testdata cluster 
     121            // int add the traindata 
     122            // of this cluster to our returned traindata 
     123            int cnumber; 
     124            selected = new Instances(traindata); 
     125            selected.delete(); 
     126 
     127            for (int j = 0; j < ctrain.numInstances(); j++) { 
     128                // get the cluster number from the attributes 
     129                cnumber = 
     130                    Integer.parseInt(ctrain.get(j).stringValue(ctrain.get(j).numAttributes() - 1) 
     131                        .replace("cluster", "")); 
     132 
     133                // Console.traceln(Level.INFO, 
     134                // String.format("instance "+j+" is in cluster: "+cnumber)); 
     135                if (selectedCluster.contains(cnumber)) { 
     136                    // this only works if the index does not change 
     137                    selected.add(traindata.get(j)); 
     138                    // check for differences, just one attribute, we are pretty sure the index does 
     139                    // not change 
     140                    if (traindata.get(j).value(3) != ctrain.get(j).value(3)) { 
     141                        Console.traceln(Level.WARNING, String 
     142                            .format("we have a difference between train an ctrain!")); 
     143                    } 
     144                } 
     145            } 
     146 
     147            Console.traceln(Level.INFO, 
     148                            String.format("that leaves us with: " + selected.numInstances() + 
     149                                " traindata instances from " + traindata.numInstances())); 
     150        } 
     151        catch (Exception e) { 
     152            Console.traceln(Level.WARNING, String.format("ERROR")); 
     153            throw new RuntimeException("error in pointwise em", e); 
     154        } 
     155 
     156        return selected; 
     157    } 
    140158 
    141159} 
Note: See TracChangeset for help on using the changeset viewer.