// Copyright 2015 Georg-August-Universität Göttingen, Germany
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package de.ugoe.cs.cpdp.dataselection;
import java.util.LinkedList;
import java.util.List;
import java.util.logging.Level;
import org.apache.commons.collections4.list.SetUniqueList;
import weka.clusterers.EM;
import weka.core.Instances;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.AddCluster;
import weka.filters.unsupervised.attribute.Remove;
import de.ugoe.cs.util.console.Console;
/**
* Use in Config:
*
* Specify number of clusters -N = Num Clusters
*
* Try to determine the number of clusters: -I 10 = max iterations -X 5 = 5 folds for cross
* evaluation -max = max number of clusters
*
* Don't forget to add:
*/
public class PointWiseEMClusterSelection implements IPointWiseDataselectionStrategy {
private String[] params;
@Override
public void setParameter(String parameters) {
params = parameters.split(" ");
}
/**
* 1. Cluster the traindata 2. for each instance in the testdata find the assigned cluster 3.
* select only traindata from the clusters we found in our testdata
*
* @returns the selected training data
*/
@Override
public Instances apply(Instances testdata, Instances traindata) {
// final Attribute classAttribute = testdata.classAttribute();
final List selectedCluster =
SetUniqueList.setUniqueList(new LinkedList());
// 1. copy train- and testdata
Instances train = new Instances(traindata);
Instances test = new Instances(testdata);
Instances selected = null;
try {
// remove class attribute from traindata
Remove filter = new Remove();
filter.setAttributeIndices("" + (train.classIndex() + 1));
filter.setInputFormat(train);
train = Filter.useFilter(train, filter);
Console.traceln(Level.INFO, String.format("starting clustering"));
// 3. cluster data
EM clusterer = new EM();
clusterer.setOptions(params);
clusterer.buildClusterer(train);
int numClusters = clusterer.getNumClusters();
if (numClusters == -1) {
Console.traceln(Level.INFO, String.format("we have unlimited clusters"));
}
else {
Console.traceln(Level.INFO, String.format("we have: " + numClusters + " clusters"));
}
// 4. classify testdata, save cluster int
// remove class attribute from testdata?
Remove filter2 = new Remove();
filter2.setAttributeIndices("" + (test.classIndex() + 1));
filter2.setInputFormat(test);
test = Filter.useFilter(test, filter2);
int cnum;
for (int i = 0; i < test.numInstances(); i++) {
cnum = ((EM) clusterer).clusterInstance(test.get(i));
// we dont want doubles (maybe use a hashset instead of list?)
if (!selectedCluster.contains(cnum)) {
selectedCluster.add(cnum);
// Console.traceln(Level.INFO, String.format("assigned to cluster: "+cnum));
}
}
Console.traceln(Level.INFO,
String.format("our testdata is in: " + selectedCluster.size() +
" different clusters"));
// 5. get cluster membership of our traindata
AddCluster cfilter = new AddCluster();
cfilter.setClusterer(clusterer);
cfilter.setInputFormat(train);
Instances ctrain = Filter.useFilter(train, cfilter);
// 6. for all traindata get the cluster int, if it is in our list of testdata cluster
// int add the traindata
// of this cluster to our returned traindata
int cnumber;
selected = new Instances(traindata);
selected.delete();
for (int j = 0; j < ctrain.numInstances(); j++) {
// get the cluster number from the attributes
cnumber =
Integer.parseInt(ctrain.get(j).stringValue(ctrain.get(j).numAttributes() - 1)
.replace("cluster", ""));
// Console.traceln(Level.INFO,
// String.format("instance "+j+" is in cluster: "+cnumber));
if (selectedCluster.contains(cnumber)) {
// this only works if the index does not change
selected.add(traindata.get(j));
// check for differences, just one attribute, we are pretty sure the index does
// not change
if (traindata.get(j).value(3) != ctrain.get(j).value(3)) {
Console.traceln(Level.WARNING, String
.format("we have a difference between train an ctrain!"));
}
}
}
Console.traceln(Level.INFO,
String.format("that leaves us with: " + selected.numInstances() +
" traindata instances from " + traindata.numInstances()));
}
catch (Exception e) {
Console.traceln(Level.WARNING, String.format("ERROR"));
throw new RuntimeException("error in pointwise em", e);
}
return selected;
}
}