source: trunk/CrossPare/src/de/ugoe/cs/cpdp/eval/AbstractWekaEvaluation.java @ 4

Last change on this file since 4 was 2, checked in by sherbold, 10 years ago
  • initial commit
  • Property svn:mime-type set to text/plain
File size: 9.0 KB
Line 
1package de.ugoe.cs.cpdp.eval;
2
3import java.io.FileNotFoundException;
4import java.io.FileOutputStream;
5import java.io.PrintWriter;
6import java.util.ArrayList;
7import java.util.LinkedList;
8import java.util.List;
9
10import de.ugoe.cs.cpdp.training.ITrainer;
11import de.ugoe.cs.cpdp.training.WekaCompatibleTrainer;
12import de.ugoe.cs.util.StringTools;
13import weka.classifiers.Classifier;
14import weka.classifiers.Evaluation;
15import weka.core.Attribute;
16import weka.core.Instances;
17
18/**
19 * Base class for the evaluation of results of classifiers compatible with the {@link Classifier} interface.
20 * For each classifier, the following metrics are calculated:
21 * <ul>
22 *  <li>Success with recall>0.7, precision>0.5</li>
23 *  <li>Success with recall>0.7, precision>0.5</li>
24 *  <li>Success with gscore>0.75</li>
25 *  <li>Success with gscore>0.6</li>
26 *  <li>error rate</li>
27 *  <li>recall</li>
28 *  <li>precision</li>
29 *  <li>fscore</li>
30 *  <li>gscore</li>
31 *  <li>AUC</li>
32 *  <li>AUCEC (weighted by LOC, if applicable; 0.0 if LOC not available)</li>
33 *  <li>true positive rate</li>
34 *  <li>true negative rate</li>
35 *  <li>true positives</li>
36 *  <li>false positives</li>
37 *  <li>true negatives</li>
38 *  <li>false negatives</li>
39 * </ul>
40 * @author Steffen Herbold
41 */
42public abstract class AbstractWekaEvaluation implements IEvaluationStrategy {
43
44        /**
45         * writer for the evaluation results
46         */
47        private PrintWriter output = new PrintWriter(System.out);
48       
49        private boolean outputIsSystemOut = true;
50       
51        /**
52         * Creates the weka evaluator. Allows the creation of the evaluator in different ways, e.g., for cross-validation
53         * or evaluation on the test data.
54         * @param testdata test data
55         * @param classifier classifier used
56         * @return evaluator
57         */
58        protected abstract Evaluation createEvaluator(Instances testdata, Classifier classifier);
59       
60        /*
61         * (non-Javadoc)
62         * @see de.ugoe.cs.cpdp.eval.EvaluationStrategy#apply(weka.core.Instances, weka.core.Instances, java.util.List, boolean)
63         */
64        @Override
65        public void apply(Instances testdata, Instances traindata, List<ITrainer> trainers,
66                        boolean writeHeader) {
67                final List<Classifier> classifiers = new LinkedList<Classifier>();
68                for( ITrainer trainer : trainers ) {
69                        if( trainer instanceof WekaCompatibleTrainer ) {
70                                classifiers.add(((WekaCompatibleTrainer) trainer).getClassifier());
71                        } else {
72                                throw new RuntimeException("The selected evaluator only support Weka classifiers");
73                        }
74                }
75               
76                if( writeHeader ) {
77                        output.append("version,size_test,size_training");
78                        for( ITrainer trainer : trainers ) {
79                                output.append(",succHe_" + ((WekaCompatibleTrainer) trainer).getName());
80                                output.append(",succZi_" + ((WekaCompatibleTrainer) trainer).getName());
81                                output.append(",succG75_" + ((WekaCompatibleTrainer) trainer).getName());
82                                output.append(",succG60_" + ((WekaCompatibleTrainer) trainer).getName());
83                                output.append(",error_" + ((WekaCompatibleTrainer) trainer).getName());
84                                output.append(",recall_" + ((WekaCompatibleTrainer) trainer).getName());
85                                output.append(",precision_" + ((WekaCompatibleTrainer) trainer).getName());
86                                output.append(",fscore_" + ((WekaCompatibleTrainer) trainer).getName());
87                                output.append(",gscore_" + ((WekaCompatibleTrainer) trainer).getName());
88                                output.append(",mcc_" + ((WekaCompatibleTrainer) trainer).getName());
89                                output.append(",auc_" + ((WekaCompatibleTrainer) trainer).getName());
90                                output.append(",aucec_" + ((WekaCompatibleTrainer) trainer).getName());
91                                output.append(",tpr_" + ((WekaCompatibleTrainer) trainer).getName());
92                                output.append(",tnr_" + ((WekaCompatibleTrainer) trainer).getName());
93                                output.append(",tp_" + ((WekaCompatibleTrainer) trainer).getName());
94                                output.append(",fn_" + ((WekaCompatibleTrainer) trainer).getName());
95                                output.append(",tn_" + ((WekaCompatibleTrainer) trainer).getName());
96                                output.append(",fp_" + ((WekaCompatibleTrainer) trainer).getName());
97                        }
98                        output.append(StringTools.ENDLINE);
99                }
100               
101                output.append(testdata.relationName());
102                output.append("," + testdata.numInstances());   
103                output.append("," + traindata.numInstances());
104               
105                Evaluation eval = null;
106                for( Classifier classifier : classifiers ) {
107                        eval = createEvaluator(testdata, classifier);
108                       
109                        double pf = eval.numFalsePositives(1)/(eval.numFalsePositives(1)+eval.numTrueNegatives(1));
110                        double gmeasure = 2*eval.recall(1)*(1.0-pf)/(eval.recall(1)+(1.0-pf));
111                        double mcc = (eval.numTruePositives(1)*eval.numTrueNegatives(1)-eval.numFalsePositives(1)*eval.numFalseNegatives(1))/Math.sqrt((eval.numTruePositives(1)+eval.numFalsePositives(1))*(eval.numTruePositives(1)+eval.numFalseNegatives(1))*(eval.numTrueNegatives(1)+eval.numFalsePositives(1))*(eval.numTrueNegatives(1)+eval.numFalseNegatives(1)));
112                        double aucec = calculateReviewEffort(testdata, classifier);
113                       
114                        if( eval.recall(1)>=0.7 && eval.precision(1) >= 0.5 ) {
115                                output.append(",1");
116                        } else {
117                                output.append(",0");
118                        }
119                       
120                        if( eval.recall(1)>=0.7 && eval.precision(1) >= 0.7 ) {
121                                output.append(",1");
122                        } else {
123                                output.append(",0");
124                        }
125                       
126                        if( gmeasure>0.75 ) {
127                                output.append(",1");
128                        } else {
129                                output.append(",0");
130                        }
131                       
132                        if( gmeasure>0.6 ) {
133                                output.append(",1");
134                        } else {
135                                output.append(",0");
136                        }
137                       
138                        output.append("," + eval.errorRate());
139                        output.append("," + eval.recall(1));
140                        output.append("," + eval.precision(1));
141                        output.append("," + eval.fMeasure(1));
142                        output.append("," + gmeasure);
143                        output.append("," + mcc);
144                        output.append("," + eval.areaUnderROC(1));
145                        output.append("," + aucec);
146                        output.append("," + eval.truePositiveRate(1));
147                        output.append("," + eval.trueNegativeRate(1));
148                        output.append("," + eval.numTruePositives(1));
149                        output.append("," + eval.numFalseNegatives(1));
150                        output.append("," + eval.numTrueNegatives(1));
151                        output.append("," + eval.numFalsePositives(1));
152                }
153               
154                output.append(StringTools.ENDLINE);
155                output.flush();
156        }
157       
158        private double calculateReviewEffort(Instances testdata, Classifier classifier) {
159               
160                final Attribute loc = testdata.attribute("loc");
161                if( loc==null ) {
162                        return 0.0;
163                }
164                               
165                final List<Integer> bugPredicted = new ArrayList<>();
166                final List<Integer> nobugPredicted = new ArrayList<>();
167                double totalLoc = 0.0d;
168                int totalBugs = 0;
169                for( int i=0 ; i<testdata.numInstances() ; i++ ) {
170                        try {
171                                if( Double.compare(classifier.classifyInstance(testdata.instance(i)),0.0d)==0 ) {
172                                        nobugPredicted.add(i);
173                                } else {
174                                        bugPredicted.add(i);
175                                }
176                        } catch (Exception e) {
177                                throw new RuntimeException("unexpected error during the evaluation of the review effort", e);
178                        }
179                        if(Double.compare(testdata.instance(i).classValue(),1.0d)==0) {
180                                totalBugs++;
181                        }
182                        totalLoc += testdata.instance(i).value(loc);
183                }
184               
185                final List<Double> reviewLoc = new ArrayList<>(testdata.numInstances());
186                final List<Double> bugsFound = new ArrayList<>(testdata.numInstances());
187               
188                double currentBugsFound = 0;
189               
190                while( !bugPredicted.isEmpty() ) {
191                        double minLoc = Double.MAX_VALUE;
192                        int minIndex = -1;
193                        for( int i=0 ; i<bugPredicted.size() ; i++ ) {
194                                double currentLoc = testdata.instance(bugPredicted.get(i)).value(loc);
195                                if( currentLoc<minLoc ) {
196                                        minIndex = i;
197                                        minLoc = currentLoc;
198                                }
199                        }
200                        if( minIndex!=-1 ) {
201                                reviewLoc.add(minLoc/totalLoc);
202                               
203                                currentBugsFound += testdata.instance(bugPredicted.get(minIndex)).classValue();
204                                bugsFound.add(currentBugsFound);
205                               
206                                bugPredicted.remove(minIndex);
207                        } else {
208                                throw new RuntimeException("Shouldn't happen!");
209                        }
210                }
211               
212                while( !nobugPredicted.isEmpty() ) {
213                        double minLoc = Double.MAX_VALUE;
214                        int minIndex = -1;
215                        for( int i=0 ; i<nobugPredicted.size() ; i++ ) {
216                                double currentLoc = testdata.instance(nobugPredicted.get(i)).value(loc);
217                                if( currentLoc<minLoc ) {
218                                        minIndex = i;
219                                        minLoc = currentLoc;
220                                }
221                        }
222                        if( minIndex!=-1 ) {                           
223                                reviewLoc.add(minLoc/totalLoc);
224                               
225                                currentBugsFound += testdata.instance(nobugPredicted.get(minIndex)).classValue();
226                                bugsFound.add(currentBugsFound);
227                                nobugPredicted.remove(minIndex);
228                        } else {
229                                throw new RuntimeException("Shouldn't happen!");
230                        }
231                }
232               
233                double auc = 0.0;
234                for( int i=0 ; i<bugsFound.size() ; i++ ) {
235                        auc += reviewLoc.get(i)*bugsFound.get(i)/totalBugs;
236                }
237               
238                return auc;
239        }
240
241        /*
242         * (non-Javadoc)
243         * @see de.ugoe.cs.cpdp.Parameterizable#setParameter(java.lang.String)
244         */
245        @Override
246        public void setParameter(String parameters) {
247                if( output!=null && !outputIsSystemOut ) {
248                        output.close();
249                }
250                if( "system.out".equals(parameters) || "".equals(parameters) ) {
251                        output = new PrintWriter(System.out);
252                        outputIsSystemOut = true;
253                } else {
254                        try {
255                                output = new PrintWriter(new FileOutputStream(parameters));
256                                outputIsSystemOut = false;
257                        } catch (FileNotFoundException e) {
258                                throw new RuntimeException(e);
259                        }
260                }
261        }
262}
Note: See TracBrowser for help on using the repository browser.