Ignore:
Timestamp:
05/04/16 13:29:37 (9 years ago)
Author:
sherbold
Message:
  • implemented ensemble classifier after Uchigaki et al.
File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/CrossPare/src/de/ugoe/cs/cpdp/wekaclassifier/LogisticEnsemble.java

    r53 r75  
    1515package de.ugoe.cs.cpdp.wekaclassifier; 
    1616 
    17 import java.util.HashSet; 
     17import java.util.Iterator; 
    1818import java.util.LinkedList; 
    1919import java.util.List; 
    20 import java.util.Set; 
    2120 
    2221import weka.classifiers.AbstractClassifier; 
    2322import weka.classifiers.Classifier; 
     23import weka.classifiers.Evaluation; 
    2424import weka.classifiers.functions.Logistic; 
    2525import weka.core.DenseInstance; 
     
    2828 
    2929/** 
    30  * Logistic Ensemble Classifier after Uchigaki et al.  
     30 * Logistic Ensemble Classifier after Uchigaki et al. with some assumptions. It is unclear if these 
     31 * assumptions are true. 
    3132 * 
    32  * TODO comment class 
    3333 * @author Steffen Herbold 
    3434 */ 
    3535public class LogisticEnsemble extends AbstractClassifier { 
    3636 
     37    /** 
     38     * default id 
     39     */ 
    3740    private static final long serialVersionUID = 1L; 
    3841 
    39     private List<Instances> trainingData = null; 
     42    /** 
     43     * list with classifiers 
     44     */ 
     45    private List<Classifier> classifiers = null; 
    4046 
    41     private List<Classifier> classifiers = null; 
    42      
    43     private String[] options;  
     47    /** 
     48     * list with weights for each classifier 
     49     */ 
     50    private List<Double> weights = null; 
    4451 
     52    /** 
     53     * local copy of the options to be passed to the ensemble of logistic classifiers 
     54     */ 
     55    private String[] options; 
     56 
     57    /* 
     58     * (non-Javadoc) 
     59     *  
     60     * @see weka.classifiers.AbstractClassifier#setOptions(java.lang.String[]) 
     61     */ 
    4562    @Override 
    4663    public void setOptions(String[] options) throws Exception { 
    4764        this.options = options; 
    4865    } 
    49      
     66 
     67    /* 
     68     * (non-Javadoc) 
     69     *  
     70     * @see weka.classifiers.AbstractClassifier#distributionForInstance(weka.core.Instance) 
     71     */ 
    5072    @Override 
    51     public double classifyInstance(Instance instance) { 
    52         if (classifiers == null) { 
    53             return 0.0; 
    54         } 
    55  
    56         double classification = 0.0; 
    57         for (int i = 0; i < classifiers.size(); i++) { 
    58             Classifier classifier = classifiers.get(i); 
    59             Instances traindata = trainingData.get(i); 
    60  
    61             Set<String> attributeNames = new HashSet<>(); 
    62             for (int j = 0; j < traindata.numAttributes(); j++) { 
    63                 attributeNames.add(traindata.attribute(j).name()); 
    64             } 
    65  
    66             double[] values = new double[traindata.numAttributes()]; 
    67             int index = 0; 
     73    public double[] distributionForInstance(Instance instance) throws Exception { 
     74        Iterator<Classifier> classifierIter = classifiers.iterator(); 
     75        Iterator<Double> weightIter = weights.iterator(); 
     76        double[] result = new double[2]; 
     77        while (classifierIter.hasNext()) { 
    6878            for (int j = 0; j < instance.numAttributes(); j++) { 
    69                 if (attributeNames.contains(instance.attribute(j).name())) { 
    70                     values[index] = instance.value(j); 
    71                     index++; 
     79                if (j != instance.classIndex()) { 
     80                    Instance copy = new DenseInstance(instance); 
     81                    for (int k = instance.numAttributes() - 1; k >= 0; k--) { 
     82                        if (j != k && k != instance.classIndex()) { 
     83                            copy.deleteAttributeAt(k); 
     84                        } 
     85                    } 
     86                    double[] localResult = classifierIter.next().distributionForInstance(copy); 
     87                    double currentWeight = weightIter.next(); 
     88                    for (int i = 0; i < localResult.length; i++) { 
     89                        result[i] = result[i] + localResult[i] * currentWeight; 
     90                    } 
    7291                } 
    7392            } 
    74  
    75             Instances tmp = new Instances(traindata); 
    76             tmp.clear(); 
    77             Instance instCopy = new DenseInstance(instance.weight(), values); 
    78             instCopy.setDataset(tmp); 
    79             try { 
    80                 classification += classifier.classifyInstance(instCopy); 
    81             } 
    82             catch (Exception e) { 
    83                 throw new RuntimeException("bagging classifier could not classify an instance", e); 
    84             } 
    8593        } 
    86         classification /= classifiers.size(); 
    87         return (classification >= 0.5) ? 1.0 : 0.0; 
     94        return result; 
    8895    } 
    8996 
     97    /* 
     98     * (non-Javadoc) 
     99     *  
     100     * @see weka.classifiers.Classifier#buildClassifier(weka.core.Instances) 
     101     */ 
    90102    @Override 
    91103    public void buildClassifier(Instances traindata) throws Exception { 
    92104        classifiers = new LinkedList<>(); 
    93         for( int j=0 ; j<traindata.numAttributes() ; j++) { 
    94             final Logistic classifier = new Logistic(); 
    95             classifier.setOptions(options); 
    96             final Instances copy = new Instances(traindata); 
    97             for( int k=traindata.numAttributes()-1; k>=0 ; k-- ) { 
    98                 if( j!=k && traindata.classIndex()!=k ) { 
    99                     copy.deleteAttributeAt(k); 
     105        weights = new LinkedList<>(); 
     106        List<Double> weightsTmp = new LinkedList<>(); 
     107        double sumWeights = 0.0; 
     108        for (int j = 0; j < traindata.numAttributes(); j++) { 
     109            if (j != traindata.classIndex()) { 
     110                final Logistic classifier = new Logistic(); 
     111                classifier.setOptions(options); 
     112                final Instances copy = new Instances(traindata); 
     113                for (int k = traindata.numAttributes() - 1; k >= 0; k--) { 
     114                    if (j != k && k != traindata.classIndex()) { 
     115                        copy.deleteAttributeAt(k); 
     116                    } 
    100117                } 
     118                classifier.buildClassifier(copy); 
     119                classifiers.add(classifier); 
     120                Evaluation eval = new Evaluation(copy); 
     121                eval.evaluateModel(classifier, copy); 
     122                weightsTmp.add(eval.matthewsCorrelationCoefficient(1)); 
     123                sumWeights += eval.matthewsCorrelationCoefficient(1); 
    101124            } 
    102             classifier.buildClassifier(copy); 
    103             classifiers.add(classifier); 
     125        } 
     126        for (double tmp : weightsTmp) { 
     127            weights.add(tmp / sumWeights); 
    104128        } 
    105129    } 
Note: See TracChangeset for help on using the changeset viewer.