source: trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaLASERTraining.java @ 115

Last change on this file since 115 was 91, checked in by sherbold, 9 years ago
  • fixed bug in WekaLASERTraining
  • Property svn:mime-type set to text/plain
File size: 5.4 KB
Line 
1// Copyright 2015 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.cpdp.training;
16
17import java.util.LinkedList;
18import java.util.List;
19
20import de.ugoe.cs.cpdp.util.WekaUtils;
21import weka.classifiers.AbstractClassifier;
22import weka.classifiers.Classifier;
23import weka.core.Instance;
24import weka.core.Instances;
25
26
27/**
28 * <p>
29 * TODO comment
30 * </p>
31 *
32 * @author Steffen Herbold
33 */
34public class WekaLASERTraining extends WekaBaseTraining implements ITrainingStrategy {
35
36    private final LASERClassifier internalClassifier = new LASERClassifier();
37
38    @Override
39    public Classifier getClassifier() {
40        return internalClassifier;
41    }
42
43    @Override
44    public void apply(Instances traindata) {
45        try {
46            internalClassifier.buildClassifier(traindata);
47        }
48        catch (Exception e) {
49            throw new RuntimeException(e);
50        }
51    }
52
53    public class LASERClassifier extends AbstractClassifier {
54
55        private static final long serialVersionUID = 1L;
56       
57        private Classifier laserClassifier = null;
58        private Instances traindata = null;
59
60        @Override
61        public double classifyInstance(Instance instance) throws Exception {
62            List<Integer> closestInstances = new LinkedList<>();
63            double minDistance = Double.MAX_VALUE;
64            for( int i=0; i<traindata.size(); i++ ) {
65                double distance = WekaUtils.hammingDistance(instance, traindata.get(i));
66                if( distance<minDistance) {
67                    minDistance = distance;
68                }
69            }
70            for( int i=0; i<traindata.size(); i++ ) {
71                double distance = WekaUtils.hammingDistance(instance, traindata.get(i));
72                if( distance<=minDistance ) {
73                    closestInstances.add(i);
74                }
75            }
76            if( closestInstances.size()==1 ) {
77                int closestIndex = closestInstances.get(0);
78                Instance closestTrainingInstance = traindata.get(closestIndex);
79                List<Integer> closestToTrainingInstance = new LinkedList<>();
80                double minTrainingDistance = Double.MAX_VALUE;
81                for( int i=0; i<traindata.size(); i++ ) {
82                    if( closestIndex!=i ) {
83                        double distance = WekaUtils.hammingDistance(closestTrainingInstance, traindata.get(i));
84                        if( distance<minTrainingDistance ) {
85                            minTrainingDistance = distance;
86                        }
87                    }
88                }
89                for( int i=0; i<traindata.size(); i++ ) {
90                    if( closestIndex!=i ) {
91                        double distance = WekaUtils.hammingDistance(closestTrainingInstance, traindata.get(i));
92                        if( distance<=minTrainingDistance ) {
93                            closestToTrainingInstance.add(i);
94                        }
95                    }
96                }
97                if( closestToTrainingInstance.size()==1 ) {
98                    return laserClassifier.classifyInstance(instance);
99                }
100                else {
101                    double label = Double.NaN;
102                    boolean allEqual = true;
103                    for( Integer index : closestToTrainingInstance ) {
104                        if( Double.isNaN(label) ) {
105                            label = traindata.get(index).classValue();
106                        }
107                        else if( label!=traindata.get(index).classValue() ) {
108                            allEqual = false;
109                            break;
110                        }
111                    }
112                    if( allEqual ) {
113                        return label;
114                    }
115                    else {
116                        return laserClassifier.classifyInstance(instance);
117                    }
118                }
119            } else {
120                double label = Double.NaN;
121                boolean allEqual = true;
122                for( Integer index : closestInstances ) {
123                    if( Double.isNaN(label) ) {
124                        label = traindata.get(index).classValue();
125                    }
126                    else if( label!=traindata.get(index).classValue() ) {
127                        allEqual = false;
128                        break;
129                    }
130                }
131                if( allEqual ) {
132                    return label;
133                }
134                else {
135                    return laserClassifier.classifyInstance(instance);
136                }
137            }
138        }
139
140        @Override
141        public void buildClassifier(Instances traindata) throws Exception {
142            this.traindata = new Instances(traindata);
143            laserClassifier = setupClassifier();
144            laserClassifier.buildClassifier(traindata);
145        }
146    }
147}
Note: See TracBrowser for help on using the repository browser.