source: trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaLocalEMTraining.java @ 34

Last change on this file since 34 was 25, checked in by atrautsch, 10 years ago

comment fixes

  • Property svn:mime-type set to text/plain
File size: 6.7 KB
Line 
1package de.ugoe.cs.cpdp.training;
2
3import java.io.PrintStream;
4import java.util.HashMap;
5import java.util.HashSet;
6import java.util.Iterator;
7import java.util.Map.Entry;
8import java.util.Set;
9import java.util.logging.Level;
10
11import org.apache.commons.io.output.NullOutputStream;
12
13import de.ugoe.cs.util.console.Console;
14import weka.classifiers.AbstractClassifier;
15import weka.classifiers.Classifier;
16import weka.clusterers.EM;
17import weka.core.DenseInstance;
18import weka.core.Instance;
19import weka.core.Instances;
20import weka.filters.Filter;
21import weka.filters.unsupervised.attribute.Remove;
22
23/**
24 * WekaLocalEMTraining
25 *
26 * Local Trainer with EM Clustering for data partitioning.
27 * Currently supports only EM Clustering.
28 *
29 * 1. Cluster training data
30 * 2. for each cluster train a classifier with training data from cluster
31 * 3. match test data instance to a cluster, then classify with classifier from the cluster
32 *
33 * XML configuration:
34 * <!-- because of clustering -->
35 * <preprocessor name="Normalization" param=""/>
36 *
37 * <!-- cluster trainer -->
38 * <trainer name="WekaLocalEMTraining" param="NaiveBayes weka.classifiers.bayes.NaiveBayes" />
39 */
40public class WekaLocalEMTraining extends WekaBaseTraining implements ITrainingStrategy {
41
42        private final TraindatasetCluster classifier = new TraindatasetCluster();
43       
44        @Override
45        public Classifier getClassifier() {
46                return classifier;
47        }
48       
49        @Override
50        public void apply(Instances traindata) {
51                PrintStream errStr      = System.err;
52                System.setErr(new PrintStream(new NullOutputStream()));
53                try {
54                        classifier.buildClassifier(traindata);
55                } catch (Exception e) {
56                        throw new RuntimeException(e);
57                } finally {
58                        System.setErr(errStr);
59                }
60        }
61       
62
63        public class TraindatasetCluster extends AbstractClassifier {
64               
65                private static final long serialVersionUID = 1L;
66
67                private EM clusterer = null;
68
69                private HashMap<Integer, Classifier> cclassifier;
70                private HashMap<Integer, Instances> ctraindata;
71               
72               
73                /**
74                 * Helper method that gives us a clean instance copy with
75                 * the values of the instancelist of the first parameter.
76                 *
77                 * @param instancelist with attributes
78                 * @param instance with only values
79                 * @return copy of the instance
80                 */
81                private Instance createInstance(Instances instances, Instance instance) {
82                        // attributes for feeding instance to classifier
83                        Set<String> attributeNames = new HashSet<>();
84                        for( int j=0; j<instances.numAttributes(); j++ ) {
85                                attributeNames.add(instances.attribute(j).name());
86                        }
87                       
88                        double[] values = new double[instances.numAttributes()];
89                        int index = 0;
90                        for( int j=0; j<instance.numAttributes(); j++ ) {
91                                if( attributeNames.contains(instance.attribute(j).name())) {
92                                        values[index] = instance.value(j);
93                                        index++;
94                                }
95                        }
96                       
97                        Instances tmp = new Instances(instances);
98                        tmp.clear();
99                        Instance instCopy = new DenseInstance(instance.weight(), values);
100                        instCopy.setDataset(tmp);
101                       
102                        return instCopy;
103                }
104               
105                @Override
106                public double classifyInstance(Instance instance) {
107                        double ret = 0;
108                        try {
109                                // 1. copy the instance (keep the class attribute)
110                                Instances traindata = ctraindata.get(0);
111                                Instance classInstance = createInstance(traindata, instance);
112                               
113                                // 2. remove class attribute before clustering
114                                Remove filter = new Remove();
115                                filter.setAttributeIndices("" + (traindata.classIndex() + 1));
116                                filter.setInputFormat(traindata);
117                                traindata = Filter.useFilter(traindata, filter);
118                               
119                                // 3. copy the instance (without the class attribute) for clustering
120                                Instance clusterInstance = createInstance(traindata, instance);
121                               
122                                // 4. match instance without class attribute to a cluster number
123                                int cnum = clusterer.clusterInstance(clusterInstance);
124                               
125                                // 5. classify instance with class attribute to the classifier of that cluster number
126                                ret = cclassifier.get(cnum).classifyInstance(classInstance);
127                               
128                        }catch( Exception e ) {
129                                Console.traceln(Level.INFO, String.format("ERROR matching instance to cluster!"));
130                                throw new RuntimeException(e);
131                        }
132                        return ret;
133                }
134
135                @Override
136                public void buildClassifier(Instances traindata) throws Exception {
137                       
138                        // 1. copy training data
139                        Instances train = new Instances(traindata);
140                       
141                        // 2. remove class attribute for clustering
142                        Remove filter = new Remove();
143                        filter.setAttributeIndices("" + (train.classIndex() + 1));
144                        filter.setInputFormat(train);
145                        train = Filter.useFilter(train, filter);
146                       
147                        // new objects
148                        cclassifier = new HashMap<Integer, Classifier>();
149                        ctraindata = new HashMap<Integer, Instances>();
150                                               
151                        Instances ctrain;
152                        int maxNumClusters = train.size();
153                        boolean sufficientInstancesInEachCluster;
154                        do { // while(onlyTarget)
155                                sufficientInstancesInEachCluster = true;
156                                clusterer = new EM();
157                                clusterer.setMaximumNumberOfClusters(maxNumClusters);
158                                clusterer.buildClusterer(train);
159                               
160                                // 4. get cluster membership of our traindata
161                                //AddCluster cfilter = new AddCluster();
162                                //cfilter.setClusterer(clusterer);
163                                //cfilter.setInputFormat(train);
164                                //Instances ctrain = Filter.useFilter(train, cfilter);
165                               
166                                ctrain = new Instances(train);
167                                ctraindata = new HashMap<>();
168                               
169                                // get traindata per cluster
170                                for ( int j=0; j < ctrain.numInstances(); j++ ) {
171                                        // get the cluster number from the attributes, subract 1 because if we clusterInstance we get 0-n, and this is 1-n
172                                        //cnumber = Integer.parseInt(ctrain.get(j).stringValue(ctrain.get(j).numAttributes()-1).replace("cluster", "")) - 1;
173                                       
174                                        int cnumber = clusterer.clusterInstance(ctrain.get(j));
175                                        // add training data to list of instances for this cluster number
176                                        if ( !ctraindata.containsKey(cnumber) ) {
177                                                ctraindata.put(cnumber, new Instances(traindata));
178                                                ctraindata.get(cnumber).delete();
179                                        }
180                                        ctraindata.get(cnumber).add(traindata.get(j));
181                                }
182                               
183                                for( Entry<Integer,Instances> entry : ctraindata.entrySet() ) {
184                                        Instances instances = entry.getValue();
185                                        int[] counts = instances.attributeStats(instances.classIndex()).nominalCounts;
186                                        for( int count : counts ) {
187                                                sufficientInstancesInEachCluster &= count>0;
188                                        }
189                                        sufficientInstancesInEachCluster &= instances.numInstances()>=5;
190                                }
191                                maxNumClusters = clusterer.numberOfClusters()-1;
192                        } while(!sufficientInstancesInEachCluster);
193                       
194                        // train one classifier per cluster, we get the cluster number from the training data
195                        Iterator<Integer> clusternumber = ctraindata.keySet().iterator();
196                        while ( clusternumber.hasNext() ) {
197                                int cnumber = clusternumber.next();                     
198                                cclassifier.put(cnumber,setupClassifier());
199                                cclassifier.get(cnumber).buildClassifier(ctraindata.get(cnumber));
200                               
201                                //Console.traceln(Level.INFO, String.format("classifier in cluster "+cnumber));
202                        }
203                }
204        }
205}
Note: See TracBrowser for help on using the repository browser.