source: trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaLASERTraining.java

Last change on this file was 135, checked in by sherbold, 8 years ago
  • code documentation and formatting
  • Property svn:mime-type set to text/plain
File size: 6.4 KB
Line 
1// Copyright 2015 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.cpdp.training;
16
17import java.util.LinkedList;
18import java.util.List;
19
20import de.ugoe.cs.cpdp.util.WekaUtils;
21import weka.classifiers.AbstractClassifier;
22import weka.classifiers.Classifier;
23import weka.core.Instance;
24import weka.core.Instances;
25
26/**
27 * <p>
28 * Implements training following the LASER classification scheme.
29 * </p>
30 *
31 * @author Steffen Herbold
32 */
33public class WekaLASERTraining extends WekaBaseTraining implements ITrainingStrategy {
34
35    /**
36     * Internal classifier used for LASER.
37     */
38    private final LASERClassifier internalClassifier = new LASERClassifier();
39
40    /*
41     * (non-Javadoc)
42     *
43     * @see de.ugoe.cs.cpdp.training.WekaBaseTraining#getClassifier()
44     */
45    @Override
46    public Classifier getClassifier() {
47        return internalClassifier;
48    }
49
50    /*
51     * (non-Javadoc)
52     *
53     * @see de.ugoe.cs.cpdp.training.ITrainingStrategy#apply(weka.core.Instances)
54     */
55    @Override
56    public void apply(Instances traindata) {
57        try {
58            internalClassifier.buildClassifier(traindata);
59        }
60        catch (Exception e) {
61            throw new RuntimeException(e);
62        }
63    }
64
65    /**
66     * <p>
67     * Internal helper class that defines the laser classifier.
68     * </p>
69     *
70     * @author Steffen Herbold
71     */
72    public class LASERClassifier extends AbstractClassifier {
73
74        /**
75         * Default serial ID.
76         */
77        private static final long serialVersionUID = 1L;
78
79        /**
80         * Internal reference to the classifier.
81         */
82        private Classifier laserClassifier = null;
83
84        /**
85         * Internal storage of the training data required for NN analysis.
86         */
87        private Instances traindata = null;
88
89        /*
90         * (non-Javadoc)
91         *
92         * @see weka.classifiers.AbstractClassifier#classifyInstance(weka.core.Instance)
93         */
94        @Override
95        public double classifyInstance(Instance instance) throws Exception {
96            List<Integer> closestInstances = new LinkedList<>();
97            double minDistance = Double.MAX_VALUE;
98            for (int i = 0; i < traindata.size(); i++) {
99                double distance = WekaUtils.hammingDistance(instance, traindata.get(i));
100                if (distance < minDistance) {
101                    minDistance = distance;
102                }
103            }
104            for (int i = 0; i < traindata.size(); i++) {
105                double distance = WekaUtils.hammingDistance(instance, traindata.get(i));
106                if (distance <= minDistance) {
107                    closestInstances.add(i);
108                }
109            }
110            if (closestInstances.size() == 1) {
111                int closestIndex = closestInstances.get(0);
112                Instance closestTrainingInstance = traindata.get(closestIndex);
113                List<Integer> closestToTrainingInstance = new LinkedList<>();
114                double minTrainingDistance = Double.MAX_VALUE;
115                for (int i = 0; i < traindata.size(); i++) {
116                    if (closestIndex != i) {
117                        double distance =
118                            WekaUtils.hammingDistance(closestTrainingInstance, traindata.get(i));
119                        if (distance < minTrainingDistance) {
120                            minTrainingDistance = distance;
121                        }
122                    }
123                }
124                for (int i = 0; i < traindata.size(); i++) {
125                    if (closestIndex != i) {
126                        double distance =
127                            WekaUtils.hammingDistance(closestTrainingInstance, traindata.get(i));
128                        if (distance <= minTrainingDistance) {
129                            closestToTrainingInstance.add(i);
130                        }
131                    }
132                }
133                if (closestToTrainingInstance.size() == 1) {
134                    return laserClassifier.classifyInstance(instance);
135                }
136                else {
137                    double label = Double.NaN;
138                    boolean allEqual = true;
139                    for (Integer index : closestToTrainingInstance) {
140                        if (Double.isNaN(label)) {
141                            label = traindata.get(index).classValue();
142                        }
143                        else if (label != traindata.get(index).classValue()) {
144                            allEqual = false;
145                            break;
146                        }
147                    }
148                    if (allEqual) {
149                        return label;
150                    }
151                    else {
152                        return laserClassifier.classifyInstance(instance);
153                    }
154                }
155            }
156            else {
157                double label = Double.NaN;
158                boolean allEqual = true;
159                for (Integer index : closestInstances) {
160                    if (Double.isNaN(label)) {
161                        label = traindata.get(index).classValue();
162                    }
163                    else if (label != traindata.get(index).classValue()) {
164                        allEqual = false;
165                        break;
166                    }
167                }
168                if (allEqual) {
169                    return label;
170                }
171                else {
172                    return laserClassifier.classifyInstance(instance);
173                }
174            }
175        }
176
177        /*
178         * (non-Javadoc)
179         *
180         * @see weka.classifiers.Classifier#buildClassifier(weka.core.Instances)
181         */
182        @Override
183        public void buildClassifier(Instances traindata) throws Exception {
184            this.traindata = new Instances(traindata);
185            laserClassifier = setupClassifier();
186            laserClassifier.buildClassifier(traindata);
187        }
188    }
189}
Note: See TracBrowser for help on using the repository browser.