// Copyright 2015 Georg-August-Universität Göttingen, Germany // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package de.ugoe.cs.cpdp.training; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.Map.Entry; import java.util.Set; import java.util.logging.Level; import de.ugoe.cs.util.console.Console; import weka.classifiers.AbstractClassifier; import weka.classifiers.Classifier; import weka.clusterers.EM; import weka.core.DenseInstance; import weka.core.Instance; import weka.core.Instances; import weka.filters.Filter; import weka.filters.unsupervised.attribute.Remove; /** *

* Local Trainer with EM Clustering for data partitioning. Currently supports only EM Clustering. *

*
    *
  1. Cluster training data
  2. *
  3. for each cluster train a classifier with training data from cluster
  4. *
  5. match test data instance to a cluster, then classify with classifier from the cluster
  6. *
* * XML configuration: * *
 * {@code
 * 
 * }
 * 
*/ public class WekaLocalEMTraining extends WekaBaseTraining implements ITrainingStrategy { /** * the classifier */ private final TraindatasetCluster classifier = new TraindatasetCluster(); /* * (non-Javadoc) * * @see de.ugoe.cs.cpdp.training.WekaBaseTraining#getClassifier() */ @Override public Classifier getClassifier() { return classifier; } /* * (non-Javadoc) * * @see de.ugoe.cs.cpdp.training.ITrainingStrategy#apply(weka.core.Instances) */ @Override public void apply(Instances traindata) { try { classifier.buildClassifier(traindata); } catch (Exception e) { throw new RuntimeException(e); } } /** *

* Weka classifier for the local model with EM clustering. *

* * @author Alexander Trautsch */ public class TraindatasetCluster extends AbstractClassifier { /** * default serializtion ID */ private static final long serialVersionUID = 1L; /** * EM clusterer used */ private EM clusterer = null; /** * classifiers for each cluster */ private HashMap cclassifier; /** * training data for each cluster */ private HashMap ctraindata; /** * Helper method that gives us a clean instance copy with the values of the instancelist of * the first parameter. * * @param instancelist * with attributes * @param instance * with only values * @return copy of the instance */ private Instance createInstance(Instances instances, Instance instance) { // attributes for feeding instance to classifier Set attributeNames = new HashSet<>(); for (int j = 0; j < instances.numAttributes(); j++) { attributeNames.add(instances.attribute(j).name()); } double[] values = new double[instances.numAttributes()]; int index = 0; for (int j = 0; j < instance.numAttributes(); j++) { if (attributeNames.contains(instance.attribute(j).name())) { values[index] = instance.value(j); index++; } } Instances tmp = new Instances(instances); tmp.clear(); Instance instCopy = new DenseInstance(instance.weight(), values); instCopy.setDataset(tmp); return instCopy; } /* * (non-Javadoc) * * @see weka.classifiers.AbstractClassifier#classifyInstance(weka.core.Instance) */ @Override public double classifyInstance(Instance instance) { double ret = 0; try { // 1. copy the instance (keep the class attribute) Instances traindata = ctraindata.get(0); Instance classInstance = createInstance(traindata, instance); // 2. remove class attribute before clustering Remove filter = new Remove(); filter.setAttributeIndices("" + (traindata.classIndex() + 1)); filter.setInputFormat(traindata); traindata = Filter.useFilter(traindata, filter); // 3. copy the instance (without the class attribute) for clustering Instance clusterInstance = createInstance(traindata, instance); // 4. match instance without class attribute to a cluster number int cnum = clusterer.clusterInstance(clusterInstance); // 5. classify instance with class attribute to the classifier of that cluster // number ret = cclassifier.get(cnum).classifyInstance(classInstance); } catch (Exception e) { Console.traceln(Level.INFO, String.format("ERROR matching instance to cluster!")); throw new RuntimeException(e); } return ret; } /* * (non-Javadoc) * * @see weka.classifiers.Classifier#buildClassifier(weka.core.Instances) */ @Override public void buildClassifier(Instances traindata) throws Exception { // 1. copy training data Instances train = new Instances(traindata); // 2. remove class attribute for clustering Remove filter = new Remove(); filter.setAttributeIndices("" + (train.classIndex() + 1)); filter.setInputFormat(train); train = Filter.useFilter(train, filter); // new objects cclassifier = new HashMap(); ctraindata = new HashMap(); Instances ctrain; int maxNumClusters = train.size(); boolean sufficientInstancesInEachCluster; do { // while(onlyTarget) sufficientInstancesInEachCluster = true; clusterer = new EM(); clusterer.setMaximumNumberOfClusters(maxNumClusters); clusterer.buildClusterer(train); // 4. get cluster membership of our traindata // AddCluster cfilter = new AddCluster(); // cfilter.setClusterer(clusterer); // cfilter.setInputFormat(train); // Instances ctrain = Filter.useFilter(train, cfilter); ctrain = new Instances(train); ctraindata = new HashMap<>(); // get traindata per cluster for (int j = 0; j < ctrain.numInstances(); j++) { // get the cluster number from the attributes, subract 1 because if we // clusterInstance we get 0-n, and this is 1-n // cnumber = // Integer.parseInt(ctrain.get(j).stringValue(ctrain.get(j).numAttributes()-1).replace("cluster", // "")) - 1; int cnumber = clusterer.clusterInstance(ctrain.get(j)); // add training data to list of instances for this cluster number if (!ctraindata.containsKey(cnumber)) { ctraindata.put(cnumber, new Instances(traindata)); ctraindata.get(cnumber).delete(); } ctraindata.get(cnumber).add(traindata.get(j)); } for (Entry entry : ctraindata.entrySet()) { Instances instances = entry.getValue(); int[] counts = instances.attributeStats(instances.classIndex()).nominalCounts; for (int count : counts) { sufficientInstancesInEachCluster &= count > 0; } sufficientInstancesInEachCluster &= instances.numInstances() >= 5; } maxNumClusters = clusterer.numberOfClusters() - 1; } while (!sufficientInstancesInEachCluster); // train one classifier per cluster, we get the cluster number from the training data Iterator clusternumber = ctraindata.keySet().iterator(); while (clusternumber.hasNext()) { int cnumber = clusternumber.next(); cclassifier.put(cnumber, setupClassifier()); cclassifier.get(cnumber).buildClassifier(ctraindata.get(cnumber)); // Console.traceln(Level.INFO, String.format("classifier in cluster "+cnumber)); } } } }