1 | // Copyright 2015 Georg-August-Universität Göttingen, Germany |
---|
2 | // |
---|
3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
---|
4 | // you may not use this file except in compliance with the License. |
---|
5 | // You may obtain a copy of the License at |
---|
6 | // |
---|
7 | // http://www.apache.org/licenses/LICENSE-2.0 |
---|
8 | // |
---|
9 | // Unless required by applicable law or agreed to in writing, software |
---|
10 | // distributed under the License is distributed on an "AS IS" BASIS, |
---|
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
---|
12 | // See the License for the specific language governing permissions and |
---|
13 | // limitations under the License. |
---|
14 | |
---|
15 | package de.ugoe.cs.cpdp.dataselection; |
---|
16 | |
---|
17 | import java.util.LinkedList; |
---|
18 | import java.util.List; |
---|
19 | import java.util.logging.Level; |
---|
20 | |
---|
21 | import org.apache.commons.collections4.list.SetUniqueList; |
---|
22 | |
---|
23 | import weka.clusterers.EM; |
---|
24 | import weka.core.Instances; |
---|
25 | import weka.filters.Filter; |
---|
26 | import weka.filters.unsupervised.attribute.AddCluster; |
---|
27 | import weka.filters.unsupervised.attribute.Remove; |
---|
28 | import de.ugoe.cs.util.console.Console; |
---|
29 | |
---|
30 | /** |
---|
31 | * Use in Config: |
---|
32 | * |
---|
33 | * Specify number of clusters -N = Num Clusters <pointwiseselector |
---|
34 | * name="PointWiseEMClusterSelection" param="-N 10"/> |
---|
35 | * |
---|
36 | * Try to determine the number of clusters: -I 10 = max iterations -X 5 = 5 folds for cross |
---|
37 | * evaluation -max = max number of clusters <pointwiseselector name="PointWiseEMClusterSelection" |
---|
38 | * param="-I 10 -X 5 -max 300"/> |
---|
39 | * |
---|
40 | * Don't forget to add: <preprocessor name="Normalization" param=""/> |
---|
41 | */ |
---|
42 | public class PointWiseEMClusterSelection implements IPointWiseDataselectionStrategy { |
---|
43 | |
---|
44 | private String[] params; |
---|
45 | |
---|
46 | @Override |
---|
47 | public void setParameter(String parameters) { |
---|
48 | params = parameters.split(" "); |
---|
49 | } |
---|
50 | |
---|
51 | /** |
---|
52 | * 1. Cluster the traindata 2. for each instance in the testdata find the assigned cluster 3. |
---|
53 | * select only traindata from the clusters we found in our testdata |
---|
54 | * |
---|
55 | * @returns the selected training data |
---|
56 | */ |
---|
57 | @Override |
---|
58 | public Instances apply(Instances testdata, Instances traindata) { |
---|
59 | // final Attribute classAttribute = testdata.classAttribute(); |
---|
60 | |
---|
61 | final List<Integer> selectedCluster = |
---|
62 | SetUniqueList.setUniqueList(new LinkedList<Integer>()); |
---|
63 | |
---|
64 | // 1. copy train- and testdata |
---|
65 | Instances train = new Instances(traindata); |
---|
66 | Instances test = new Instances(testdata); |
---|
67 | |
---|
68 | Instances selected = null; |
---|
69 | |
---|
70 | try { |
---|
71 | // remove class attribute from traindata |
---|
72 | Remove filter = new Remove(); |
---|
73 | filter.setAttributeIndices("" + (train.classIndex() + 1)); |
---|
74 | filter.setInputFormat(train); |
---|
75 | train = Filter.useFilter(train, filter); |
---|
76 | |
---|
77 | Console.traceln(Level.INFO, String.format("starting clustering")); |
---|
78 | |
---|
79 | // 3. cluster data |
---|
80 | EM clusterer = new EM(); |
---|
81 | clusterer.setOptions(params); |
---|
82 | clusterer.buildClusterer(train); |
---|
83 | int numClusters = clusterer.getNumClusters(); |
---|
84 | if (numClusters == -1) { |
---|
85 | Console.traceln(Level.INFO, String.format("we have unlimited clusters")); |
---|
86 | } |
---|
87 | else { |
---|
88 | Console.traceln(Level.INFO, String.format("we have: " + numClusters + " clusters")); |
---|
89 | } |
---|
90 | |
---|
91 | // 4. classify testdata, save cluster int |
---|
92 | |
---|
93 | // remove class attribute from testdata? |
---|
94 | Remove filter2 = new Remove(); |
---|
95 | filter2.setAttributeIndices("" + (test.classIndex() + 1)); |
---|
96 | filter2.setInputFormat(test); |
---|
97 | test = Filter.useFilter(test, filter2); |
---|
98 | |
---|
99 | int cnum; |
---|
100 | for (int i = 0; i < test.numInstances(); i++) { |
---|
101 | cnum = ((EM) clusterer).clusterInstance(test.get(i)); |
---|
102 | |
---|
103 | // we dont want doubles (maybe use a hashset instead of list?) |
---|
104 | if (!selectedCluster.contains(cnum)) { |
---|
105 | selectedCluster.add(cnum); |
---|
106 | // Console.traceln(Level.INFO, String.format("assigned to cluster: "+cnum)); |
---|
107 | } |
---|
108 | } |
---|
109 | |
---|
110 | Console.traceln(Level.INFO, |
---|
111 | String.format("our testdata is in: " + selectedCluster.size() + |
---|
112 | " different clusters")); |
---|
113 | |
---|
114 | // 5. get cluster membership of our traindata |
---|
115 | AddCluster cfilter = new AddCluster(); |
---|
116 | cfilter.setClusterer(clusterer); |
---|
117 | cfilter.setInputFormat(train); |
---|
118 | Instances ctrain = Filter.useFilter(train, cfilter); |
---|
119 | |
---|
120 | // 6. for all traindata get the cluster int, if it is in our list of testdata cluster |
---|
121 | // int add the traindata |
---|
122 | // of this cluster to our returned traindata |
---|
123 | int cnumber; |
---|
124 | selected = new Instances(traindata); |
---|
125 | selected.delete(); |
---|
126 | |
---|
127 | for (int j = 0; j < ctrain.numInstances(); j++) { |
---|
128 | // get the cluster number from the attributes |
---|
129 | cnumber = |
---|
130 | Integer.parseInt(ctrain.get(j).stringValue(ctrain.get(j).numAttributes() - 1) |
---|
131 | .replace("cluster", "")); |
---|
132 | |
---|
133 | // Console.traceln(Level.INFO, |
---|
134 | // String.format("instance "+j+" is in cluster: "+cnumber)); |
---|
135 | if (selectedCluster.contains(cnumber)) { |
---|
136 | // this only works if the index does not change |
---|
137 | selected.add(traindata.get(j)); |
---|
138 | // check for differences, just one attribute, we are pretty sure the index does |
---|
139 | // not change |
---|
140 | if (traindata.get(j).value(3) != ctrain.get(j).value(3)) { |
---|
141 | Console.traceln(Level.WARNING, String |
---|
142 | .format("we have a difference between train an ctrain!")); |
---|
143 | } |
---|
144 | } |
---|
145 | } |
---|
146 | |
---|
147 | Console.traceln(Level.INFO, |
---|
148 | String.format("that leaves us with: " + selected.numInstances() + |
---|
149 | " traindata instances from " + traindata.numInstances())); |
---|
150 | } |
---|
151 | catch (Exception e) { |
---|
152 | Console.traceln(Level.WARNING, String.format("ERROR")); |
---|
153 | throw new RuntimeException("error in pointwise em", e); |
---|
154 | } |
---|
155 | |
---|
156 | return selected; |
---|
157 | } |
---|
158 | |
---|
159 | } |
---|