1 | package de.ugoe.cs.cpdp.dataselection; |
---|
2 | |
---|
3 | import java.util.LinkedList; |
---|
4 | import java.util.List; |
---|
5 | import java.util.logging.Level; |
---|
6 | |
---|
7 | import org.apache.commons.collections4.list.SetUniqueList; |
---|
8 | |
---|
9 | import weka.clusterers.EM; |
---|
10 | import weka.core.Instances; |
---|
11 | import weka.filters.Filter; |
---|
12 | import weka.filters.unsupervised.attribute.AddCluster; |
---|
13 | import weka.filters.unsupervised.attribute.Remove; |
---|
14 | import de.ugoe.cs.util.console.Console; |
---|
15 | |
---|
16 | |
---|
17 | /** |
---|
18 | * Use in Config: |
---|
19 | * |
---|
20 | * Specify number of clusters |
---|
21 | * -N = Num Clusters |
---|
22 | * <pointwiseselector name="PointWiseEMClusterSelection" param="-N 10"/> |
---|
23 | * |
---|
24 | * Try to determine the number of clusters: |
---|
25 | * -I 10 = max iterations |
---|
26 | * -X 5 = 5 folds for cross evaluation |
---|
27 | * -max = max number of clusters |
---|
28 | * <pointwiseselector name="PointWiseEMClusterSelection" param="-I 10 -X 5 -max 300"/> |
---|
29 | * |
---|
30 | * Don't forget to add: |
---|
31 | * <preprocessor name="Normalization" param=""/> |
---|
32 | */ |
---|
33 | public class PointWiseEMClusterSelection implements IPointWiseDataselectionStrategy { |
---|
34 | |
---|
35 | private String[] params; |
---|
36 | |
---|
37 | @Override |
---|
38 | public void setParameter(String parameters) { |
---|
39 | params = parameters.split(" "); |
---|
40 | } |
---|
41 | |
---|
42 | |
---|
43 | /** |
---|
44 | * 1. Cluster the traindata |
---|
45 | * 2. for each instance in the testdata find the assigned cluster |
---|
46 | * 3. select only traindata from the clusters we found in our testdata |
---|
47 | * |
---|
48 | * @returns the selected training data |
---|
49 | */ |
---|
50 | @Override |
---|
51 | public Instances apply(Instances testdata, Instances traindata) { |
---|
52 | //final Attribute classAttribute = testdata.classAttribute(); |
---|
53 | |
---|
54 | final List<Integer> selectedCluster = SetUniqueList.setUniqueList(new LinkedList<Integer>()); |
---|
55 | |
---|
56 | // 1. copy train- and testdata |
---|
57 | Instances train = new Instances(traindata); |
---|
58 | Instances test = new Instances(testdata); |
---|
59 | |
---|
60 | Instances selected = null; |
---|
61 | |
---|
62 | try { |
---|
63 | // remove class attribute from traindata |
---|
64 | Remove filter = new Remove(); |
---|
65 | filter.setAttributeIndices("" + (train.classIndex() + 1)); |
---|
66 | filter.setInputFormat(train); |
---|
67 | train = Filter.useFilter(train, filter); |
---|
68 | |
---|
69 | Console.traceln(Level.INFO, String.format("starting clustering")); |
---|
70 | |
---|
71 | // 3. cluster data |
---|
72 | EM clusterer = new EM(); |
---|
73 | clusterer.setOptions(params); |
---|
74 | clusterer.buildClusterer(train); |
---|
75 | int numClusters = clusterer.getNumClusters(); |
---|
76 | if ( numClusters == -1) { |
---|
77 | Console.traceln(Level.INFO, String.format("we have unlimited clusters")); |
---|
78 | }else { |
---|
79 | Console.traceln(Level.INFO, String.format("we have: "+numClusters+" clusters")); |
---|
80 | } |
---|
81 | |
---|
82 | |
---|
83 | // 4. classify testdata, save cluster int |
---|
84 | |
---|
85 | // remove class attribute from testdata? |
---|
86 | Remove filter2 = new Remove(); |
---|
87 | filter2.setAttributeIndices("" + (test.classIndex() + 1)); |
---|
88 | filter2.setInputFormat(test); |
---|
89 | test = Filter.useFilter(test, filter2); |
---|
90 | |
---|
91 | int cnum; |
---|
92 | for( int i=0; i < test.numInstances(); i++ ) { |
---|
93 | cnum = ((EM)clusterer).clusterInstance(test.get(i)); |
---|
94 | |
---|
95 | // we dont want doubles (maybe use a hashset instead of list?) |
---|
96 | if ( !selectedCluster.contains(cnum) ) { |
---|
97 | selectedCluster.add(cnum); |
---|
98 | //Console.traceln(Level.INFO, String.format("assigned to cluster: "+cnum)); |
---|
99 | } |
---|
100 | } |
---|
101 | |
---|
102 | Console.traceln(Level.INFO, String.format("our testdata is in: "+selectedCluster.size()+" different clusters")); |
---|
103 | |
---|
104 | // 5. get cluster membership of our traindata |
---|
105 | AddCluster cfilter = new AddCluster(); |
---|
106 | cfilter.setClusterer(clusterer); |
---|
107 | cfilter.setInputFormat(train); |
---|
108 | Instances ctrain = Filter.useFilter(train, cfilter); |
---|
109 | |
---|
110 | |
---|
111 | // 6. for all traindata get the cluster int, if it is in our list of testdata cluster int add the traindata |
---|
112 | // of this cluster to our returned traindata |
---|
113 | int cnumber; |
---|
114 | selected = new Instances(traindata); |
---|
115 | selected.delete(); |
---|
116 | |
---|
117 | for ( int j=0; j < ctrain.numInstances(); j++ ) { |
---|
118 | // get the cluster number from the attributes |
---|
119 | cnumber = Integer.parseInt(ctrain.get(j).stringValue(ctrain.get(j).numAttributes()-1).replace("cluster", "")); |
---|
120 | |
---|
121 | //Console.traceln(Level.INFO, String.format("instance "+j+" is in cluster: "+cnumber)); |
---|
122 | if ( selectedCluster.contains(cnumber) ) { |
---|
123 | // this only works if the index does not change |
---|
124 | selected.add(traindata.get(j)); |
---|
125 | // check for differences, just one attribute, we are pretty sure the index does not change |
---|
126 | if ( traindata.get(j).value(3) != ctrain.get(j).value(3) ) { |
---|
127 | Console.traceln(Level.WARNING, String.format("we have a difference between train an ctrain!")); |
---|
128 | } |
---|
129 | } |
---|
130 | } |
---|
131 | |
---|
132 | Console.traceln(Level.INFO, String.format("that leaves us with: "+selected.numInstances()+" traindata instances from "+traindata.numInstances())); |
---|
133 | }catch( Exception e ) { |
---|
134 | Console.traceln(Level.WARNING, String.format("ERROR")); |
---|
135 | throw new RuntimeException("error in pointwise em", e); |
---|
136 | } |
---|
137 | |
---|
138 | return selected; |
---|
139 | } |
---|
140 | |
---|
141 | } |
---|