source: trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaLocalEMTraining.java @ 110

Last change on this file since 110 was 99, checked in by sherbold, 9 years ago
  • improved error reporting
  • Property svn:mime-type set to text/plain
File size: 8.4 KB
Line 
1// Copyright 2015 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.cpdp.training;
16
17import java.util.HashMap;
18import java.util.HashSet;
19import java.util.Iterator;
20import java.util.Map.Entry;
21import java.util.Set;
22import java.util.logging.Level;
23
24import de.ugoe.cs.util.console.Console;
25import weka.classifiers.AbstractClassifier;
26import weka.classifiers.Classifier;
27import weka.clusterers.EM;
28import weka.core.DenseInstance;
29import weka.core.Instance;
30import weka.core.Instances;
31import weka.filters.Filter;
32import weka.filters.unsupervised.attribute.Remove;
33
34/**
35 * WekaLocalEMTraining
36 *
37 * Local Trainer with EM Clustering for data partitioning. Currently supports only EM Clustering.
38 *
39 * 1. Cluster training data 2. for each cluster train a classifier with training data from cluster
40 * 3. match test data instance to a cluster, then classify with classifier from the cluster
41 *
42 * XML configuration: <!-- because of clustering --> <preprocessor name="Normalization" param=""/>
43 *
44 * <!-- cluster trainer --> <trainer name="WekaLocalEMTraining"
45 * param="NaiveBayes weka.classifiers.bayes.NaiveBayes" />
46 */
47public class WekaLocalEMTraining extends WekaBaseTraining implements ITrainingStrategy {
48
49    private final TraindatasetCluster classifier = new TraindatasetCluster();
50
51    @Override
52    public Classifier getClassifier() {
53        return classifier;
54    }
55
56    @Override
57    public void apply(Instances traindata) {
58        try {
59            classifier.buildClassifier(traindata);
60        }
61        catch (Exception e) {
62            throw new RuntimeException(e);
63        }
64    }
65
66    public class TraindatasetCluster extends AbstractClassifier {
67
68        private static final long serialVersionUID = 1L;
69
70        private EM clusterer = null;
71
72        private HashMap<Integer, Classifier> cclassifier;
73        private HashMap<Integer, Instances> ctraindata;
74
75        /**
76         * Helper method that gives us a clean instance copy with the values of the instancelist of
77         * the first parameter.
78         *
79         * @param instancelist
80         *            with attributes
81         * @param instance
82         *            with only values
83         * @return copy of the instance
84         */
85        private Instance createInstance(Instances instances, Instance instance) {
86            // attributes for feeding instance to classifier
87            Set<String> attributeNames = new HashSet<>();
88            for (int j = 0; j < instances.numAttributes(); j++) {
89                attributeNames.add(instances.attribute(j).name());
90            }
91
92            double[] values = new double[instances.numAttributes()];
93            int index = 0;
94            for (int j = 0; j < instance.numAttributes(); j++) {
95                if (attributeNames.contains(instance.attribute(j).name())) {
96                    values[index] = instance.value(j);
97                    index++;
98                }
99            }
100
101            Instances tmp = new Instances(instances);
102            tmp.clear();
103            Instance instCopy = new DenseInstance(instance.weight(), values);
104            instCopy.setDataset(tmp);
105
106            return instCopy;
107        }
108
109        @Override
110        public double classifyInstance(Instance instance) {
111            double ret = 0;
112            try {
113                // 1. copy the instance (keep the class attribute)
114                Instances traindata = ctraindata.get(0);
115                Instance classInstance = createInstance(traindata, instance);
116
117                // 2. remove class attribute before clustering
118                Remove filter = new Remove();
119                filter.setAttributeIndices("" + (traindata.classIndex() + 1));
120                filter.setInputFormat(traindata);
121                traindata = Filter.useFilter(traindata, filter);
122
123                // 3. copy the instance (without the class attribute) for clustering
124                Instance clusterInstance = createInstance(traindata, instance);
125
126                // 4. match instance without class attribute to a cluster number
127                int cnum = clusterer.clusterInstance(clusterInstance);
128
129                // 5. classify instance with class attribute to the classifier of that cluster
130                // number
131                ret = cclassifier.get(cnum).classifyInstance(classInstance);
132
133            }
134            catch (Exception e) {
135                Console.traceln(Level.INFO, String.format("ERROR matching instance to cluster!"));
136                throw new RuntimeException(e);
137            }
138            return ret;
139        }
140
141        @Override
142        public void buildClassifier(Instances traindata) throws Exception {
143
144            // 1. copy training data
145            Instances train = new Instances(traindata);
146
147            // 2. remove class attribute for clustering
148            Remove filter = new Remove();
149            filter.setAttributeIndices("" + (train.classIndex() + 1));
150            filter.setInputFormat(train);
151            train = Filter.useFilter(train, filter);
152
153            // new objects
154            cclassifier = new HashMap<Integer, Classifier>();
155            ctraindata = new HashMap<Integer, Instances>();
156
157            Instances ctrain;
158            int maxNumClusters = train.size();
159            boolean sufficientInstancesInEachCluster;
160            do { // while(onlyTarget)
161                sufficientInstancesInEachCluster = true;
162                clusterer = new EM();
163                clusterer.setMaximumNumberOfClusters(maxNumClusters);
164                clusterer.buildClusterer(train);
165
166                // 4. get cluster membership of our traindata
167                // AddCluster cfilter = new AddCluster();
168                // cfilter.setClusterer(clusterer);
169                // cfilter.setInputFormat(train);
170                // Instances ctrain = Filter.useFilter(train, cfilter);
171
172                ctrain = new Instances(train);
173                ctraindata = new HashMap<>();
174
175                // get traindata per cluster
176                for (int j = 0; j < ctrain.numInstances(); j++) {
177                    // get the cluster number from the attributes, subract 1 because if we
178                    // clusterInstance we get 0-n, and this is 1-n
179                    // cnumber =
180                    // Integer.parseInt(ctrain.get(j).stringValue(ctrain.get(j).numAttributes()-1).replace("cluster",
181                    // "")) - 1;
182
183                    int cnumber = clusterer.clusterInstance(ctrain.get(j));
184                    // add training data to list of instances for this cluster number
185                    if (!ctraindata.containsKey(cnumber)) {
186                        ctraindata.put(cnumber, new Instances(traindata));
187                        ctraindata.get(cnumber).delete();
188                    }
189                    ctraindata.get(cnumber).add(traindata.get(j));
190                }
191
192                for (Entry<Integer, Instances> entry : ctraindata.entrySet()) {
193                    Instances instances = entry.getValue();
194                    int[] counts = instances.attributeStats(instances.classIndex()).nominalCounts;
195                    for (int count : counts) {
196                        sufficientInstancesInEachCluster &= count > 0;
197                    }
198                    sufficientInstancesInEachCluster &= instances.numInstances() >= 5;
199                }
200                maxNumClusters = clusterer.numberOfClusters() - 1;
201            }
202            while (!sufficientInstancesInEachCluster);
203
204            // train one classifier per cluster, we get the cluster number from the training data
205            Iterator<Integer> clusternumber = ctraindata.keySet().iterator();
206            while (clusternumber.hasNext()) {
207                int cnumber = clusternumber.next();
208                cclassifier.put(cnumber, setupClassifier());
209                cclassifier.get(cnumber).buildClassifier(ctraindata.get(cnumber));
210
211                // Console.traceln(Level.INFO, String.format("classifier in cluster "+cnumber));
212            }
213        }
214    }
215}
Note: See TracBrowser for help on using the repository browser.