source: trunk/CrossPare/src/de/ugoe/cs/cpdp/wekaclassifier/LogisticEnsemble.java @ 78

Last change on this file since 78 was 75, checked in by sherbold, 9 years ago
  • implemented ensemble classifier after Uchigaki et al.
  • Property svn:mime-type set to text/plain
File size: 4.5 KB
Line 
1// Copyright 2015 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.cpdp.wekaclassifier;
16
17import java.util.Iterator;
18import java.util.LinkedList;
19import java.util.List;
20
21import weka.classifiers.AbstractClassifier;
22import weka.classifiers.Classifier;
23import weka.classifiers.Evaluation;
24import weka.classifiers.functions.Logistic;
25import weka.core.DenseInstance;
26import weka.core.Instance;
27import weka.core.Instances;
28
29/**
30 * Logistic Ensemble Classifier after Uchigaki et al. with some assumptions. It is unclear if these
31 * assumptions are true.
32 *
33 * @author Steffen Herbold
34 */
35public class LogisticEnsemble extends AbstractClassifier {
36
37    /**
38     * default id
39     */
40    private static final long serialVersionUID = 1L;
41
42    /**
43     * list with classifiers
44     */
45    private List<Classifier> classifiers = null;
46
47    /**
48     * list with weights for each classifier
49     */
50    private List<Double> weights = null;
51
52    /**
53     * local copy of the options to be passed to the ensemble of logistic classifiers
54     */
55    private String[] options;
56
57    /*
58     * (non-Javadoc)
59     *
60     * @see weka.classifiers.AbstractClassifier#setOptions(java.lang.String[])
61     */
62    @Override
63    public void setOptions(String[] options) throws Exception {
64        this.options = options;
65    }
66
67    /*
68     * (non-Javadoc)
69     *
70     * @see weka.classifiers.AbstractClassifier#distributionForInstance(weka.core.Instance)
71     */
72    @Override
73    public double[] distributionForInstance(Instance instance) throws Exception {
74        Iterator<Classifier> classifierIter = classifiers.iterator();
75        Iterator<Double> weightIter = weights.iterator();
76        double[] result = new double[2];
77        while (classifierIter.hasNext()) {
78            for (int j = 0; j < instance.numAttributes(); j++) {
79                if (j != instance.classIndex()) {
80                    Instance copy = new DenseInstance(instance);
81                    for (int k = instance.numAttributes() - 1; k >= 0; k--) {
82                        if (j != k && k != instance.classIndex()) {
83                            copy.deleteAttributeAt(k);
84                        }
85                    }
86                    double[] localResult = classifierIter.next().distributionForInstance(copy);
87                    double currentWeight = weightIter.next();
88                    for (int i = 0; i < localResult.length; i++) {
89                        result[i] = result[i] + localResult[i] * currentWeight;
90                    }
91                }
92            }
93        }
94        return result;
95    }
96
97    /*
98     * (non-Javadoc)
99     *
100     * @see weka.classifiers.Classifier#buildClassifier(weka.core.Instances)
101     */
102    @Override
103    public void buildClassifier(Instances traindata) throws Exception {
104        classifiers = new LinkedList<>();
105        weights = new LinkedList<>();
106        List<Double> weightsTmp = new LinkedList<>();
107        double sumWeights = 0.0;
108        for (int j = 0; j < traindata.numAttributes(); j++) {
109            if (j != traindata.classIndex()) {
110                final Logistic classifier = new Logistic();
111                classifier.setOptions(options);
112                final Instances copy = new Instances(traindata);
113                for (int k = traindata.numAttributes() - 1; k >= 0; k--) {
114                    if (j != k && k != traindata.classIndex()) {
115                        copy.deleteAttributeAt(k);
116                    }
117                }
118                classifier.buildClassifier(copy);
119                classifiers.add(classifier);
120                Evaluation eval = new Evaluation(copy);
121                eval.evaluateModel(classifier, copy);
122                weightsTmp.add(eval.matthewsCorrelationCoefficient(1));
123                sumWeights += eval.matthewsCorrelationCoefficient(1);
124            }
125        }
126        for (double tmp : weightsTmp) {
127            weights.add(tmp / sumWeights);
128        }
129    }
130}
Note: See TracBrowser for help on using the repository browser.