source: trunk/CrossPare/src/de/ugoe/cs/cpdp/dataselection/PointWiseEMClusterSelection.java @ 7

Last change on this file since 7 was 2, checked in by sherbold, 10 years ago
  • initial commit
  • Property svn:mime-type set to text/plain
File size: 4.6 KB
Line 
1package de.ugoe.cs.cpdp.dataselection;
2
3import java.util.LinkedList;
4import java.util.List;
5import java.util.logging.Level;
6
7import org.apache.commons.collections4.list.SetUniqueList;
8
9import weka.clusterers.EM;
10import weka.core.Instances;
11import weka.filters.Filter;
12import weka.filters.unsupervised.attribute.AddCluster;
13import weka.filters.unsupervised.attribute.Remove;
14import de.ugoe.cs.util.console.Console;
15
16
17/**
18 * Use in Config:
19 *
20 * Specify number of clusters
21 * -N = Num Clusters
22 * <pointwiseselector name="PointWiseEMClusterSelection" param="-N 10"/>
23 *
24 * Try to determine the number of clusters:
25 * -I 10 = max iterations
26 * -X 5 = 5 folds for cross evaluation
27 * -max = max number of clusters
28 * <pointwiseselector name="PointWiseEMClusterSelection" param="-I 10 -X 5 -max 300"/>
29 *
30 * Don't forget to add:
31 * <preprocessor name="Normalization" param=""/>
32 */
33public class PointWiseEMClusterSelection implements IPointWiseDataselectionStrategy {
34       
35        private String[] params;
36       
37        @Override
38        public void setParameter(String parameters) {
39                params = parameters.split(" ");
40        }
41
42       
43        /**
44         * 1. Cluster the traindata
45         * 2. for each instance in the testdata find the assigned cluster
46         * 3. select only traindata from the clusters we found in our testdata
47         *
48         * @returns the selected training data
49         */
50        @Override
51        public Instances apply(Instances testdata, Instances traindata) {
52                //final Attribute classAttribute = testdata.classAttribute();
53               
54                final List<Integer> selectedCluster = SetUniqueList.setUniqueList(new LinkedList<Integer>());
55
56                // 1. copy train- and testdata
57                Instances train = new Instances(traindata);
58                Instances test = new Instances(testdata);
59               
60                Instances selected = null;
61               
62                try {
63                        // remove class attribute from traindata
64                        Remove filter = new Remove();
65                        filter.setAttributeIndices("" + (train.classIndex() + 1));
66                        filter.setInputFormat(train);
67                        train = Filter.useFilter(train, filter);
68                       
69                        Console.traceln(Level.INFO, String.format("starting clustering"));
70                       
71                        // 3. cluster data
72                        EM clusterer = new EM();
73                        clusterer.setOptions(params);
74                        clusterer.buildClusterer(train);
75                        int numClusters = clusterer.getNumClusters();
76                        if ( numClusters == -1) {
77                                Console.traceln(Level.INFO, String.format("we have unlimited clusters"));
78                        }else {
79                                Console.traceln(Level.INFO, String.format("we have: "+numClusters+" clusters"));
80                        }
81                       
82                       
83                        // 4. classify testdata, save cluster int
84                       
85                        // remove class attribute from testdata?
86                        Remove filter2 = new Remove();
87                        filter2.setAttributeIndices("" + (test.classIndex() + 1));
88                        filter2.setInputFormat(test);
89                        test = Filter.useFilter(test, filter2);
90                       
91                        int cnum;
92                        for( int i=0; i < test.numInstances(); i++ ) {
93                                cnum = ((EM)clusterer).clusterInstance(test.get(i));
94
95                                // we dont want doubles (maybe use a hashset instead of list?)
96                                if ( !selectedCluster.contains(cnum) ) {
97                                        selectedCluster.add(cnum);
98                                        //Console.traceln(Level.INFO, String.format("assigned to cluster: "+cnum));
99                                }
100                        }
101                       
102                        Console.traceln(Level.INFO, String.format("our testdata is in: "+selectedCluster.size()+" different clusters"));
103                       
104                        // 5. get cluster membership of our traindata
105                        AddCluster cfilter = new AddCluster();
106                        cfilter.setClusterer(clusterer);
107                        cfilter.setInputFormat(train);
108                        Instances ctrain = Filter.useFilter(train, cfilter);
109                       
110                       
111                        // 6. for all traindata get the cluster int, if it is in our list of testdata cluster int add the traindata
112                        // of this cluster to our returned traindata
113                        int cnumber;
114                        selected = new Instances(traindata);
115                        selected.delete();
116                       
117                        for ( int j=0; j < ctrain.numInstances(); j++ ) {
118                                // get the cluster number from the attributes
119                                cnumber = Integer.parseInt(ctrain.get(j).stringValue(ctrain.get(j).numAttributes()-1).replace("cluster", ""));
120                               
121                                //Console.traceln(Level.INFO, String.format("instance "+j+" is in cluster: "+cnumber));
122                                if ( selectedCluster.contains(cnumber) ) {
123                                        // this only works if the index does not change
124                                        selected.add(traindata.get(j));
125                                        // check for differences, just one attribute, we are pretty sure the index does not change
126                                        if ( traindata.get(j).value(3) != ctrain.get(j).value(3) ) {
127                                                Console.traceln(Level.WARNING, String.format("we have a difference between train an ctrain!"));
128                                        }
129                                }
130                        }
131                       
132                        Console.traceln(Level.INFO, String.format("that leaves us with: "+selected.numInstances()+" traindata instances from "+traindata.numInstances()));
133                }catch( Exception e ) {
134                        Console.traceln(Level.WARNING, String.format("ERROR"));
135                        throw new RuntimeException("error in pointwise em", e);
136                }
137       
138                return selected;
139        }
140
141}
Note: See TracBrowser for help on using the repository browser.