source: trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaLocalEMTraining.java @ 43

Last change on this file since 43 was 41, checked in by sherbold, 9 years ago
  • formatted code and added copyrights
  • Property svn:mime-type set to text/plain
File size: 8.6 KB
Line 
1// Copyright 2015 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.cpdp.training;
16
17import java.io.PrintStream;
18import java.util.HashMap;
19import java.util.HashSet;
20import java.util.Iterator;
21import java.util.Map.Entry;
22import java.util.Set;
23import java.util.logging.Level;
24
25import org.apache.commons.io.output.NullOutputStream;
26
27import de.ugoe.cs.util.console.Console;
28import weka.classifiers.AbstractClassifier;
29import weka.classifiers.Classifier;
30import weka.clusterers.EM;
31import weka.core.DenseInstance;
32import weka.core.Instance;
33import weka.core.Instances;
34import weka.filters.Filter;
35import weka.filters.unsupervised.attribute.Remove;
36
37/**
38 * WekaLocalEMTraining
39 *
40 * Local Trainer with EM Clustering for data partitioning. Currently supports only EM Clustering.
41 *
42 * 1. Cluster training data 2. for each cluster train a classifier with training data from cluster
43 * 3. match test data instance to a cluster, then classify with classifier from the cluster
44 *
45 * XML configuration: <!-- because of clustering --> <preprocessor name="Normalization" param=""/>
46 *
47 * <!-- cluster trainer --> <trainer name="WekaLocalEMTraining"
48 * param="NaiveBayes weka.classifiers.bayes.NaiveBayes" />
49 */
50public class WekaLocalEMTraining extends WekaBaseTraining implements ITrainingStrategy {
51
52    private final TraindatasetCluster classifier = new TraindatasetCluster();
53
54    @Override
55    public Classifier getClassifier() {
56        return classifier;
57    }
58
59    @Override
60    public void apply(Instances traindata) {
61        PrintStream errStr = System.err;
62        System.setErr(new PrintStream(new NullOutputStream()));
63        try {
64            classifier.buildClassifier(traindata);
65        }
66        catch (Exception e) {
67            throw new RuntimeException(e);
68        }
69        finally {
70            System.setErr(errStr);
71        }
72    }
73
74    public class TraindatasetCluster extends AbstractClassifier {
75
76        private static final long serialVersionUID = 1L;
77
78        private EM clusterer = null;
79
80        private HashMap<Integer, Classifier> cclassifier;
81        private HashMap<Integer, Instances> ctraindata;
82
83        /**
84         * Helper method that gives us a clean instance copy with the values of the instancelist of
85         * the first parameter.
86         *
87         * @param instancelist
88         *            with attributes
89         * @param instance
90         *            with only values
91         * @return copy of the instance
92         */
93        private Instance createInstance(Instances instances, Instance instance) {
94            // attributes for feeding instance to classifier
95            Set<String> attributeNames = new HashSet<>();
96            for (int j = 0; j < instances.numAttributes(); j++) {
97                attributeNames.add(instances.attribute(j).name());
98            }
99
100            double[] values = new double[instances.numAttributes()];
101            int index = 0;
102            for (int j = 0; j < instance.numAttributes(); j++) {
103                if (attributeNames.contains(instance.attribute(j).name())) {
104                    values[index] = instance.value(j);
105                    index++;
106                }
107            }
108
109            Instances tmp = new Instances(instances);
110            tmp.clear();
111            Instance instCopy = new DenseInstance(instance.weight(), values);
112            instCopy.setDataset(tmp);
113
114            return instCopy;
115        }
116
117        @Override
118        public double classifyInstance(Instance instance) {
119            double ret = 0;
120            try {
121                // 1. copy the instance (keep the class attribute)
122                Instances traindata = ctraindata.get(0);
123                Instance classInstance = createInstance(traindata, instance);
124
125                // 2. remove class attribute before clustering
126                Remove filter = new Remove();
127                filter.setAttributeIndices("" + (traindata.classIndex() + 1));
128                filter.setInputFormat(traindata);
129                traindata = Filter.useFilter(traindata, filter);
130
131                // 3. copy the instance (without the class attribute) for clustering
132                Instance clusterInstance = createInstance(traindata, instance);
133
134                // 4. match instance without class attribute to a cluster number
135                int cnum = clusterer.clusterInstance(clusterInstance);
136
137                // 5. classify instance with class attribute to the classifier of that cluster
138                // number
139                ret = cclassifier.get(cnum).classifyInstance(classInstance);
140
141            }
142            catch (Exception e) {
143                Console.traceln(Level.INFO, String.format("ERROR matching instance to cluster!"));
144                throw new RuntimeException(e);
145            }
146            return ret;
147        }
148
149        @Override
150        public void buildClassifier(Instances traindata) throws Exception {
151
152            // 1. copy training data
153            Instances train = new Instances(traindata);
154
155            // 2. remove class attribute for clustering
156            Remove filter = new Remove();
157            filter.setAttributeIndices("" + (train.classIndex() + 1));
158            filter.setInputFormat(train);
159            train = Filter.useFilter(train, filter);
160
161            // new objects
162            cclassifier = new HashMap<Integer, Classifier>();
163            ctraindata = new HashMap<Integer, Instances>();
164
165            Instances ctrain;
166            int maxNumClusters = train.size();
167            boolean sufficientInstancesInEachCluster;
168            do { // while(onlyTarget)
169                sufficientInstancesInEachCluster = true;
170                clusterer = new EM();
171                clusterer.setMaximumNumberOfClusters(maxNumClusters);
172                clusterer.buildClusterer(train);
173
174                // 4. get cluster membership of our traindata
175                // AddCluster cfilter = new AddCluster();
176                // cfilter.setClusterer(clusterer);
177                // cfilter.setInputFormat(train);
178                // Instances ctrain = Filter.useFilter(train, cfilter);
179
180                ctrain = new Instances(train);
181                ctraindata = new HashMap<>();
182
183                // get traindata per cluster
184                for (int j = 0; j < ctrain.numInstances(); j++) {
185                    // get the cluster number from the attributes, subract 1 because if we
186                    // clusterInstance we get 0-n, and this is 1-n
187                    // cnumber =
188                    // Integer.parseInt(ctrain.get(j).stringValue(ctrain.get(j).numAttributes()-1).replace("cluster",
189                    // "")) - 1;
190
191                    int cnumber = clusterer.clusterInstance(ctrain.get(j));
192                    // add training data to list of instances for this cluster number
193                    if (!ctraindata.containsKey(cnumber)) {
194                        ctraindata.put(cnumber, new Instances(traindata));
195                        ctraindata.get(cnumber).delete();
196                    }
197                    ctraindata.get(cnumber).add(traindata.get(j));
198                }
199
200                for (Entry<Integer, Instances> entry : ctraindata.entrySet()) {
201                    Instances instances = entry.getValue();
202                    int[] counts = instances.attributeStats(instances.classIndex()).nominalCounts;
203                    for (int count : counts) {
204                        sufficientInstancesInEachCluster &= count > 0;
205                    }
206                    sufficientInstancesInEachCluster &= instances.numInstances() >= 5;
207                }
208                maxNumClusters = clusterer.numberOfClusters() - 1;
209            }
210            while (!sufficientInstancesInEachCluster);
211
212            // train one classifier per cluster, we get the cluster number from the training data
213            Iterator<Integer> clusternumber = ctraindata.keySet().iterator();
214            while (clusternumber.hasNext()) {
215                int cnumber = clusternumber.next();
216                cclassifier.put(cnumber, setupClassifier());
217                cclassifier.get(cnumber).buildClassifier(ctraindata.get(cnumber));
218
219                // Console.traceln(Level.INFO, String.format("classifier in cluster "+cnumber));
220            }
221        }
222    }
223}
Note: See TracBrowser for help on using the repository browser.