source: trunk/CrossPare/src/de/ugoe/cs/cpdp/eval/AbstractWekaEvaluation.java @ 42

Last change on this file since 42 was 41, checked in by sherbold, 9 years ago
  • formatted code and added copyrights
  • Property svn:mime-type set to text/plain
File size: 12.8 KB
Line 
1// Copyright 2015 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.cpdp.eval;
16
17import java.io.FileNotFoundException;
18import java.io.FileOutputStream;
19import java.io.PrintWriter;
20import java.util.ArrayList;
21import java.util.LinkedList;
22import java.util.List;
23
24import de.ugoe.cs.cpdp.training.ITrainer;
25import de.ugoe.cs.cpdp.training.IWekaCompatibleTrainer;
26import de.ugoe.cs.util.StringTools;
27import weka.classifiers.Classifier;
28import weka.classifiers.Evaluation;
29import weka.core.Attribute;
30import weka.core.Instances;
31
32/**
33 * Base class for the evaluation of results of classifiers compatible with the {@link Classifier}
34 * interface. For each classifier, the following metrics are calculated:
35 * <ul>
36 * <li>succHe: Success with recall>0.7, precision>0.5</li>
37 * <li>succZi: Success with recall>0.7, precision>0.7</li>
38 * <li>succG75: Success with gscore>0.75</li>
39 * <li>succG60: Success with gscore>0.6</li>
40 * <li>error</li>
41 * <li>recall</li>
42 * <li>precision</li>
43 * <li>fscore</li>
44 * <li>gscore</li>
45 * <li>AUC</li>
46 * <li>AUCEC (weighted by LOC, if applicable; 0.0 if LOC not available)</li>
47 * <li>tpr: true positive rate</li>
48 * <li>tnr: true negative rate</li>
49 * <li>tp: true positives</li>
50 * <li>fp: false positives</li>
51 * <li>tn: true negatives</li>
52 * <li>fn: false negatives</li>
53 * <li>errortrain: training error</li>
54 * <li>recalltrain: training recall</li>
55 * <li>precisiontrain: training precision</li>
56 * <li>succHetrain: training success with recall>0.7 and precision>0.5
57 * </ul>
58 *
59 * @author Steffen Herbold
60 */
61public abstract class AbstractWekaEvaluation implements IEvaluationStrategy {
62
63    /**
64     * writer for the evaluation results
65     */
66    private PrintWriter output = new PrintWriter(System.out);
67
68    private boolean outputIsSystemOut = true;
69
70    /**
71     * Creates the weka evaluator. Allows the creation of the evaluator in different ways, e.g., for
72     * cross-validation or evaluation on the test data.
73     *
74     * @param testdata
75     *            test data
76     * @param classifier
77     *            classifier used
78     * @return evaluator
79     */
80    protected abstract Evaluation createEvaluator(Instances testdata, Classifier classifier);
81
82    /*
83     * (non-Javadoc)
84     *
85     * @see de.ugoe.cs.cpdp.eval.EvaluationStrategy#apply(weka.core.Instances, weka.core.Instances,
86     * java.util.List, boolean)
87     */
88    @Override
89    public void apply(Instances testdata,
90                      Instances traindata,
91                      List<ITrainer> trainers,
92                      boolean writeHeader)
93    {
94        final List<Classifier> classifiers = new LinkedList<Classifier>();
95        for (ITrainer trainer : trainers) {
96            if (trainer instanceof IWekaCompatibleTrainer) {
97                classifiers.add(((IWekaCompatibleTrainer) trainer).getClassifier());
98            }
99            else {
100                throw new RuntimeException("The selected evaluator only support Weka classifiers");
101            }
102        }
103
104        if (writeHeader) {
105            output.append("version,size_test,size_training");
106            for (ITrainer trainer : trainers) {
107                output.append(",succHe_" + ((IWekaCompatibleTrainer) trainer).getName());
108                output.append(",succZi_" + ((IWekaCompatibleTrainer) trainer).getName());
109                output.append(",succG75_" + ((IWekaCompatibleTrainer) trainer).getName());
110                output.append(",succG60_" + ((IWekaCompatibleTrainer) trainer).getName());
111                output.append(",error_" + ((IWekaCompatibleTrainer) trainer).getName());
112                output.append(",recall_" + ((IWekaCompatibleTrainer) trainer).getName());
113                output.append(",precision_" + ((IWekaCompatibleTrainer) trainer).getName());
114                output.append(",fscore_" + ((IWekaCompatibleTrainer) trainer).getName());
115                output.append(",gscore_" + ((IWekaCompatibleTrainer) trainer).getName());
116                output.append(",mcc_" + ((IWekaCompatibleTrainer) trainer).getName());
117                output.append(",auc_" + ((IWekaCompatibleTrainer) trainer).getName());
118                output.append(",aucec_" + ((IWekaCompatibleTrainer) trainer).getName());
119                output.append(",tpr_" + ((IWekaCompatibleTrainer) trainer).getName());
120                output.append(",tnr_" + ((IWekaCompatibleTrainer) trainer).getName());
121                output.append(",tp_" + ((IWekaCompatibleTrainer) trainer).getName());
122                output.append(",fn_" + ((IWekaCompatibleTrainer) trainer).getName());
123                output.append(",tn_" + ((IWekaCompatibleTrainer) trainer).getName());
124                output.append(",fp_" + ((IWekaCompatibleTrainer) trainer).getName());
125                output.append(",trainerror_" + ((IWekaCompatibleTrainer) trainer).getName());
126                output.append(",trainrecall_" + ((IWekaCompatibleTrainer) trainer).getName());
127                output.append(",trainprecision_" + ((IWekaCompatibleTrainer) trainer).getName());
128                output.append(",trainsuccHe_" + ((IWekaCompatibleTrainer) trainer).getName());
129            }
130            output.append(StringTools.ENDLINE);
131        }
132
133        output.append(testdata.relationName());
134        output.append("," + testdata.numInstances());
135        output.append("," + traindata.numInstances());
136
137        Evaluation eval = null;
138        Evaluation evalTrain = null;
139        for (Classifier classifier : classifiers) {
140            eval = createEvaluator(testdata, classifier);
141            evalTrain = createEvaluator(traindata, classifier);
142
143            double pf =
144                eval.numFalsePositives(1) / (eval.numFalsePositives(1) + eval.numTrueNegatives(1));
145            double gmeasure = 2 * eval.recall(1) * (1.0 - pf) / (eval.recall(1) + (1.0 - pf));
146            double mcc =
147                (eval.numTruePositives(1) * eval.numTrueNegatives(1) - eval.numFalsePositives(1) *
148                    eval.numFalseNegatives(1)) /
149                    Math.sqrt((eval.numTruePositives(1) + eval.numFalsePositives(1)) *
150                        (eval.numTruePositives(1) + eval.numFalseNegatives(1)) *
151                        (eval.numTrueNegatives(1) + eval.numFalsePositives(1)) *
152                        (eval.numTrueNegatives(1) + eval.numFalseNegatives(1)));
153            double aucec = calculateReviewEffort(testdata, classifier);
154
155            if (eval.recall(1) >= 0.7 && eval.precision(1) >= 0.5) {
156                output.append(",1");
157            }
158            else {
159                output.append(",0");
160            }
161
162            if (eval.recall(1) >= 0.7 && eval.precision(1) >= 0.7) {
163                output.append(",1");
164            }
165            else {
166                output.append(",0");
167            }
168
169            if (gmeasure > 0.75) {
170                output.append(",1");
171            }
172            else {
173                output.append(",0");
174            }
175
176            if (gmeasure > 0.6) {
177                output.append(",1");
178            }
179            else {
180                output.append(",0");
181            }
182
183            output.append("," + eval.errorRate());
184            output.append("," + eval.recall(1));
185            output.append("," + eval.precision(1));
186            output.append("," + eval.fMeasure(1));
187            output.append("," + gmeasure);
188            output.append("," + mcc);
189            output.append("," + eval.areaUnderROC(1));
190            output.append("," + aucec);
191            output.append("," + eval.truePositiveRate(1));
192            output.append("," + eval.trueNegativeRate(1));
193            output.append("," + eval.numTruePositives(1));
194            output.append("," + eval.numFalseNegatives(1));
195            output.append("," + eval.numTrueNegatives(1));
196            output.append("," + eval.numFalsePositives(1));
197            output.append("," + evalTrain.errorRate());
198            output.append("," + evalTrain.recall(1));
199            output.append("," + evalTrain.precision(1));
200            if (evalTrain.recall(1) >= 0.7 && evalTrain.precision(1) >= 0.5) {
201                output.append(",1");
202            }
203            else {
204                output.append(",0");
205            }
206        }
207
208        output.append(StringTools.ENDLINE);
209        output.flush();
210    }
211
212    private double calculateReviewEffort(Instances testdata, Classifier classifier) {
213
214        final Attribute loc = testdata.attribute("loc");
215        if (loc == null) {
216            return 0.0;
217        }
218
219        final List<Integer> bugPredicted = new ArrayList<>();
220        final List<Integer> nobugPredicted = new ArrayList<>();
221        double totalLoc = 0.0d;
222        int totalBugs = 0;
223        for (int i = 0; i < testdata.numInstances(); i++) {
224            try {
225                if (Double.compare(classifier.classifyInstance(testdata.instance(i)), 0.0d) == 0) {
226                    nobugPredicted.add(i);
227                }
228                else {
229                    bugPredicted.add(i);
230                }
231            }
232            catch (Exception e) {
233                throw new RuntimeException(
234                                           "unexpected error during the evaluation of the review effort",
235                                           e);
236            }
237            if (Double.compare(testdata.instance(i).classValue(), 1.0d) == 0) {
238                totalBugs++;
239            }
240            totalLoc += testdata.instance(i).value(loc);
241        }
242
243        final List<Double> reviewLoc = new ArrayList<>(testdata.numInstances());
244        final List<Double> bugsFound = new ArrayList<>(testdata.numInstances());
245
246        double currentBugsFound = 0;
247
248        while (!bugPredicted.isEmpty()) {
249            double minLoc = Double.MAX_VALUE;
250            int minIndex = -1;
251            for (int i = 0; i < bugPredicted.size(); i++) {
252                double currentLoc = testdata.instance(bugPredicted.get(i)).value(loc);
253                if (currentLoc < minLoc) {
254                    minIndex = i;
255                    minLoc = currentLoc;
256                }
257            }
258            if (minIndex != -1) {
259                reviewLoc.add(minLoc / totalLoc);
260
261                currentBugsFound += testdata.instance(bugPredicted.get(minIndex)).classValue();
262                bugsFound.add(currentBugsFound);
263
264                bugPredicted.remove(minIndex);
265            }
266            else {
267                throw new RuntimeException("Shouldn't happen!");
268            }
269        }
270
271        while (!nobugPredicted.isEmpty()) {
272            double minLoc = Double.MAX_VALUE;
273            int minIndex = -1;
274            for (int i = 0; i < nobugPredicted.size(); i++) {
275                double currentLoc = testdata.instance(nobugPredicted.get(i)).value(loc);
276                if (currentLoc < minLoc) {
277                    minIndex = i;
278                    minLoc = currentLoc;
279                }
280            }
281            if (minIndex != -1) {
282                reviewLoc.add(minLoc / totalLoc);
283
284                currentBugsFound += testdata.instance(nobugPredicted.get(minIndex)).classValue();
285                bugsFound.add(currentBugsFound);
286                nobugPredicted.remove(minIndex);
287            }
288            else {
289                throw new RuntimeException("Shouldn't happen!");
290            }
291        }
292
293        double auc = 0.0;
294        for (int i = 0; i < bugsFound.size(); i++) {
295            auc += reviewLoc.get(i) * bugsFound.get(i) / totalBugs;
296        }
297
298        return auc;
299    }
300
301    /*
302     * (non-Javadoc)
303     *
304     * @see de.ugoe.cs.cpdp.Parameterizable#setParameter(java.lang.String)
305     */
306    @Override
307    public void setParameter(String parameters) {
308        if (output != null && !outputIsSystemOut) {
309            output.close();
310        }
311        if ("system.out".equals(parameters) || "".equals(parameters)) {
312            output = new PrintWriter(System.out);
313            outputIsSystemOut = true;
314        }
315        else {
316            try {
317                output = new PrintWriter(new FileOutputStream(parameters));
318                outputIsSystemOut = false;
319            }
320            catch (FileNotFoundException e) {
321                throw new RuntimeException(e);
322            }
323        }
324    }
325}
Note: See TracBrowser for help on using the repository browser.