source: trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaLocalEMTraining.java @ 24

Last change on this file since 24 was 23, checked in by sherbold, 10 years ago
  • renamed new training classes
  • Property svn:mime-type set to text/plain
File size: 6.6 KB
Line 
1package de.ugoe.cs.cpdp.training;
2
3import java.io.PrintStream;
4import java.util.HashMap;
5import java.util.HashSet;
6import java.util.Iterator;
7import java.util.Map.Entry;
8import java.util.Set;
9import java.util.logging.Level;
10
11import org.apache.commons.io.output.NullOutputStream;
12
13import de.ugoe.cs.util.console.Console;
14import weka.classifiers.AbstractClassifier;
15import weka.classifiers.Classifier;
16import weka.clusterers.EM;
17import weka.core.DenseInstance;
18import weka.core.Instance;
19import weka.core.Instances;
20import weka.filters.Filter;
21import weka.filters.unsupervised.attribute.Remove;
22
23/**
24 * WekaClusterTraining2
25 *
26 * Currently supports only EM Clustering.
27 *
28 * 1. Cluster training data
29 * 2. for each cluster train a classifier with training data from cluster
30 * 3. match test data instance to a cluster, then classify with classifier from the cluster
31 *
32 * XML configuration:
33 * <!-- because of clustering -->
34 * <preprocessor name="Normalization" param=""/>
35 *
36 * <!-- cluster trainer -->
37 * <trainer name="WekaClusterTraining2" param="NaiveBayes weka.classifiers.bayes.NaiveBayes" />
38 */
39public class WekaLocalEMTraining extends WekaBaseTraining implements ITrainingStrategy {
40
41        private final TraindatasetCluster classifier = new TraindatasetCluster();
42       
43        @Override
44        public Classifier getClassifier() {
45                return classifier;
46        }
47       
48        @Override
49        public void apply(Instances traindata) {
50                PrintStream errStr      = System.err;
51                System.setErr(new PrintStream(new NullOutputStream()));
52                try {
53                        classifier.buildClassifier(traindata);
54                } catch (Exception e) {
55                        throw new RuntimeException(e);
56                } finally {
57                        System.setErr(errStr);
58                }
59        }
60       
61
62        public class TraindatasetCluster extends AbstractClassifier {
63               
64                private static final long serialVersionUID = 1L;
65
66                private EM clusterer = null;
67
68                private HashMap<Integer, Classifier> cclassifier;
69                private HashMap<Integer, Instances> ctraindata;
70               
71               
72                /**
73                 * Helper method that gives us a clean instance copy with
74                 * the values of the instancelist of the first parameter.
75                 *
76                 * @param instancelist with attributes
77                 * @param instance with only values
78                 * @return copy of the instance
79                 */
80                private Instance createInstance(Instances instances, Instance instance) {
81                        // attributes for feeding instance to classifier
82                        Set<String> attributeNames = new HashSet<>();
83                        for( int j=0; j<instances.numAttributes(); j++ ) {
84                                attributeNames.add(instances.attribute(j).name());
85                        }
86                       
87                        double[] values = new double[instances.numAttributes()];
88                        int index = 0;
89                        for( int j=0; j<instance.numAttributes(); j++ ) {
90                                if( attributeNames.contains(instance.attribute(j).name())) {
91                                        values[index] = instance.value(j);
92                                        index++;
93                                }
94                        }
95                       
96                        Instances tmp = new Instances(instances);
97                        tmp.clear();
98                        Instance instCopy = new DenseInstance(instance.weight(), values);
99                        instCopy.setDataset(tmp);
100                       
101                        return instCopy;
102                }
103               
104                @Override
105                public double classifyInstance(Instance instance) {
106                        double ret = 0;
107                        try {
108                                // 1. copy the instance (keep the class attribute)
109                                Instances traindata = ctraindata.get(0);
110                                Instance classInstance = createInstance(traindata, instance);
111                               
112                                // 2. remove class attribute before clustering
113                                Remove filter = new Remove();
114                                filter.setAttributeIndices("" + (traindata.classIndex() + 1));
115                                filter.setInputFormat(traindata);
116                                traindata = Filter.useFilter(traindata, filter);
117                               
118                                // 3. copy the instance (without the class attribute) for clustering
119                                Instance clusterInstance = createInstance(traindata, instance);
120                               
121                                // 4. match instance without class attribute to a cluster number
122                                int cnum = clusterer.clusterInstance(clusterInstance);
123                               
124                                // 5. classify instance with class attribute to the classifier of that cluster number
125                                ret = cclassifier.get(cnum).classifyInstance(classInstance);
126                               
127                        }catch( Exception e ) {
128                                Console.traceln(Level.INFO, String.format("ERROR matching instance to cluster!"));
129                                throw new RuntimeException(e);
130                        }
131                        return ret;
132                }
133
134                @Override
135                public void buildClassifier(Instances traindata) throws Exception {
136                       
137                        // 1. copy training data
138                        Instances train = new Instances(traindata);
139                       
140                        // 2. remove class attribute for clustering
141                        Remove filter = new Remove();
142                        filter.setAttributeIndices("" + (train.classIndex() + 1));
143                        filter.setInputFormat(train);
144                        train = Filter.useFilter(train, filter);
145                       
146                        // new objects
147                        cclassifier = new HashMap<Integer, Classifier>();
148                        ctraindata = new HashMap<Integer, Instances>();
149                                               
150                        Instances ctrain;
151                        int maxNumClusters = train.size();
152                        boolean sufficientInstancesInEachCluster;
153                        do { // while(onlyTarget)
154                                sufficientInstancesInEachCluster = true;
155                                clusterer = new EM();
156                                clusterer.setMaximumNumberOfClusters(maxNumClusters);
157                                clusterer.buildClusterer(train);
158                               
159                                // 4. get cluster membership of our traindata
160                                //AddCluster cfilter = new AddCluster();
161                                //cfilter.setClusterer(clusterer);
162                                //cfilter.setInputFormat(train);
163                                //Instances ctrain = Filter.useFilter(train, cfilter);
164                               
165                                ctrain = new Instances(train);
166                                ctraindata = new HashMap<>();
167                               
168                                // get traindata per cluster
169                                for ( int j=0; j < ctrain.numInstances(); j++ ) {
170                                        // get the cluster number from the attributes, subract 1 because if we clusterInstance we get 0-n, and this is 1-n
171                                        //cnumber = Integer.parseInt(ctrain.get(j).stringValue(ctrain.get(j).numAttributes()-1).replace("cluster", "")) - 1;
172                                       
173                                        int cnumber = clusterer.clusterInstance(ctrain.get(j));
174                                        // add training data to list of instances for this cluster number
175                                        if ( !ctraindata.containsKey(cnumber) ) {
176                                                ctraindata.put(cnumber, new Instances(traindata));
177                                                ctraindata.get(cnumber).delete();
178                                        }
179                                        ctraindata.get(cnumber).add(traindata.get(j));
180                                }
181                               
182                                for( Entry<Integer,Instances> entry : ctraindata.entrySet() ) {
183                                        Instances instances = entry.getValue();
184                                        int[] counts = instances.attributeStats(instances.classIndex()).nominalCounts;
185                                        for( int count : counts ) {
186                                                sufficientInstancesInEachCluster &= count>0;
187                                        }
188                                        sufficientInstancesInEachCluster &= instances.numInstances()>=5;
189                                }
190                                maxNumClusters = clusterer.numberOfClusters()-1;
191                        } while(!sufficientInstancesInEachCluster);
192                       
193                        // train one classifier per cluster, we get the cluster number from the training data
194                        Iterator<Integer> clusternumber = ctraindata.keySet().iterator();
195                        while ( clusternumber.hasNext() ) {
196                                int cnumber = clusternumber.next();                     
197                                cclassifier.put(cnumber,setupClassifier());
198                                cclassifier.get(cnumber).buildClassifier(ctraindata.get(cnumber));
199                               
200                                //Console.traceln(Level.INFO, String.format("classifier in cluster "+cnumber));
201                        }
202                }
203        }
204}
Note: See TracBrowser for help on using the repository browser.