source: trunk/CrossPare/src/de/ugoe/cs/cpdp/wekaclassifier/LogisticEnsemble.java @ 69

Last change on this file since 69 was 53, checked in by sherbold, 9 years ago
  • added logistic ensemble classifier after Uchigaki et al.
  • Property svn:mime-type set to text/plain
File size: 3.6 KB
Line 
1// Copyright 2015 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.cpdp.wekaclassifier;
16
17import java.util.HashSet;
18import java.util.LinkedList;
19import java.util.List;
20import java.util.Set;
21
22import weka.classifiers.AbstractClassifier;
23import weka.classifiers.Classifier;
24import weka.classifiers.functions.Logistic;
25import weka.core.DenseInstance;
26import weka.core.Instance;
27import weka.core.Instances;
28
29/**
30 * Logistic Ensemble Classifier after Uchigaki et al.
31 *
32 * TODO comment class
33 * @author Steffen Herbold
34 */
35public class LogisticEnsemble extends AbstractClassifier {
36
37    private static final long serialVersionUID = 1L;
38
39    private List<Instances> trainingData = null;
40
41    private List<Classifier> classifiers = null;
42   
43    private String[] options;
44
45    @Override
46    public void setOptions(String[] options) throws Exception {
47        this.options = options;
48    }
49   
50    @Override
51    public double classifyInstance(Instance instance) {
52        if (classifiers == null) {
53            return 0.0;
54        }
55
56        double classification = 0.0;
57        for (int i = 0; i < classifiers.size(); i++) {
58            Classifier classifier = classifiers.get(i);
59            Instances traindata = trainingData.get(i);
60
61            Set<String> attributeNames = new HashSet<>();
62            for (int j = 0; j < traindata.numAttributes(); j++) {
63                attributeNames.add(traindata.attribute(j).name());
64            }
65
66            double[] values = new double[traindata.numAttributes()];
67            int index = 0;
68            for (int j = 0; j < instance.numAttributes(); j++) {
69                if (attributeNames.contains(instance.attribute(j).name())) {
70                    values[index] = instance.value(j);
71                    index++;
72                }
73            }
74
75            Instances tmp = new Instances(traindata);
76            tmp.clear();
77            Instance instCopy = new DenseInstance(instance.weight(), values);
78            instCopy.setDataset(tmp);
79            try {
80                classification += classifier.classifyInstance(instCopy);
81            }
82            catch (Exception e) {
83                throw new RuntimeException("bagging classifier could not classify an instance", e);
84            }
85        }
86        classification /= classifiers.size();
87        return (classification >= 0.5) ? 1.0 : 0.0;
88    }
89
90    @Override
91    public void buildClassifier(Instances traindata) throws Exception {
92        classifiers = new LinkedList<>();
93        for( int j=0 ; j<traindata.numAttributes() ; j++) {
94            final Logistic classifier = new Logistic();
95            classifier.setOptions(options);
96            final Instances copy = new Instances(traindata);
97            for( int k=traindata.numAttributes()-1; k>=0 ; k-- ) {
98                if( j!=k && traindata.classIndex()!=k ) {
99                    copy.deleteAttributeAt(k);
100                }
101            }
102            classifier.buildClassifier(copy);
103            classifiers.add(classifier);
104        }
105    }
106}
Note: See TracBrowser for help on using the repository browser.