Ignore:
Timestamp:
09/24/15 10:59:05 (9 years ago)
Author:
sherbold
Message:
  • formatted code and added copyrights
File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaLocalEMTraining.java

    r25 r41  
     1// Copyright 2015 Georg-August-Universität Göttingen, Germany 
     2// 
     3//   Licensed under the Apache License, Version 2.0 (the "License"); 
     4//   you may not use this file except in compliance with the License. 
     5//   You may obtain a copy of the License at 
     6// 
     7//       http://www.apache.org/licenses/LICENSE-2.0 
     8// 
     9//   Unless required by applicable law or agreed to in writing, software 
     10//   distributed under the License is distributed on an "AS IS" BASIS, 
     11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
     12//   See the License for the specific language governing permissions and 
     13//   limitations under the License. 
     14 
    115package de.ugoe.cs.cpdp.training; 
    216 
     
    2438 * WekaLocalEMTraining 
    2539 *  
    26  * Local Trainer with EM Clustering for data partitioning. 
    27  * Currently supports only EM Clustering. 
    28  *  
    29  * 1. Cluster training data 
    30  * 2. for each cluster train a classifier with training data from cluster 
     40 * Local Trainer with EM Clustering for data partitioning. Currently supports only EM Clustering. 
     41 *  
     42 * 1. Cluster training data 2. for each cluster train a classifier with training data from cluster 
    3143 * 3. match test data instance to a cluster, then classify with classifier from the cluster 
    3244 *  
    33  * XML configuration: 
    34  * <!-- because of clustering --> 
    35  * <preprocessor name="Normalization" param=""/> 
    36  *  
    37  * <!-- cluster trainer --> 
    38  * <trainer name="WekaLocalEMTraining" param="NaiveBayes weka.classifiers.bayes.NaiveBayes" /> 
     45 * XML configuration: <!-- because of clustering --> <preprocessor name="Normalization" param=""/> 
     46 *  
     47 * <!-- cluster trainer --> <trainer name="WekaLocalEMTraining" 
     48 * param="NaiveBayes weka.classifiers.bayes.NaiveBayes" /> 
    3949 */ 
    4050public class WekaLocalEMTraining extends WekaBaseTraining implements ITrainingStrategy { 
    4151 
    42         private final TraindatasetCluster classifier = new TraindatasetCluster(); 
    43          
    44         @Override 
    45         public Classifier getClassifier() { 
    46                 return classifier; 
    47         } 
    48          
    49         @Override 
    50         public void apply(Instances traindata) { 
    51                 PrintStream errStr      = System.err; 
    52                 System.setErr(new PrintStream(new NullOutputStream())); 
    53                 try { 
    54                         classifier.buildClassifier(traindata); 
    55                 } catch (Exception e) { 
    56                         throw new RuntimeException(e); 
    57                 } finally { 
    58                         System.setErr(errStr); 
    59                 } 
    60         } 
    61          
    62  
    63         public class TraindatasetCluster extends AbstractClassifier { 
    64                  
    65                 private static final long serialVersionUID = 1L; 
    66  
    67                 private EM clusterer = null; 
    68  
    69                 private HashMap<Integer, Classifier> cclassifier; 
    70                 private HashMap<Integer, Instances> ctraindata;  
    71                  
    72                  
    73                 /** 
    74                  * Helper method that gives us a clean instance copy with  
    75                  * the values of the instancelist of the first parameter.  
    76                  *  
    77                  * @param instancelist with attributes 
    78                  * @param instance with only values 
    79                  * @return copy of the instance 
    80                  */ 
    81                 private Instance createInstance(Instances instances, Instance instance) { 
    82                         // attributes for feeding instance to classifier 
    83                         Set<String> attributeNames = new HashSet<>(); 
    84                         for( int j=0; j<instances.numAttributes(); j++ ) { 
    85                                 attributeNames.add(instances.attribute(j).name()); 
    86                         } 
    87                          
    88                         double[] values = new double[instances.numAttributes()]; 
    89                         int index = 0; 
    90                         for( int j=0; j<instance.numAttributes(); j++ ) { 
    91                                 if( attributeNames.contains(instance.attribute(j).name())) { 
    92                                         values[index] = instance.value(j); 
    93                                         index++; 
    94                                 } 
    95                         } 
    96                          
    97                         Instances tmp = new Instances(instances); 
    98                         tmp.clear(); 
    99                         Instance instCopy = new DenseInstance(instance.weight(), values); 
    100                         instCopy.setDataset(tmp); 
    101                          
    102                         return instCopy; 
    103                 } 
    104                  
    105                 @Override 
    106                 public double classifyInstance(Instance instance) { 
    107                         double ret = 0; 
    108                         try { 
    109                                 // 1. copy the instance (keep the class attribute) 
    110                                 Instances traindata = ctraindata.get(0); 
    111                                 Instance classInstance = createInstance(traindata, instance); 
    112                                  
    113                                 // 2. remove class attribute before clustering 
    114                                 Remove filter = new Remove(); 
    115                                 filter.setAttributeIndices("" + (traindata.classIndex() + 1)); 
    116                                 filter.setInputFormat(traindata); 
    117                                 traindata = Filter.useFilter(traindata, filter); 
    118                                  
    119                                 // 3. copy the instance (without the class attribute) for clustering 
    120                                 Instance clusterInstance = createInstance(traindata, instance); 
    121                                  
    122                                 // 4. match instance without class attribute to a cluster number 
    123                                 int cnum = clusterer.clusterInstance(clusterInstance); 
    124                                  
    125                                 // 5. classify instance with class attribute to the classifier of that cluster number 
    126                                 ret = cclassifier.get(cnum).classifyInstance(classInstance); 
    127                                  
    128                         }catch( Exception e ) { 
    129                                 Console.traceln(Level.INFO, String.format("ERROR matching instance to cluster!")); 
    130                                 throw new RuntimeException(e); 
    131                         } 
    132                         return ret; 
    133                 } 
    134  
    135                 @Override 
    136                 public void buildClassifier(Instances traindata) throws Exception { 
    137                          
    138                         // 1. copy training data 
    139                         Instances train = new Instances(traindata); 
    140                          
    141                         // 2. remove class attribute for clustering 
    142                         Remove filter = new Remove(); 
    143                         filter.setAttributeIndices("" + (train.classIndex() + 1)); 
    144                         filter.setInputFormat(train); 
    145                         train = Filter.useFilter(train, filter); 
    146                          
    147                         // new objects 
    148                         cclassifier = new HashMap<Integer, Classifier>(); 
    149                         ctraindata = new HashMap<Integer, Instances>(); 
    150                                                  
    151                         Instances ctrain; 
    152                         int maxNumClusters = train.size(); 
    153                         boolean sufficientInstancesInEachCluster; 
    154                         do { // while(onlyTarget) 
    155                                 sufficientInstancesInEachCluster = true; 
    156                                 clusterer = new EM(); 
    157                                 clusterer.setMaximumNumberOfClusters(maxNumClusters); 
    158                                 clusterer.buildClusterer(train); 
    159                                  
    160                                 // 4. get cluster membership of our traindata 
    161                                 //AddCluster cfilter = new AddCluster(); 
    162                                 //cfilter.setClusterer(clusterer); 
    163                                 //cfilter.setInputFormat(train); 
    164                                 //Instances ctrain = Filter.useFilter(train, cfilter); 
    165                                  
    166                                 ctrain = new Instances(train); 
    167                                 ctraindata = new HashMap<>(); 
    168                                  
    169                                 // get traindata per cluster 
    170                                 for ( int j=0; j < ctrain.numInstances(); j++ ) { 
    171                                         // get the cluster number from the attributes, subract 1 because if we clusterInstance we get 0-n, and this is 1-n 
    172                                         //cnumber = Integer.parseInt(ctrain.get(j).stringValue(ctrain.get(j).numAttributes()-1).replace("cluster", "")) - 1; 
    173                                          
    174                                         int cnumber = clusterer.clusterInstance(ctrain.get(j)); 
    175                                         // add training data to list of instances for this cluster number 
    176                                         if ( !ctraindata.containsKey(cnumber) ) { 
    177                                                 ctraindata.put(cnumber, new Instances(traindata)); 
    178                                                 ctraindata.get(cnumber).delete(); 
    179                                         } 
    180                                         ctraindata.get(cnumber).add(traindata.get(j)); 
    181                                 } 
    182                                  
    183                                 for( Entry<Integer,Instances> entry : ctraindata.entrySet() ) { 
    184                                         Instances instances = entry.getValue(); 
    185                                         int[] counts = instances.attributeStats(instances.classIndex()).nominalCounts; 
    186                                         for( int count : counts ) { 
    187                                                 sufficientInstancesInEachCluster &= count>0; 
    188                                         } 
    189                                         sufficientInstancesInEachCluster &= instances.numInstances()>=5; 
    190                                 } 
    191                                 maxNumClusters = clusterer.numberOfClusters()-1; 
    192                         } while(!sufficientInstancesInEachCluster); 
    193                          
    194                         // train one classifier per cluster, we get the cluster number from the training data 
    195                         Iterator<Integer> clusternumber = ctraindata.keySet().iterator(); 
    196                         while ( clusternumber.hasNext() ) { 
    197                                 int cnumber = clusternumber.next();                      
    198                                 cclassifier.put(cnumber,setupClassifier()); 
    199                                 cclassifier.get(cnumber).buildClassifier(ctraindata.get(cnumber)); 
    200                                  
    201                                 //Console.traceln(Level.INFO, String.format("classifier in cluster "+cnumber)); 
    202                         } 
    203                 } 
    204         } 
     52    private final TraindatasetCluster classifier = new TraindatasetCluster(); 
     53 
     54    @Override 
     55    public Classifier getClassifier() { 
     56        return classifier; 
     57    } 
     58 
     59    @Override 
     60    public void apply(Instances traindata) { 
     61        PrintStream errStr = System.err; 
     62        System.setErr(new PrintStream(new NullOutputStream())); 
     63        try { 
     64            classifier.buildClassifier(traindata); 
     65        } 
     66        catch (Exception e) { 
     67            throw new RuntimeException(e); 
     68        } 
     69        finally { 
     70            System.setErr(errStr); 
     71        } 
     72    } 
     73 
     74    public class TraindatasetCluster extends AbstractClassifier { 
     75 
     76        private static final long serialVersionUID = 1L; 
     77 
     78        private EM clusterer = null; 
     79 
     80        private HashMap<Integer, Classifier> cclassifier; 
     81        private HashMap<Integer, Instances> ctraindata; 
     82 
     83        /** 
     84         * Helper method that gives us a clean instance copy with the values of the instancelist of 
     85         * the first parameter. 
     86         *  
     87         * @param instancelist 
     88         *            with attributes 
     89         * @param instance 
     90         *            with only values 
     91         * @return copy of the instance 
     92         */ 
     93        private Instance createInstance(Instances instances, Instance instance) { 
     94            // attributes for feeding instance to classifier 
     95            Set<String> attributeNames = new HashSet<>(); 
     96            for (int j = 0; j < instances.numAttributes(); j++) { 
     97                attributeNames.add(instances.attribute(j).name()); 
     98            } 
     99 
     100            double[] values = new double[instances.numAttributes()]; 
     101            int index = 0; 
     102            for (int j = 0; j < instance.numAttributes(); j++) { 
     103                if (attributeNames.contains(instance.attribute(j).name())) { 
     104                    values[index] = instance.value(j); 
     105                    index++; 
     106                } 
     107            } 
     108 
     109            Instances tmp = new Instances(instances); 
     110            tmp.clear(); 
     111            Instance instCopy = new DenseInstance(instance.weight(), values); 
     112            instCopy.setDataset(tmp); 
     113 
     114            return instCopy; 
     115        } 
     116 
     117        @Override 
     118        public double classifyInstance(Instance instance) { 
     119            double ret = 0; 
     120            try { 
     121                // 1. copy the instance (keep the class attribute) 
     122                Instances traindata = ctraindata.get(0); 
     123                Instance classInstance = createInstance(traindata, instance); 
     124 
     125                // 2. remove class attribute before clustering 
     126                Remove filter = new Remove(); 
     127                filter.setAttributeIndices("" + (traindata.classIndex() + 1)); 
     128                filter.setInputFormat(traindata); 
     129                traindata = Filter.useFilter(traindata, filter); 
     130 
     131                // 3. copy the instance (without the class attribute) for clustering 
     132                Instance clusterInstance = createInstance(traindata, instance); 
     133 
     134                // 4. match instance without class attribute to a cluster number 
     135                int cnum = clusterer.clusterInstance(clusterInstance); 
     136 
     137                // 5. classify instance with class attribute to the classifier of that cluster 
     138                // number 
     139                ret = cclassifier.get(cnum).classifyInstance(classInstance); 
     140 
     141            } 
     142            catch (Exception e) { 
     143                Console.traceln(Level.INFO, String.format("ERROR matching instance to cluster!")); 
     144                throw new RuntimeException(e); 
     145            } 
     146            return ret; 
     147        } 
     148 
     149        @Override 
     150        public void buildClassifier(Instances traindata) throws Exception { 
     151 
     152            // 1. copy training data 
     153            Instances train = new Instances(traindata); 
     154 
     155            // 2. remove class attribute for clustering 
     156            Remove filter = new Remove(); 
     157            filter.setAttributeIndices("" + (train.classIndex() + 1)); 
     158            filter.setInputFormat(train); 
     159            train = Filter.useFilter(train, filter); 
     160 
     161            // new objects 
     162            cclassifier = new HashMap<Integer, Classifier>(); 
     163            ctraindata = new HashMap<Integer, Instances>(); 
     164 
     165            Instances ctrain; 
     166            int maxNumClusters = train.size(); 
     167            boolean sufficientInstancesInEachCluster; 
     168            do { // while(onlyTarget) 
     169                sufficientInstancesInEachCluster = true; 
     170                clusterer = new EM(); 
     171                clusterer.setMaximumNumberOfClusters(maxNumClusters); 
     172                clusterer.buildClusterer(train); 
     173 
     174                // 4. get cluster membership of our traindata 
     175                // AddCluster cfilter = new AddCluster(); 
     176                // cfilter.setClusterer(clusterer); 
     177                // cfilter.setInputFormat(train); 
     178                // Instances ctrain = Filter.useFilter(train, cfilter); 
     179 
     180                ctrain = new Instances(train); 
     181                ctraindata = new HashMap<>(); 
     182 
     183                // get traindata per cluster 
     184                for (int j = 0; j < ctrain.numInstances(); j++) { 
     185                    // get the cluster number from the attributes, subract 1 because if we 
     186                    // clusterInstance we get 0-n, and this is 1-n 
     187                    // cnumber = 
     188                    // Integer.parseInt(ctrain.get(j).stringValue(ctrain.get(j).numAttributes()-1).replace("cluster", 
     189                    // "")) - 1; 
     190 
     191                    int cnumber = clusterer.clusterInstance(ctrain.get(j)); 
     192                    // add training data to list of instances for this cluster number 
     193                    if (!ctraindata.containsKey(cnumber)) { 
     194                        ctraindata.put(cnumber, new Instances(traindata)); 
     195                        ctraindata.get(cnumber).delete(); 
     196                    } 
     197                    ctraindata.get(cnumber).add(traindata.get(j)); 
     198                } 
     199 
     200                for (Entry<Integer, Instances> entry : ctraindata.entrySet()) { 
     201                    Instances instances = entry.getValue(); 
     202                    int[] counts = instances.attributeStats(instances.classIndex()).nominalCounts; 
     203                    for (int count : counts) { 
     204                        sufficientInstancesInEachCluster &= count > 0; 
     205                    } 
     206                    sufficientInstancesInEachCluster &= instances.numInstances() >= 5; 
     207                } 
     208                maxNumClusters = clusterer.numberOfClusters() - 1; 
     209            } 
     210            while (!sufficientInstancesInEachCluster); 
     211 
     212            // train one classifier per cluster, we get the cluster number from the training data 
     213            Iterator<Integer> clusternumber = ctraindata.keySet().iterator(); 
     214            while (clusternumber.hasNext()) { 
     215                int cnumber = clusternumber.next(); 
     216                cclassifier.put(cnumber, setupClassifier()); 
     217                cclassifier.get(cnumber).buildClassifier(ctraindata.get(cnumber)); 
     218 
     219                // Console.traceln(Level.INFO, String.format("classifier in cluster "+cnumber)); 
     220            } 
     221        } 
     222    } 
    205223} 
Note: See TracChangeset for help on using the changeset viewer.