source: trunk/CrossPare/src/de/ugoe/cs/cpdp/dataselection/SeparatabilitySelection.java @ 19

Last change on this file since 19 was 2, checked in by sherbold, 10 years ago
  • initial commit
  • Property svn:mime-type set to text/plain
File size: 3.1 KB
Line 
1package de.ugoe.cs.cpdp.dataselection;
2
3import java.util.Arrays;
4import java.util.Random;
5
6import org.apache.commons.collections4.list.SetUniqueList;
7
8import weka.classifiers.Evaluation;
9import weka.classifiers.functions.Logistic;
10import weka.core.DenseInstance;
11import weka.core.Instance;
12import weka.core.Instances;
13
14/**
15 * A setwise data selection strategy based on the separatability of the training data from the test data after Z. He, F. Peters, T. Menzies, Y. Yang: Learning from Open-Source Projects: An Empirical Study on Defect Prediction.
16 * <br><br>
17 * This is calculated through the error of a logistic regression classifier that tries to separate the sets.
18 * @author Steffen Herbold
19 */
20public class SeparatabilitySelection implements ISetWiseDataselectionStrategy {
21
22        /**
23         * size of the random sample that is drawn from both test data and training data
24         */
25        private final int sampleSize = 500;
26       
27        /**
28         * number of repetitions of the sample drawing
29         */
30        private final int maxRep = 10;
31       
32        /**
33         * number of neighbors that are selected
34         */
35        private int neighbors = 10;
36       
37        /**
38         * Sets the number of neighbors that are selected.
39         */
40        @Override
41        public void setParameter(String parameters) {
42                if( !"".equals(parameters) ) {
43                        neighbors = Integer.parseInt(parameters);
44                }
45        }
46
47        /**
48         * @see de.ugoe.cs.cpdp.dataselection.SetWiseDataselectionStrategy#apply(weka.core.Instances, org.apache.commons.collections4.list.SetUniqueList)
49         */
50        @Override
51        public void apply(Instances testdata, SetUniqueList<Instances> traindataSet) {
52                final Random rand = new Random(1);
53               
54                // calculate distances between testdata and traindata
55                final double[] distances = new double[traindataSet.size()];
56               
57                int i=0;
58                for( Instances traindata : traindataSet ) {
59                        double distance = 0.0;
60                        for( int rep=0; rep<maxRep ; rep++ ) {
61                                // sample instances
62                                Instances sample = new Instances(testdata);
63                                for( int j=0; j<sampleSize; j++ ) {
64                                        Instance inst = new DenseInstance(testdata.instance(rand.nextInt(testdata.numInstances())));
65                                        inst.setDataset(sample);
66                                        inst.setClassValue(1.0);
67                                        sample.add(inst);
68                                        inst = new DenseInstance(traindata.instance(rand.nextInt(traindata.numInstances())));
69                                        inst.setDataset(sample);
70                                        inst.setClassValue(0.0);
71                                        sample.add(inst);
72                                }
73                               
74                                // calculate separation
75                                Evaluation eval;
76                                try {
77                                        eval = new Evaluation(sample);
78                                        eval.crossValidateModel(new Logistic(), sample, 5, rand);
79                                } catch (Exception e) {
80                                        throw new RuntimeException("cross-validation during calculation of separatability failed", e);
81                                }
82                                distance += eval.pctCorrect()/100.0;
83                        }
84                        distances[i++] = 2*((distance/maxRep)-0.5);
85                }
86               
87                // select closest neighbors
88                final double[] distancesCopy = Arrays.copyOf(distances, distances.length);
89                Arrays.sort(distancesCopy);
90                final double cutoffDistance = distancesCopy[neighbors];
91               
92                for( i=traindataSet.size()-1; i>=0 ; i-- ) {
93                        if( distances[i]>cutoffDistance ) {
94                                traindataSet.remove(i);
95                        }
96                }
97        }
98}
Note: See TracBrowser for help on using the repository browser.