source: trunk/CrossPare/src/de/ugoe/cs/cpdp/dataselection/PointWiseEMClusterSelection.java

Last change on this file was 135, checked in by sherbold, 8 years ago
  • code documentation and formatting
  • Property svn:mime-type set to text/plain
File size: 6.2 KB
Line 
1// Copyright 2015 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.cpdp.dataselection;
16
17import java.util.LinkedList;
18import java.util.List;
19import java.util.logging.Level;
20
21import org.apache.commons.collections4.list.SetUniqueList;
22
23import weka.clusterers.EM;
24import weka.core.Instances;
25import weka.filters.Filter;
26import weka.filters.unsupervised.attribute.AddCluster;
27import weka.filters.unsupervised.attribute.Remove;
28import de.ugoe.cs.util.console.Console;
29
30/**
31 * Use in Config:
32 *
33 * Specify number of clusters -N = Num Clusters
34 * <pointwiseselector name="PointWiseEMClusterSelection" param="-N 10"/>
35 *
36 * Try to determine the number of clusters: -I 10 = max iterations -X 5 = 5 folds for cross
37 * evaluation -max = max number of clusters
38 * <pointwiseselector name="PointWiseEMClusterSelection" param="-I 10 -X 5 -max 300"/>
39 *
40 * Don't forget to add: <preprocessor name="Normalization" param=""/>
41 */
42public class PointWiseEMClusterSelection implements IPointWiseDataselectionStrategy {
43
44    /**
45     * paramters passed to the selection
46     */
47    private String[] params;
48
49    /*
50     * (non-Javadoc)
51     *
52     * @see de.ugoe.cs.cpdp.IParameterizable#setParameter(java.lang.String)
53     */
54    @Override
55    public void setParameter(String parameters) {
56        params = parameters.split(" ");
57    }
58
59    /**
60     * 1. Cluster the traindata 2. for each instance in the testdata find the assigned cluster 3.
61     * select only traindata from the clusters we found in our testdata
62     *
63     * @returns the selected training data
64     */
65    @Override
66    public Instances apply(Instances testdata, Instances traindata) {
67        // final Attribute classAttribute = testdata.classAttribute();
68
69        final List<Integer> selectedCluster =
70            SetUniqueList.setUniqueList(new LinkedList<Integer>());
71
72        // 1. copy train- and testdata
73        Instances train = new Instances(traindata);
74        Instances test = new Instances(testdata);
75
76        Instances selected = null;
77
78        try {
79            // remove class attribute from traindata
80            Remove filter = new Remove();
81            filter.setAttributeIndices("" + (train.classIndex() + 1));
82            filter.setInputFormat(train);
83            train = Filter.useFilter(train, filter);
84
85            Console.traceln(Level.INFO, String.format("starting clustering"));
86
87            // 3. cluster data
88            EM clusterer = new EM();
89            clusterer.setOptions(params);
90            clusterer.buildClusterer(train);
91            int numClusters = clusterer.getNumClusters();
92            if (numClusters == -1) {
93                Console.traceln(Level.INFO, String.format("we have unlimited clusters"));
94            }
95            else {
96                Console.traceln(Level.INFO, String.format("we have: " + numClusters + " clusters"));
97            }
98
99            // 4. classify testdata, save cluster int
100
101            // remove class attribute from testdata?
102            Remove filter2 = new Remove();
103            filter2.setAttributeIndices("" + (test.classIndex() + 1));
104            filter2.setInputFormat(test);
105            test = Filter.useFilter(test, filter2);
106
107            int cnum;
108            for (int i = 0; i < test.numInstances(); i++) {
109                cnum = ((EM) clusterer).clusterInstance(test.get(i));
110
111                // we dont want doubles (maybe use a hashset instead of list?)
112                if (!selectedCluster.contains(cnum)) {
113                    selectedCluster.add(cnum);
114                    // Console.traceln(Level.INFO, String.format("assigned to cluster: "+cnum));
115                }
116            }
117
118            Console.traceln(Level.INFO, String
119                .format("our testdata is in: " + selectedCluster.size() + " different clusters"));
120
121            // 5. get cluster membership of our traindata
122            AddCluster cfilter = new AddCluster();
123            cfilter.setClusterer(clusterer);
124            cfilter.setInputFormat(train);
125            Instances ctrain = Filter.useFilter(train, cfilter);
126
127            // 6. for all traindata get the cluster int, if it is in our list of testdata cluster
128            // int add the traindata
129            // of this cluster to our returned traindata
130            int cnumber;
131            selected = new Instances(traindata);
132            selected.delete();
133
134            for (int j = 0; j < ctrain.numInstances(); j++) {
135                // get the cluster number from the attributes
136                cnumber = Integer.parseInt(ctrain.get(j)
137                    .stringValue(ctrain.get(j).numAttributes() - 1).replace("cluster", ""));
138
139                // Console.traceln(Level.INFO,
140                // String.format("instance "+j+" is in cluster: "+cnumber));
141                if (selectedCluster.contains(cnumber)) {
142                    // this only works if the index does not change
143                    selected.add(traindata.get(j));
144                    // check for differences, just one attribute, we are pretty sure the index does
145                    // not change
146                    if (traindata.get(j).value(3) != ctrain.get(j).value(3)) {
147                        Console.traceln(Level.WARNING, String
148                            .format("we have a difference between train an ctrain!"));
149                    }
150                }
151            }
152
153            Console.traceln(Level.INFO, String.format("that leaves us with: " +
154                selected.numInstances() + " traindata instances from " + traindata.numInstances()));
155        }
156        catch (Exception e) {
157            Console.traceln(Level.WARNING, String.format("ERROR"));
158            throw new RuntimeException("error in pointwise em", e);
159        }
160
161        return selected;
162    }
163
164}
Note: See TracBrowser for help on using the repository browser.