source: trunk/CrossPare/src/de/ugoe/cs/cpdp/eval/AbstractWekaEvaluation.java @ 10

Last change on this file since 10 was 6, checked in by sherbold, 10 years ago
  • improved Javadocs and exception handling
  • added getName() to ITrainingStrategy
  • Property svn:mime-type set to text/plain
File size: 9.8 KB
Line 
1package de.ugoe.cs.cpdp.eval;
2
3import java.io.FileNotFoundException;
4import java.io.FileOutputStream;
5import java.io.PrintWriter;
6import java.util.ArrayList;
7import java.util.LinkedList;
8import java.util.List;
9
10import de.ugoe.cs.cpdp.training.ITrainer;
11import de.ugoe.cs.cpdp.training.WekaCompatibleTrainer;
12import de.ugoe.cs.util.StringTools;
13import weka.classifiers.Classifier;
14import weka.classifiers.Evaluation;
15import weka.core.Attribute;
16import weka.core.Instances;
17
18/**
19 * Base class for the evaluation of results of classifiers compatible with the {@link Classifier} interface.
20 * For each classifier, the following metrics are calculated:
21 * <ul>
22 *  <li>Success with recall>0.7, precision>0.5</li>
23 *  <li>Success with recall>0.7, precision>0.7</li>
24 *  <li>Success with gscore>0.75</li>
25 *  <li>Success with gscore>0.6</li>
26 *  <li>error rate</li>
27 *  <li>recall</li>
28 *  <li>precision</li>
29 *  <li>fscore</li>
30 *  <li>gscore</li>
31 *  <li>AUC</li>
32 *  <li>AUCEC (weighted by LOC, if applicable; 0.0 if LOC not available)</li>
33 *  <li>true positive rate</li>
34 *  <li>true negative rate</li>
35 *  <li>true positives</li>
36 *  <li>false positives</li>
37 *  <li>true negatives</li>
38 *  <li>false negatives</li>
39 *  <li>training error</li>
40 *  <li>training recall</li>
41 *  <li>training precision</li>
42 *  <li>training success with recall>0.7 and precision>0.5
43 * </ul>
44 * @author Steffen Herbold
45 */
46public abstract class AbstractWekaEvaluation implements IEvaluationStrategy {
47
48        /**
49         * writer for the evaluation results
50         */
51        private PrintWriter output = new PrintWriter(System.out);
52       
53        private boolean outputIsSystemOut = true;
54       
55        /**
56         * Creates the weka evaluator. Allows the creation of the evaluator in different ways, e.g., for cross-validation
57         * or evaluation on the test data.
58         * @param testdata test data
59         * @param classifier classifier used
60         * @return evaluator
61         */
62        protected abstract Evaluation createEvaluator(Instances testdata, Classifier classifier);
63       
64        /*
65         * (non-Javadoc)
66         * @see de.ugoe.cs.cpdp.eval.EvaluationStrategy#apply(weka.core.Instances, weka.core.Instances, java.util.List, boolean)
67         */
68        @Override
69        public void apply(Instances testdata, Instances traindata, List<ITrainer> trainers,
70                        boolean writeHeader) {
71                final List<Classifier> classifiers = new LinkedList<Classifier>();
72                for( ITrainer trainer : trainers ) {
73                        if( trainer instanceof WekaCompatibleTrainer ) {
74                                classifiers.add(((WekaCompatibleTrainer) trainer).getClassifier());
75                        } else {
76                                throw new RuntimeException("The selected evaluator only support Weka classifiers");
77                        }
78                }
79               
80                if( writeHeader ) {
81                        output.append("version,size_test,size_training");
82                        for( ITrainer trainer : trainers ) {
83                                output.append(",succHe_" + ((WekaCompatibleTrainer) trainer).getName());
84                                output.append(",succZi_" + ((WekaCompatibleTrainer) trainer).getName());
85                                output.append(",succG75_" + ((WekaCompatibleTrainer) trainer).getName());
86                                output.append(",succG60_" + ((WekaCompatibleTrainer) trainer).getName());
87                                output.append(",error_" + ((WekaCompatibleTrainer) trainer).getName());
88                                output.append(",recall_" + ((WekaCompatibleTrainer) trainer).getName());
89                                output.append(",precision_" + ((WekaCompatibleTrainer) trainer).getName());
90                                output.append(",fscore_" + ((WekaCompatibleTrainer) trainer).getName());
91                                output.append(",gscore_" + ((WekaCompatibleTrainer) trainer).getName());
92                                output.append(",mcc_" + ((WekaCompatibleTrainer) trainer).getName());
93                                output.append(",auc_" + ((WekaCompatibleTrainer) trainer).getName());
94                                output.append(",aucec_" + ((WekaCompatibleTrainer) trainer).getName());
95                                output.append(",tpr_" + ((WekaCompatibleTrainer) trainer).getName());
96                                output.append(",tnr_" + ((WekaCompatibleTrainer) trainer).getName());
97                                output.append(",tp_" + ((WekaCompatibleTrainer) trainer).getName());
98                                output.append(",fn_" + ((WekaCompatibleTrainer) trainer).getName());
99                                output.append(",tn_" + ((WekaCompatibleTrainer) trainer).getName());
100                                output.append(",fp_" + ((WekaCompatibleTrainer) trainer).getName());
101                                output.append(",trainerror_" + ((WekaCompatibleTrainer) trainer).getName());
102                                output.append(",trainrecall_" + ((WekaCompatibleTrainer) trainer).getName());
103                                output.append(",trainprecision_" + ((WekaCompatibleTrainer) trainer).getName());
104                                output.append(",trainsuccHe_" + ((WekaCompatibleTrainer) trainer).getName());
105                        }
106                        output.append(StringTools.ENDLINE);
107                }
108               
109                output.append(testdata.relationName());
110                output.append("," + testdata.numInstances());   
111                output.append("," + traindata.numInstances());
112               
113                Evaluation eval = null;
114                Evaluation evalTrain = null;
115                for( Classifier classifier : classifiers ) {
116                        eval = createEvaluator(testdata, classifier);
117                        evalTrain = createEvaluator(traindata, classifier);
118                       
119                        double pf = eval.numFalsePositives(1)/(eval.numFalsePositives(1)+eval.numTrueNegatives(1));
120                        double gmeasure = 2*eval.recall(1)*(1.0-pf)/(eval.recall(1)+(1.0-pf));
121                        double mcc = (eval.numTruePositives(1)*eval.numTrueNegatives(1)-eval.numFalsePositives(1)*eval.numFalseNegatives(1))/Math.sqrt((eval.numTruePositives(1)+eval.numFalsePositives(1))*(eval.numTruePositives(1)+eval.numFalseNegatives(1))*(eval.numTrueNegatives(1)+eval.numFalsePositives(1))*(eval.numTrueNegatives(1)+eval.numFalseNegatives(1)));
122                        double aucec = calculateReviewEffort(testdata, classifier);
123                       
124                        if( eval.recall(1)>=0.7 && eval.precision(1) >= 0.5 ) {
125                                output.append(",1");
126                        } else {
127                                output.append(",0");
128                        }
129                       
130                        if( eval.recall(1)>=0.7 && eval.precision(1) >= 0.7 ) {
131                                output.append(",1");
132                        } else {
133                                output.append(",0");
134                        }
135                       
136                        if( gmeasure>0.75 ) {
137                                output.append(",1");
138                        } else {
139                                output.append(",0");
140                        }
141                       
142                        if( gmeasure>0.6 ) {
143                                output.append(",1");
144                        } else {
145                                output.append(",0");
146                        }
147                       
148                        output.append("," + eval.errorRate());
149                        output.append("," + eval.recall(1));
150                        output.append("," + eval.precision(1));
151                        output.append("," + eval.fMeasure(1));
152                        output.append("," + gmeasure);
153                        output.append("," + mcc);
154                        output.append("," + eval.areaUnderROC(1));
155                        output.append("," + aucec);
156                        output.append("," + eval.truePositiveRate(1));
157                        output.append("," + eval.trueNegativeRate(1));
158                        output.append("," + eval.numTruePositives(1));
159                        output.append("," + eval.numFalseNegatives(1));
160                        output.append("," + eval.numTrueNegatives(1));
161                        output.append("," + eval.numFalsePositives(1));
162                        output.append("," + evalTrain.errorRate());
163                        output.append("," + evalTrain.recall(1));
164                        output.append("," + evalTrain.precision(1));
165                        if( evalTrain.recall(1)>=0.7 && evalTrain.precision(1) >= 0.5 ) {
166                                output.append(",1");
167                        } else {
168                                output.append(",0");
169                        }
170                }
171               
172                output.append(StringTools.ENDLINE);
173                output.flush();
174        }
175       
176        private double calculateReviewEffort(Instances testdata, Classifier classifier) {
177               
178                final Attribute loc = testdata.attribute("loc");
179                if( loc==null ) {
180                        return 0.0;
181                }
182                               
183                final List<Integer> bugPredicted = new ArrayList<>();
184                final List<Integer> nobugPredicted = new ArrayList<>();
185                double totalLoc = 0.0d;
186                int totalBugs = 0;
187                for( int i=0 ; i<testdata.numInstances() ; i++ ) {
188                        try {
189                                if( Double.compare(classifier.classifyInstance(testdata.instance(i)),0.0d)==0 ) {
190                                        nobugPredicted.add(i);
191                                } else {
192                                        bugPredicted.add(i);
193                                }
194                        } catch (Exception e) {
195                                throw new RuntimeException("unexpected error during the evaluation of the review effort", e);
196                        }
197                        if(Double.compare(testdata.instance(i).classValue(),1.0d)==0) {
198                                totalBugs++;
199                        }
200                        totalLoc += testdata.instance(i).value(loc);
201                }
202               
203                final List<Double> reviewLoc = new ArrayList<>(testdata.numInstances());
204                final List<Double> bugsFound = new ArrayList<>(testdata.numInstances());
205               
206                double currentBugsFound = 0;
207               
208                while( !bugPredicted.isEmpty() ) {
209                        double minLoc = Double.MAX_VALUE;
210                        int minIndex = -1;
211                        for( int i=0 ; i<bugPredicted.size() ; i++ ) {
212                                double currentLoc = testdata.instance(bugPredicted.get(i)).value(loc);
213                                if( currentLoc<minLoc ) {
214                                        minIndex = i;
215                                        minLoc = currentLoc;
216                                }
217                        }
218                        if( minIndex!=-1 ) {
219                                reviewLoc.add(minLoc/totalLoc);
220                               
221                                currentBugsFound += testdata.instance(bugPredicted.get(minIndex)).classValue();
222                                bugsFound.add(currentBugsFound);
223                               
224                                bugPredicted.remove(minIndex);
225                        } else {
226                                throw new RuntimeException("Shouldn't happen!");
227                        }
228                }
229               
230                while( !nobugPredicted.isEmpty() ) {
231                        double minLoc = Double.MAX_VALUE;
232                        int minIndex = -1;
233                        for( int i=0 ; i<nobugPredicted.size() ; i++ ) {
234                                double currentLoc = testdata.instance(nobugPredicted.get(i)).value(loc);
235                                if( currentLoc<minLoc ) {
236                                        minIndex = i;
237                                        minLoc = currentLoc;
238                                }
239                        }
240                        if( minIndex!=-1 ) {                           
241                                reviewLoc.add(minLoc/totalLoc);
242                               
243                                currentBugsFound += testdata.instance(nobugPredicted.get(minIndex)).classValue();
244                                bugsFound.add(currentBugsFound);
245                                nobugPredicted.remove(minIndex);
246                        } else {
247                                throw new RuntimeException("Shouldn't happen!");
248                        }
249                }
250               
251                double auc = 0.0;
252                for( int i=0 ; i<bugsFound.size() ; i++ ) {
253                        auc += reviewLoc.get(i)*bugsFound.get(i)/totalBugs;
254                }
255               
256                return auc;
257        }
258
259        /*
260         * (non-Javadoc)
261         * @see de.ugoe.cs.cpdp.Parameterizable#setParameter(java.lang.String)
262         */
263        @Override
264        public void setParameter(String parameters) {
265                if( output!=null && !outputIsSystemOut ) {
266                        output.close();
267                }
268                if( "system.out".equals(parameters) || "".equals(parameters) ) {
269                        output = new PrintWriter(System.out);
270                        outputIsSystemOut = true;
271                } else {
272                        try {
273                                output = new PrintWriter(new FileOutputStream(parameters));
274                                outputIsSystemOut = false;
275                        } catch (FileNotFoundException e) {
276                                throw new RuntimeException(e);
277                        }
278                }
279        }
280}
Note: See TracBrowser for help on using the repository browser.