source: trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaLocalEMTraining.java @ 135

Last change on this file since 135 was 135, checked in by sherbold, 8 years ago
  • code documentation and formatting
  • Property svn:mime-type set to text/plain
File size: 9.3 KB
Line 
1// Copyright 2015 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.cpdp.training;
16
17import java.util.HashMap;
18import java.util.HashSet;
19import java.util.Iterator;
20import java.util.Map.Entry;
21import java.util.Set;
22import java.util.logging.Level;
23
24import de.ugoe.cs.util.console.Console;
25import weka.classifiers.AbstractClassifier;
26import weka.classifiers.Classifier;
27import weka.clusterers.EM;
28import weka.core.DenseInstance;
29import weka.core.Instance;
30import weka.core.Instances;
31import weka.filters.Filter;
32import weka.filters.unsupervised.attribute.Remove;
33
34/**
35 * <p>
36 * Local Trainer with EM Clustering for data partitioning. Currently supports only EM Clustering.
37 * </p>
38 * <ol>
39 * <li>Cluster training data</li>
40 * <li>for each cluster train a classifier with training data from cluster</li>
41 * <li>match test data instance to a cluster, then classify with classifier from the cluster</li>
42 * </ol>
43 *
44 * XML configuration:
45 *
46 * <pre>
47 * {@code
48 * <trainer name="WekaLocalEMTraining" param="NaiveBayes weka.classifiers.bayes.NaiveBayes" />
49 * }
50 * </pre>
51 */
52public class WekaLocalEMTraining extends WekaBaseTraining implements ITrainingStrategy {
53
54    /**
55     * the classifier
56     */
57    private final TraindatasetCluster classifier = new TraindatasetCluster();
58
59    /*
60     * (non-Javadoc)
61     *
62     * @see de.ugoe.cs.cpdp.training.WekaBaseTraining#getClassifier()
63     */
64    @Override
65    public Classifier getClassifier() {
66        return classifier;
67    }
68
69    /*
70     * (non-Javadoc)
71     *
72     * @see de.ugoe.cs.cpdp.training.ITrainingStrategy#apply(weka.core.Instances)
73     */
74    @Override
75    public void apply(Instances traindata) {
76        try {
77            classifier.buildClassifier(traindata);
78        }
79        catch (Exception e) {
80            throw new RuntimeException(e);
81        }
82    }
83
84    /**
85     * <p>
86     * Weka classifier for the local model with EM clustering.
87     * </p>
88     *
89     * @author Alexander Trautsch
90     */
91    public class TraindatasetCluster extends AbstractClassifier {
92
93        /**
94         * default serializtion ID
95         */
96        private static final long serialVersionUID = 1L;
97
98        /**
99         * EM clusterer used
100         */
101        private EM clusterer = null;
102
103        /**
104         * classifiers for each cluster
105         */
106        private HashMap<Integer, Classifier> cclassifier;
107
108        /**
109         * training data for each cluster
110         */
111        private HashMap<Integer, Instances> ctraindata;
112
113        /**
114         * Helper method that gives us a clean instance copy with the values of the instancelist of
115         * the first parameter.
116         *
117         * @param instancelist
118         *            with attributes
119         * @param instance
120         *            with only values
121         * @return copy of the instance
122         */
123        private Instance createInstance(Instances instances, Instance instance) {
124            // attributes for feeding instance to classifier
125            Set<String> attributeNames = new HashSet<>();
126            for (int j = 0; j < instances.numAttributes(); j++) {
127                attributeNames.add(instances.attribute(j).name());
128            }
129
130            double[] values = new double[instances.numAttributes()];
131            int index = 0;
132            for (int j = 0; j < instance.numAttributes(); j++) {
133                if (attributeNames.contains(instance.attribute(j).name())) {
134                    values[index] = instance.value(j);
135                    index++;
136                }
137            }
138
139            Instances tmp = new Instances(instances);
140            tmp.clear();
141            Instance instCopy = new DenseInstance(instance.weight(), values);
142            instCopy.setDataset(tmp);
143
144            return instCopy;
145        }
146
147        /*
148         * (non-Javadoc)
149         *
150         * @see weka.classifiers.AbstractClassifier#classifyInstance(weka.core.Instance)
151         */
152        @Override
153        public double classifyInstance(Instance instance) {
154            double ret = 0;
155            try {
156                // 1. copy the instance (keep the class attribute)
157                Instances traindata = ctraindata.get(0);
158                Instance classInstance = createInstance(traindata, instance);
159
160                // 2. remove class attribute before clustering
161                Remove filter = new Remove();
162                filter.setAttributeIndices("" + (traindata.classIndex() + 1));
163                filter.setInputFormat(traindata);
164                traindata = Filter.useFilter(traindata, filter);
165
166                // 3. copy the instance (without the class attribute) for clustering
167                Instance clusterInstance = createInstance(traindata, instance);
168
169                // 4. match instance without class attribute to a cluster number
170                int cnum = clusterer.clusterInstance(clusterInstance);
171
172                // 5. classify instance with class attribute to the classifier of that cluster
173                // number
174                ret = cclassifier.get(cnum).classifyInstance(classInstance);
175
176            }
177            catch (Exception e) {
178                Console.traceln(Level.INFO, String.format("ERROR matching instance to cluster!"));
179                throw new RuntimeException(e);
180            }
181            return ret;
182        }
183
184        /*
185         * (non-Javadoc)
186         *
187         * @see weka.classifiers.Classifier#buildClassifier(weka.core.Instances)
188         */
189        @Override
190        public void buildClassifier(Instances traindata) throws Exception {
191
192            // 1. copy training data
193            Instances train = new Instances(traindata);
194
195            // 2. remove class attribute for clustering
196            Remove filter = new Remove();
197            filter.setAttributeIndices("" + (train.classIndex() + 1));
198            filter.setInputFormat(train);
199            train = Filter.useFilter(train, filter);
200
201            // new objects
202            cclassifier = new HashMap<Integer, Classifier>();
203            ctraindata = new HashMap<Integer, Instances>();
204
205            Instances ctrain;
206            int maxNumClusters = train.size();
207            boolean sufficientInstancesInEachCluster;
208            do { // while(onlyTarget)
209                sufficientInstancesInEachCluster = true;
210                clusterer = new EM();
211                clusterer.setMaximumNumberOfClusters(maxNumClusters);
212                clusterer.buildClusterer(train);
213
214                // 4. get cluster membership of our traindata
215                // AddCluster cfilter = new AddCluster();
216                // cfilter.setClusterer(clusterer);
217                // cfilter.setInputFormat(train);
218                // Instances ctrain = Filter.useFilter(train, cfilter);
219
220                ctrain = new Instances(train);
221                ctraindata = new HashMap<>();
222
223                // get traindata per cluster
224                for (int j = 0; j < ctrain.numInstances(); j++) {
225                    // get the cluster number from the attributes, subract 1 because if we
226                    // clusterInstance we get 0-n, and this is 1-n
227                    // cnumber =
228                    // Integer.parseInt(ctrain.get(j).stringValue(ctrain.get(j).numAttributes()-1).replace("cluster",
229                    // "")) - 1;
230
231                    int cnumber = clusterer.clusterInstance(ctrain.get(j));
232                    // add training data to list of instances for this cluster number
233                    if (!ctraindata.containsKey(cnumber)) {
234                        ctraindata.put(cnumber, new Instances(traindata));
235                        ctraindata.get(cnumber).delete();
236                    }
237                    ctraindata.get(cnumber).add(traindata.get(j));
238                }
239
240                for (Entry<Integer, Instances> entry : ctraindata.entrySet()) {
241                    Instances instances = entry.getValue();
242                    int[] counts = instances.attributeStats(instances.classIndex()).nominalCounts;
243                    for (int count : counts) {
244                        sufficientInstancesInEachCluster &= count > 0;
245                    }
246                    sufficientInstancesInEachCluster &= instances.numInstances() >= 5;
247                }
248                maxNumClusters = clusterer.numberOfClusters() - 1;
249            }
250            while (!sufficientInstancesInEachCluster);
251
252            // train one classifier per cluster, we get the cluster number from the training data
253            Iterator<Integer> clusternumber = ctraindata.keySet().iterator();
254            while (clusternumber.hasNext()) {
255                int cnumber = clusternumber.next();
256                cclassifier.put(cnumber, setupClassifier());
257                cclassifier.get(cnumber).buildClassifier(ctraindata.get(cnumber));
258
259                // Console.traceln(Level.INFO, String.format("classifier in cluster "+cnumber));
260            }
261        }
262    }
263}
Note: See TracBrowser for help on using the repository browser.