source: trunk/CrossPare/src/de/ugoe/cs/cpdp/eval/AbstractWekaEvaluation.java @ 64

Last change on this file since 64 was 63, checked in by sherbold, 9 years ago
  • added fpr and fnr as metrics; using MCC now directly from weka
  • Property svn:mime-type set to text/plain
File size: 11.9 KB
Line 
1// Copyright 2015 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.cpdp.eval;
16
17import java.io.FileNotFoundException;
18import java.io.FileOutputStream;
19import java.io.PrintWriter;
20import java.util.ArrayList;
21import java.util.LinkedList;
22import java.util.List;
23
24import de.ugoe.cs.cpdp.training.ITrainer;
25import de.ugoe.cs.cpdp.training.IWekaCompatibleTrainer;
26import de.ugoe.cs.util.StringTools;
27import weka.classifiers.Classifier;
28import weka.classifiers.Evaluation;
29import weka.core.Attribute;
30import weka.core.Instances;
31
32/**
33 * Base class for the evaluation of results of classifiers compatible with the {@link Classifier}
34 * interface. For each classifier, the following metrics are calculated:
35 * <ul>
36 * <li>succHe: Success with recall>0.7, precision>0.5</li>
37 * <li>succZi: Success with recall>0.7, precision>0.7</li>
38 * <li>succG75: Success with gscore>0.75</li>
39 * <li>succG60: Success with gscore>0.6</li>
40 * <li>error</li>
41 * <li>recall</li>
42 * <li>precision</li>
43 * <li>fscore</li>
44 * <li>gscore</li>
45 * <li>AUC</li>
46 * <li>AUCEC (weighted by LOC, if applicable; 0.0 if LOC not available)</li>
47 * <li>tpr: true positive rate</li>
48 * <li>tnr: true negative rate</li>
49 * <li>tp: true positives</li>
50 * <li>fp: false positives</li>
51 * <li>tn: true negatives</li>
52 * <li>fn: false negatives</li>
53 * <li>errortrain: training error</li>
54 * <li>recalltrain: training recall</li>
55 * <li>precisiontrain: training precision</li>
56 * <li>succHetrain: training success with recall>0.7 and precision>0.5
57 * </ul>
58 *
59 * @author Steffen Herbold
60 */
61public abstract class AbstractWekaEvaluation implements IEvaluationStrategy {
62
63    /**
64     * writer for the evaluation results
65     */
66    private PrintWriter output = new PrintWriter(System.out);
67
68    private boolean outputIsSystemOut = true;
69
70    /**
71     * Creates the weka evaluator. Allows the creation of the evaluator in different ways, e.g., for
72     * cross-validation or evaluation on the test data.
73     *
74     * @param testdata
75     *            test data
76     * @param classifier
77     *            classifier used
78     * @return evaluator
79     */
80    protected abstract Evaluation createEvaluator(Instances testdata, Classifier classifier);
81
82    /*
83     * (non-Javadoc)
84     *
85     * @see de.ugoe.cs.cpdp.eval.EvaluationStrategy#apply(weka.core.Instances, weka.core.Instances,
86     * java.util.List, boolean)
87     */
88    @Override
89    public void apply(Instances testdata,
90                      Instances traindata,
91                      List<ITrainer> trainers,
92                      boolean writeHeader)
93    {
94        final List<Classifier> classifiers = new LinkedList<Classifier>();
95        for (ITrainer trainer : trainers) {
96            if (trainer instanceof IWekaCompatibleTrainer) {
97                classifiers.add(((IWekaCompatibleTrainer) trainer).getClassifier());
98            }
99            else {
100                throw new RuntimeException("The selected evaluator only support Weka classifiers");
101            }
102        }
103
104        if (writeHeader) {
105            output.append("version,size_test,size_training");
106            for (ITrainer trainer : trainers) {
107                output.append(",succHe_" + ((IWekaCompatibleTrainer) trainer).getName());
108                output.append(",succZi_" + ((IWekaCompatibleTrainer) trainer).getName());
109                output.append(",succG75_" + ((IWekaCompatibleTrainer) trainer).getName());
110                output.append(",succG60_" + ((IWekaCompatibleTrainer) trainer).getName());
111                output.append(",error_" + ((IWekaCompatibleTrainer) trainer).getName());
112                output.append(",recall_" + ((IWekaCompatibleTrainer) trainer).getName());
113                output.append(",precision_" + ((IWekaCompatibleTrainer) trainer).getName());
114                output.append(",fscore_" + ((IWekaCompatibleTrainer) trainer).getName());
115                output.append(",gscore_" + ((IWekaCompatibleTrainer) trainer).getName());
116                output.append(",mcc_" + ((IWekaCompatibleTrainer) trainer).getName());
117                output.append(",auc_" + ((IWekaCompatibleTrainer) trainer).getName());
118                output.append(",aucec_" + ((IWekaCompatibleTrainer) trainer).getName());
119                output.append(",tpr_" + ((IWekaCompatibleTrainer) trainer).getName());
120                output.append(",tnr_" + ((IWekaCompatibleTrainer) trainer).getName());
121                output.append(",fpr_" + ((IWekaCompatibleTrainer) trainer).getName());
122                output.append(",fnr_" + ((IWekaCompatibleTrainer) trainer).getName());
123                output.append(",tp_" + ((IWekaCompatibleTrainer) trainer).getName());
124                output.append(",fn_" + ((IWekaCompatibleTrainer) trainer).getName());
125                output.append(",tn_" + ((IWekaCompatibleTrainer) trainer).getName());
126                output.append(",fp_" + ((IWekaCompatibleTrainer) trainer).getName());
127            }
128            output.append(StringTools.ENDLINE);
129        }
130
131        output.append(testdata.relationName());
132        output.append("," + testdata.numInstances());
133        output.append("," + traindata.numInstances());
134
135        Evaluation eval = null;
136        //Evaluation evalTrain = null;
137        for (Classifier classifier : classifiers) {
138            eval = createEvaluator(testdata, classifier);
139            //evalTrain = createEvaluator(traindata, classifier);
140
141            double pf =
142                eval.numFalsePositives(1) / (eval.numFalsePositives(1) + eval.numTrueNegatives(1));
143            double gmeasure = 2 * eval.recall(1) * (1.0 - pf) / (eval.recall(1) + (1.0 - pf));
144            double aucec = calculateReviewEffort(testdata, classifier);
145
146            if (eval.recall(1) >= 0.7 && eval.precision(1) >= 0.5) {
147                output.append(",1");
148            }
149            else {
150                output.append(",0");
151            }
152
153            if (eval.recall(1) >= 0.7 && eval.precision(1) >= 0.7) {
154                output.append(",1");
155            }
156            else {
157                output.append(",0");
158            }
159
160            if (gmeasure > 0.75) {
161                output.append(",1");
162            }
163            else {
164                output.append(",0");
165            }
166
167            if (gmeasure > 0.6) {
168                output.append(",1");
169            }
170            else {
171                output.append(",0");
172            }
173
174            output.append("," + eval.errorRate());
175            output.append("," + eval.recall(1));
176            output.append("," + eval.precision(1));
177            output.append("," + eval.fMeasure(1));
178            output.append("," + gmeasure);
179            output.append("," + eval.matthewsCorrelationCoefficient(1));
180            output.append("," + eval.areaUnderROC(1));
181            output.append("," + aucec);
182            output.append("," + eval.truePositiveRate(1));
183            output.append("," + eval.trueNegativeRate(1));
184            output.append("," + eval.falsePositiveRate(1));
185            output.append("," + eval.falseNegativeRate(1));
186            output.append("," + eval.numTruePositives(1));
187            output.append("," + eval.numFalseNegatives(1));
188            output.append("," + eval.numTrueNegatives(1));
189            output.append("," + eval.numFalsePositives(1));
190        }
191
192        output.append(StringTools.ENDLINE);
193        output.flush();
194    }
195
196    private double calculateReviewEffort(Instances testdata, Classifier classifier) {
197
198        final Attribute loc = testdata.attribute("loc");
199        if (loc == null) {
200            return 0.0;
201        }
202
203        final List<Integer> bugPredicted = new ArrayList<>();
204        final List<Integer> nobugPredicted = new ArrayList<>();
205        double totalLoc = 0.0d;
206        int totalBugs = 0;
207        for (int i = 0; i < testdata.numInstances(); i++) {
208            try {
209                if (Double.compare(classifier.classifyInstance(testdata.instance(i)), 0.0d) == 0) {
210                    nobugPredicted.add(i);
211                }
212                else {
213                    bugPredicted.add(i);
214                }
215            }
216            catch (Exception e) {
217                throw new RuntimeException(
218                                           "unexpected error during the evaluation of the review effort",
219                                           e);
220            }
221            if (Double.compare(testdata.instance(i).classValue(), 1.0d) == 0) {
222                totalBugs++;
223            }
224            totalLoc += testdata.instance(i).value(loc);
225        }
226
227        final List<Double> reviewLoc = new ArrayList<>(testdata.numInstances());
228        final List<Double> bugsFound = new ArrayList<>(testdata.numInstances());
229
230        double currentBugsFound = 0;
231
232        while (!bugPredicted.isEmpty()) {
233            double minLoc = Double.MAX_VALUE;
234            int minIndex = -1;
235            for (int i = 0; i < bugPredicted.size(); i++) {
236                double currentLoc = testdata.instance(bugPredicted.get(i)).value(loc);
237                if (currentLoc < minLoc) {
238                    minIndex = i;
239                    minLoc = currentLoc;
240                }
241            }
242            if (minIndex != -1) {
243                reviewLoc.add(minLoc / totalLoc);
244
245                currentBugsFound += testdata.instance(bugPredicted.get(minIndex)).classValue();
246                bugsFound.add(currentBugsFound);
247
248                bugPredicted.remove(minIndex);
249            }
250            else {
251                throw new RuntimeException("Shouldn't happen!");
252            }
253        }
254
255        while (!nobugPredicted.isEmpty()) {
256            double minLoc = Double.MAX_VALUE;
257            int minIndex = -1;
258            for (int i = 0; i < nobugPredicted.size(); i++) {
259                double currentLoc = testdata.instance(nobugPredicted.get(i)).value(loc);
260                if (currentLoc < minLoc) {
261                    minIndex = i;
262                    minLoc = currentLoc;
263                }
264            }
265            if (minIndex != -1) {
266                reviewLoc.add(minLoc / totalLoc);
267
268                currentBugsFound += testdata.instance(nobugPredicted.get(minIndex)).classValue();
269                bugsFound.add(currentBugsFound);
270                nobugPredicted.remove(minIndex);
271            }
272            else {
273                throw new RuntimeException("Shouldn't happen!");
274            }
275        }
276
277        double auc = 0.0;
278        for (int i = 0; i < bugsFound.size(); i++) {
279            auc += reviewLoc.get(i) * bugsFound.get(i) / totalBugs;
280        }
281
282        return auc;
283    }
284
285    /*
286     * (non-Javadoc)
287     *
288     * @see de.ugoe.cs.cpdp.Parameterizable#setParameter(java.lang.String)
289     */
290    @Override
291    public void setParameter(String parameters) {
292        if (output != null && !outputIsSystemOut) {
293            output.close();
294        }
295        if ("system.out".equals(parameters) || "".equals(parameters)) {
296            output = new PrintWriter(System.out);
297            outputIsSystemOut = true;
298        }
299        else {
300            try {
301                output = new PrintWriter(new FileOutputStream(parameters));
302                outputIsSystemOut = false;
303            }
304            catch (FileNotFoundException e) {
305                throw new RuntimeException(e);
306            }
307        }
308    }
309}
Note: See TracBrowser for help on using the repository browser.