package de.ugoe.cs.cpdp.eval; import java.io.FileNotFoundException; import java.io.FileOutputStream; import java.io.PrintWriter; import java.util.ArrayList; import java.util.LinkedList; import java.util.List; import de.ugoe.cs.cpdp.training.ITrainer; import de.ugoe.cs.cpdp.training.IWekaCompatibleTrainer; import de.ugoe.cs.util.StringTools; import weka.classifiers.Classifier; import weka.classifiers.Evaluation; import weka.core.Attribute; import weka.core.Instances; /** * Base class for the evaluation of results of classifiers compatible with the {@link Classifier} interface. * For each classifier, the following metrics are calculated: * * @author Steffen Herbold */ public abstract class AbstractWekaEvaluation implements IEvaluationStrategy { /** * writer for the evaluation results */ private PrintWriter output = new PrintWriter(System.out); private boolean outputIsSystemOut = true; /** * Creates the weka evaluator. Allows the creation of the evaluator in different ways, e.g., for cross-validation * or evaluation on the test data. * @param testdata test data * @param classifier classifier used * @return evaluator */ protected abstract Evaluation createEvaluator(Instances testdata, Classifier classifier); /* * (non-Javadoc) * @see de.ugoe.cs.cpdp.eval.EvaluationStrategy#apply(weka.core.Instances, weka.core.Instances, java.util.List, boolean) */ @Override public void apply(Instances testdata, Instances traindata, List trainers, boolean writeHeader) { final List classifiers = new LinkedList(); for( ITrainer trainer : trainers ) { if( trainer instanceof IWekaCompatibleTrainer ) { classifiers.add(((IWekaCompatibleTrainer) trainer).getClassifier()); } else { throw new RuntimeException("The selected evaluator only support Weka classifiers"); } } if( writeHeader ) { output.append("version,size_test,size_training"); for( ITrainer trainer : trainers ) { output.append(",succHe_" + ((IWekaCompatibleTrainer) trainer).getName()); output.append(",succZi_" + ((IWekaCompatibleTrainer) trainer).getName()); output.append(",succG75_" + ((IWekaCompatibleTrainer) trainer).getName()); output.append(",succG60_" + ((IWekaCompatibleTrainer) trainer).getName()); output.append(",error_" + ((IWekaCompatibleTrainer) trainer).getName()); output.append(",recall_" + ((IWekaCompatibleTrainer) trainer).getName()); output.append(",precision_" + ((IWekaCompatibleTrainer) trainer).getName()); output.append(",fscore_" + ((IWekaCompatibleTrainer) trainer).getName()); output.append(",gscore_" + ((IWekaCompatibleTrainer) trainer).getName()); output.append(",mcc_" + ((IWekaCompatibleTrainer) trainer).getName()); output.append(",auc_" + ((IWekaCompatibleTrainer) trainer).getName()); output.append(",aucec_" + ((IWekaCompatibleTrainer) trainer).getName()); output.append(",tpr_" + ((IWekaCompatibleTrainer) trainer).getName()); output.append(",tnr_" + ((IWekaCompatibleTrainer) trainer).getName()); output.append(",tp_" + ((IWekaCompatibleTrainer) trainer).getName()); output.append(",fn_" + ((IWekaCompatibleTrainer) trainer).getName()); output.append(",tn_" + ((IWekaCompatibleTrainer) trainer).getName()); output.append(",fp_" + ((IWekaCompatibleTrainer) trainer).getName()); output.append(",trainerror_" + ((IWekaCompatibleTrainer) trainer).getName()); output.append(",trainrecall_" + ((IWekaCompatibleTrainer) trainer).getName()); output.append(",trainprecision_" + ((IWekaCompatibleTrainer) trainer).getName()); output.append(",trainsuccHe_" + ((IWekaCompatibleTrainer) trainer).getName()); } output.append(StringTools.ENDLINE); } output.append(testdata.relationName()); output.append("," + testdata.numInstances()); output.append("," + traindata.numInstances()); Evaluation eval = null; Evaluation evalTrain = null; for( Classifier classifier : classifiers ) { eval = createEvaluator(testdata, classifier); evalTrain = createEvaluator(traindata, classifier); double pf = eval.numFalsePositives(1)/(eval.numFalsePositives(1)+eval.numTrueNegatives(1)); double gmeasure = 2*eval.recall(1)*(1.0-pf)/(eval.recall(1)+(1.0-pf)); double mcc = (eval.numTruePositives(1)*eval.numTrueNegatives(1)-eval.numFalsePositives(1)*eval.numFalseNegatives(1))/Math.sqrt((eval.numTruePositives(1)+eval.numFalsePositives(1))*(eval.numTruePositives(1)+eval.numFalseNegatives(1))*(eval.numTrueNegatives(1)+eval.numFalsePositives(1))*(eval.numTrueNegatives(1)+eval.numFalseNegatives(1))); double aucec = calculateReviewEffort(testdata, classifier); if( eval.recall(1)>=0.7 && eval.precision(1) >= 0.5 ) { output.append(",1"); } else { output.append(",0"); } if( eval.recall(1)>=0.7 && eval.precision(1) >= 0.7 ) { output.append(",1"); } else { output.append(",0"); } if( gmeasure>0.75 ) { output.append(",1"); } else { output.append(",0"); } if( gmeasure>0.6 ) { output.append(",1"); } else { output.append(",0"); } output.append("," + eval.errorRate()); output.append("," + eval.recall(1)); output.append("," + eval.precision(1)); output.append("," + eval.fMeasure(1)); output.append("," + gmeasure); output.append("," + mcc); output.append("," + eval.areaUnderROC(1)); output.append("," + aucec); output.append("," + eval.truePositiveRate(1)); output.append("," + eval.trueNegativeRate(1)); output.append("," + eval.numTruePositives(1)); output.append("," + eval.numFalseNegatives(1)); output.append("," + eval.numTrueNegatives(1)); output.append("," + eval.numFalsePositives(1)); output.append("," + evalTrain.errorRate()); output.append("," + evalTrain.recall(1)); output.append("," + evalTrain.precision(1)); if( evalTrain.recall(1)>=0.7 && evalTrain.precision(1) >= 0.5 ) { output.append(",1"); } else { output.append(",0"); } } output.append(StringTools.ENDLINE); output.flush(); } private double calculateReviewEffort(Instances testdata, Classifier classifier) { final Attribute loc = testdata.attribute("loc"); if( loc==null ) { return 0.0; } final List bugPredicted = new ArrayList<>(); final List nobugPredicted = new ArrayList<>(); double totalLoc = 0.0d; int totalBugs = 0; for( int i=0 ; i reviewLoc = new ArrayList<>(testdata.numInstances()); final List bugsFound = new ArrayList<>(testdata.numInstances()); double currentBugsFound = 0; while( !bugPredicted.isEmpty() ) { double minLoc = Double.MAX_VALUE; int minIndex = -1; for( int i=0 ; i