source: trunk/CrossPare/src/de/ugoe/cs/cpdp/dataselection/PointWiseEMClusterSelection.java @ 50

Last change on this file since 50 was 41, checked in by sherbold, 9 years ago
  • formatted code and added copyrights
  • Property svn:mime-type set to text/plain
File size: 6.1 KB
Line 
1// Copyright 2015 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.cpdp.dataselection;
16
17import java.util.LinkedList;
18import java.util.List;
19import java.util.logging.Level;
20
21import org.apache.commons.collections4.list.SetUniqueList;
22
23import weka.clusterers.EM;
24import weka.core.Instances;
25import weka.filters.Filter;
26import weka.filters.unsupervised.attribute.AddCluster;
27import weka.filters.unsupervised.attribute.Remove;
28import de.ugoe.cs.util.console.Console;
29
30/**
31 * Use in Config:
32 *
33 * Specify number of clusters -N = Num Clusters <pointwiseselector
34 * name="PointWiseEMClusterSelection" param="-N 10"/>
35 *
36 * Try to determine the number of clusters: -I 10 = max iterations -X 5 = 5 folds for cross
37 * evaluation -max = max number of clusters <pointwiseselector name="PointWiseEMClusterSelection"
38 * param="-I 10 -X 5 -max 300"/>
39 *
40 * Don't forget to add: <preprocessor name="Normalization" param=""/>
41 */
42public class PointWiseEMClusterSelection implements IPointWiseDataselectionStrategy {
43
44    private String[] params;
45
46    @Override
47    public void setParameter(String parameters) {
48        params = parameters.split(" ");
49    }
50
51    /**
52     * 1. Cluster the traindata 2. for each instance in the testdata find the assigned cluster 3.
53     * select only traindata from the clusters we found in our testdata
54     *
55     * @returns the selected training data
56     */
57    @Override
58    public Instances apply(Instances testdata, Instances traindata) {
59        // final Attribute classAttribute = testdata.classAttribute();
60
61        final List<Integer> selectedCluster =
62            SetUniqueList.setUniqueList(new LinkedList<Integer>());
63
64        // 1. copy train- and testdata
65        Instances train = new Instances(traindata);
66        Instances test = new Instances(testdata);
67
68        Instances selected = null;
69
70        try {
71            // remove class attribute from traindata
72            Remove filter = new Remove();
73            filter.setAttributeIndices("" + (train.classIndex() + 1));
74            filter.setInputFormat(train);
75            train = Filter.useFilter(train, filter);
76
77            Console.traceln(Level.INFO, String.format("starting clustering"));
78
79            // 3. cluster data
80            EM clusterer = new EM();
81            clusterer.setOptions(params);
82            clusterer.buildClusterer(train);
83            int numClusters = clusterer.getNumClusters();
84            if (numClusters == -1) {
85                Console.traceln(Level.INFO, String.format("we have unlimited clusters"));
86            }
87            else {
88                Console.traceln(Level.INFO, String.format("we have: " + numClusters + " clusters"));
89            }
90
91            // 4. classify testdata, save cluster int
92
93            // remove class attribute from testdata?
94            Remove filter2 = new Remove();
95            filter2.setAttributeIndices("" + (test.classIndex() + 1));
96            filter2.setInputFormat(test);
97            test = Filter.useFilter(test, filter2);
98
99            int cnum;
100            for (int i = 0; i < test.numInstances(); i++) {
101                cnum = ((EM) clusterer).clusterInstance(test.get(i));
102
103                // we dont want doubles (maybe use a hashset instead of list?)
104                if (!selectedCluster.contains(cnum)) {
105                    selectedCluster.add(cnum);
106                    // Console.traceln(Level.INFO, String.format("assigned to cluster: "+cnum));
107                }
108            }
109
110            Console.traceln(Level.INFO,
111                            String.format("our testdata is in: " + selectedCluster.size() +
112                                " different clusters"));
113
114            // 5. get cluster membership of our traindata
115            AddCluster cfilter = new AddCluster();
116            cfilter.setClusterer(clusterer);
117            cfilter.setInputFormat(train);
118            Instances ctrain = Filter.useFilter(train, cfilter);
119
120            // 6. for all traindata get the cluster int, if it is in our list of testdata cluster
121            // int add the traindata
122            // of this cluster to our returned traindata
123            int cnumber;
124            selected = new Instances(traindata);
125            selected.delete();
126
127            for (int j = 0; j < ctrain.numInstances(); j++) {
128                // get the cluster number from the attributes
129                cnumber =
130                    Integer.parseInt(ctrain.get(j).stringValue(ctrain.get(j).numAttributes() - 1)
131                        .replace("cluster", ""));
132
133                // Console.traceln(Level.INFO,
134                // String.format("instance "+j+" is in cluster: "+cnumber));
135                if (selectedCluster.contains(cnumber)) {
136                    // this only works if the index does not change
137                    selected.add(traindata.get(j));
138                    // check for differences, just one attribute, we are pretty sure the index does
139                    // not change
140                    if (traindata.get(j).value(3) != ctrain.get(j).value(3)) {
141                        Console.traceln(Level.WARNING, String
142                            .format("we have a difference between train an ctrain!"));
143                    }
144                }
145            }
146
147            Console.traceln(Level.INFO,
148                            String.format("that leaves us with: " + selected.numInstances() +
149                                " traindata instances from " + traindata.numInstances()));
150        }
151        catch (Exception e) {
152            Console.traceln(Level.WARNING, String.format("ERROR"));
153            throw new RuntimeException("error in pointwise em", e);
154        }
155
156        return selected;
157    }
158
159}
Note: See TracBrowser for help on using the repository browser.