source: trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaLASERTraining.java @ 73

Last change on this file since 73 was 64, checked in by sherbold, 9 years ago
  • added some new approaches
  • Property svn:mime-type set to text/plain
File size: 5.7 KB
Line 
1// Copyright 2015 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.cpdp.training;
16
17import java.io.PrintStream;
18import java.util.LinkedList;
19import java.util.List;
20
21import org.apache.commons.io.output.NullOutputStream;
22
23import de.ugoe.cs.cpdp.util.WekaUtils;
24import weka.classifiers.AbstractClassifier;
25import weka.classifiers.Classifier;
26import weka.core.Instance;
27import weka.core.Instances;
28
29
30/**
31 * <p>
32 * TODO comment
33 * </p>
34 *
35 * @author Steffen Herbold
36 */
37public class WekaLASERTraining extends WekaBaseTraining implements ITrainingStrategy {
38
39    private final LASERClassifier internalClassifier = new LASERClassifier();
40
41    @Override
42    public Classifier getClassifier() {
43        return internalClassifier;
44    }
45
46    @Override
47    public void apply(Instances traindata) {
48        PrintStream errStr = System.err;
49        System.setErr(new PrintStream(new NullOutputStream()));
50        try {
51            internalClassifier.buildClassifier(traindata);
52        }
53        catch (Exception e) {
54            throw new RuntimeException(e);
55        }
56        finally {
57            System.setErr(errStr);
58        }
59    }
60
61    public class LASERClassifier extends AbstractClassifier {
62
63        private static final long serialVersionUID = 1L;
64       
65        private Classifier laserClassifier = null;
66        private Instances traindata = null;
67
68        @Override
69        public double classifyInstance(Instance instance) throws Exception {
70            List<Integer> closestInstances = new LinkedList<>();
71            double minDistance = Double.MAX_VALUE;
72            for( int i=0; i<traindata.size(); i++ ) {
73                double distance = WekaUtils.hammingDistance(instance, traindata.get(i));
74                if( distance<minDistance) {
75                    minDistance = distance;
76                }
77            }
78            for( int i=0; i<traindata.size(); i++ ) {
79                double distance = WekaUtils.hammingDistance(instance, traindata.get(i));
80                if( distance<=minDistance ) {
81                    closestInstances.add(i);
82                }
83            }
84            if( closestInstances.size()==1 ) {
85                int closestIndex = closestInstances.get(0);
86                Instance closestTrainingInstance = traindata.get(closestIndex);
87                List<Integer> closestToTrainingInstance = new LinkedList<>();
88                double minTrainingDistance = Double.MAX_VALUE;
89                for( int i=0; i<traindata.size(); i++ ) {
90                    if( closestIndex!=i ) {
91                        double distance = WekaUtils.hammingDistance(closestTrainingInstance, traindata.get(i));
92                        if( distance<minTrainingDistance ) {
93                            minTrainingDistance = distance;
94                        }
95                    }
96                }
97                for( int i=0; i<traindata.size(); i++ ) {
98                    if( closestIndex!=i ) {
99                        double distance = WekaUtils.hammingDistance(closestTrainingInstance, traindata.get(i));
100                        if( distance<=minTrainingDistance ) {
101                            closestToTrainingInstance.add(i);
102                        }
103                    }
104                }
105                if( closestToTrainingInstance.size()==1 ) {
106                    return laserClassifier.classifyInstance(instance);
107                }
108                else {
109                    double label = Double.NaN;
110                    boolean allEqual = true;
111                    for( Integer index : closestToTrainingInstance ) {
112                        if( label == Double.NaN ) {
113                            label = traindata.get(closestToTrainingInstance.get(index)).classValue();
114                        }
115                        else if( label!=traindata.get(closestToTrainingInstance.get(index)).classValue() ) {
116                            allEqual = false;
117                            break;
118                        }
119                    }
120                    if( allEqual ) {
121                        return label;
122                    }
123                    else {
124                        return laserClassifier.classifyInstance(instance);
125                    }
126                }
127            } else {
128                double label = Double.NaN;
129                boolean allEqual = true;
130                for( Integer index : closestInstances ) {
131                    if( label == Double.NaN ) {
132                        label = traindata.get(closestInstances.get(index)).classValue();
133                    }
134                    else if( label!=traindata.get(closestInstances.get(index)).classValue() ) {
135                        allEqual = false;
136                        break;
137                    }
138                }
139                if( allEqual ) {
140                    return label;
141                }
142                else {
143                    return laserClassifier.classifyInstance(instance);
144                }
145            }
146        }
147
148        @Override
149        public void buildClassifier(Instances traindata) throws Exception {
150            this.traindata = new Instances(traindata);
151            laserClassifier = setupClassifier();
152            laserClassifier.buildClassifier(traindata);
153        }
154    }
155}
Note: See TracBrowser for help on using the repository browser.