source: trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaClusterTraining2.java @ 4

Last change on this file since 4 was 2, checked in by sherbold, 10 years ago
  • initial commit
  • Property svn:mime-type set to text/plain
File size: 5.7 KB
Line 
1package de.ugoe.cs.cpdp.training;
2
3import java.io.PrintStream;
4import java.util.HashMap;
5import java.util.HashSet;
6import java.util.Iterator;
7import java.util.Set;
8import java.util.logging.Level;
9
10import org.apache.commons.io.output.NullOutputStream;
11
12import de.ugoe.cs.util.console.Console;
13import weka.classifiers.AbstractClassifier;
14import weka.classifiers.Classifier;
15import weka.clusterers.EM;
16import weka.core.DenseInstance;
17import weka.core.Instance;
18import weka.core.Instances;
19import weka.filters.Filter;
20import weka.filters.unsupervised.attribute.Remove;
21
22/**
23 * WekaClusterTraining2
24 *
25 * 1. Cluster traindata
26 * 2. for each cluster train a classifier with traindata from cluster
27 * 3. match testdata instance to a cluster, then classify with classifier from the cluster
28 *
29 * XML config:
30 * <!-- because of clustering -->
31 * <preprocessor name="Normalization" param=""/>
32 *
33 * <!-- cluster trainer -->
34 * <trainer name="WekaClusterTraining2" param="NaiveBayes weka.classifiers.bayes.NaiveBayes" />
35 *
36 * Questions:
37 * - how do we configure the clustering params?
38 */
39public class WekaClusterTraining2 extends WekaBaseTraining2 implements ITrainingStrategy {
40
41        private final TraindatasetCluster classifier = new TraindatasetCluster();
42       
43        @Override
44        public Classifier getClassifier() {
45                return classifier;
46        }
47       
48       
49        @Override
50        public void apply(Instances traindata) {
51                PrintStream errStr      = System.err;
52                System.setErr(new PrintStream(new NullOutputStream()));
53                try {
54                        classifier.buildClassifier(traindata);
55                } catch (Exception e) {
56                        throw new RuntimeException(e);
57                } finally {
58                        System.setErr(errStr);
59                }
60        }
61       
62
63        public class TraindatasetCluster extends AbstractClassifier {
64               
65                private static final long serialVersionUID = 1L;
66
67                private EM clusterer = null;
68
69                private HashMap<Integer, Classifier> cclassifier = new HashMap<Integer, Classifier>();
70                private HashMap<Integer, Instances> ctraindata = new HashMap<Integer, Instances>();
71               
72               
73               
74                private Instance createInstance(Instances instances, Instance instance) {
75                        // attributes for feeding instance to classifier
76                        Set<String> attributeNames = new HashSet<>();
77                        for( int j=0; j<instances.numAttributes(); j++ ) {
78                                attributeNames.add(instances.attribute(j).name());
79                        }
80                       
81                        double[] values = new double[instances.numAttributes()];
82                        int index = 0;
83                        for( int j=0; j<instance.numAttributes(); j++ ) {
84                                if( attributeNames.contains(instance.attribute(j).name())) {
85                                        values[index] = instance.value(j);
86                                        index++;
87                                }
88                        }
89                       
90                        Instances tmp = new Instances(instances);
91                        tmp.clear();
92                        Instance instCopy = new DenseInstance(instance.weight(), values);
93                        instCopy.setDataset(tmp);
94                       
95                        return instCopy;
96                }
97               
98               
99                @Override
100                public double classifyInstance(Instance instance) {
101                        double ret = 0;
102                        try {
103                                Instances traindata = ctraindata.get(0);
104                                Instance classInstance = createInstance(traindata, instance);
105                               
106                                // remove class attribute before clustering
107                                Remove filter = new Remove();
108                                filter.setAttributeIndices("" + (traindata.classIndex() + 1));
109                                filter.setInputFormat(traindata);
110                                traindata = Filter.useFilter(traindata, filter);
111                               
112                                Instance clusterInstance = createInstance(traindata, instance);
113                               
114                                // 1. classify testdata instance to a cluster number
115                                int cnum = clusterer.clusterInstance(clusterInstance);
116                               
117                                // 2. classify testata instance to the classifier
118                                ret = cclassifier.get(cnum).classifyInstance(classInstance);
119                               
120                        }catch( Exception e ) {
121                                Console.traceln(Level.INFO, String.format("ERROR matching instance to cluster!"));
122                                throw new RuntimeException(e);
123                        }
124                        return ret;
125                }
126
127               
128               
129                @Override
130                public void buildClassifier(Instances traindata) throws Exception {
131                       
132                        // 1. copy traindata
133                        Instances train = new Instances(traindata);
134                       
135                        // 2. remove class attribute for clustering
136                        Remove filter = new Remove();
137                        filter.setAttributeIndices("" + (train.classIndex() + 1));
138                        filter.setInputFormat(train);
139                        train = Filter.useFilter(train, filter);
140                       
141                        // 3. cluster data
142                        //Console.traceln(Level.INFO, String.format("starting clustering"));
143                       
144                        // use standard params for now
145                        clusterer = new EM();
146                        //String[] params = {"-N", "100"};
147                        //clusterer.setOptions(params);
148                        clusterer.buildClusterer(train);
149                        // set max num to traindata size
150                        clusterer.setMaximumNumberOfClusters(train.size());
151                       
152                        // 4. get cluster membership of our traindata
153                        //AddCluster cfilter = new AddCluster();
154                        //cfilter.setClusterer(clusterer);
155                        //cfilter.setInputFormat(train);
156                        //Instances ctrain = Filter.useFilter(train, cfilter);
157                       
158                        Instances ctrain = new Instances(train);
159                       
160                        // get traindata per cluster
161                        int cnumber;
162                        for ( int j=0; j < ctrain.numInstances(); j++ ) {
163                                // get the cluster number from the attributes, subract 1 because if we clusterInstance we get 0-n, and this is 1-n
164                                //cnumber = Integer.parseInt(ctrain.get(j).stringValue(ctrain.get(j).numAttributes()-1).replace("cluster", "")) - 1;
165                               
166                                cnumber = clusterer.clusterInstance(ctrain.get(j));
167                                // add training data to list of instances for this cluster number
168                                if ( !ctraindata.containsKey(cnumber) ) {
169                                        ctraindata.put(cnumber, new Instances(traindata));
170                                        ctraindata.get(cnumber).delete();
171                                }
172                                ctraindata.get(cnumber).add(traindata.get(j));
173                        }
174                       
175                        // train one classifier per cluster, we get the clusternumber from the traindata
176                        Iterator<Integer> clusternumber = ctraindata.keySet().iterator();
177                        while ( clusternumber.hasNext() ) {
178                                cnumber = clusternumber.next();                 
179                                cclassifier.put(cnumber,setupClassifier());
180                                cclassifier.get(cnumber).buildClassifier(ctraindata.get(cnumber));
181                               
182                                //Console.traceln(Level.INFO, String.format("classifier in cluster "+cnumber));
183                        }
184                }
185        }
186}
Note: See TracBrowser for help on using the repository browser.