source: trunk/CrossPare/src/de/ugoe/cs/cpdp/dataselection/DecisionTreeSelection.java

Last change on this file was 135, checked in by sherbold, 8 years ago
  • code documentation and formatting
  • Property svn:mime-type set to text/plain
File size: 5.3 KB
Line 
1// Copyright 2015 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.cpdp.dataselection;
16
17import java.util.ArrayList;
18
19import org.apache.commons.collections4.list.SetUniqueList;
20
21import de.ugoe.cs.util.console.Console;
22import weka.classifiers.Classifier;
23import weka.classifiers.Evaluation;
24import weka.classifiers.trees.J48;
25import weka.classifiers.trees.REPTree;
26import weka.core.Attribute;
27import weka.core.DenseInstance;
28import weka.core.Instances;
29
30/**
31 * <p>
32 * Training data selection as a combination of Zimmermann et al. 2009
33 * </p>
34 *
35 * @author Steffen Herbold
36 */
37public class DecisionTreeSelection extends AbstractCharacteristicSelection {
38
39    /*
40     * @see de.ugoe.cs.cpdp.dataselection.SetWiseDataselectionStrategy#apply(weka.core.Instances,
41     * org.apache.commons.collections4.list.SetUniqueList)
42     */
43    @Override
44    public void apply(Instances testdata, SetUniqueList<Instances> traindataSet) {
45        final Instances data = characteristicInstances(testdata, traindataSet);
46
47        final ArrayList<String> attVals = new ArrayList<String>();
48        attVals.add("same");
49        attVals.add("more");
50        attVals.add("less");
51        final ArrayList<Attribute> atts = new ArrayList<Attribute>();
52        for (int j = 0; j < data.numAttributes(); j++) {
53            atts.add(new Attribute(data.attribute(j).name(), attVals));
54        }
55        atts.add(new Attribute("score"));
56        Instances similarityData = new Instances("similarity", atts, 0);
57        similarityData.setClassIndex(similarityData.numAttributes() - 1);
58
59        try {
60            Classifier classifier = new J48();
61            for (int i = 0; i < traindataSet.size(); i++) {
62                classifier.buildClassifier(traindataSet.get(i));
63                for (int j = 0; j < traindataSet.size(); j++) {
64                    if (i != j) {
65                        double[] similarity = new double[data.numAttributes() + 1];
66                        for (int k = 0; k < data.numAttributes(); k++) {
67                            if (0.9 * data.get(i + 1).value(k) > data.get(j + 1).value(k)) {
68                                similarity[k] = 2.0;
69                            }
70                            else if (1.1 * data.get(i + 1).value(k) < data.get(j + 1).value(k)) {
71                                similarity[k] = 1.0;
72                            }
73                            else {
74                                similarity[k] = 0.0;
75                            }
76                        }
77
78                        Evaluation eval = new Evaluation(traindataSet.get(j));
79                        eval.evaluateModel(classifier, traindataSet.get(j));
80                        similarity[data.numAttributes()] = eval.fMeasure(1);
81                        similarityData.add(new DenseInstance(1.0, similarity));
82                    }
83                }
84            }
85            REPTree repTree = new REPTree();
86            if (repTree.getNumFolds() > similarityData.size()) {
87                repTree.setNumFolds(similarityData.size());
88            }
89            repTree.setNumFolds(2);
90            repTree.buildClassifier(similarityData);
91
92            Instances testTrainSimilarity = new Instances(similarityData);
93            testTrainSimilarity.clear();
94            for (int i = 0; i < traindataSet.size(); i++) {
95                double[] similarity = new double[data.numAttributes() + 1];
96                for (int k = 0; k < data.numAttributes(); k++) {
97                    if (0.9 * data.get(0).value(k) > data.get(i + 1).value(k)) {
98                        similarity[k] = 2.0;
99                    }
100                    else if (1.1 * data.get(0).value(k) < data.get(i + 1).value(k)) {
101                        similarity[k] = 1.0;
102                    }
103                    else {
104                        similarity[k] = 0.0;
105                    }
106                }
107                testTrainSimilarity.add(new DenseInstance(1.0, similarity));
108            }
109
110            int bestScoringProductIndex = -1;
111            double maxScore = Double.MIN_VALUE;
112            for (int i = 0; i < traindataSet.size(); i++) {
113                double score = repTree.classifyInstance(testTrainSimilarity.get(i));
114                if (score > maxScore) {
115                    maxScore = score;
116                    bestScoringProductIndex = i;
117                }
118            }
119            Instances bestScoringProduct = traindataSet.get(bestScoringProductIndex);
120            traindataSet.clear();
121            traindataSet.add(bestScoringProduct);
122        }
123        catch (Exception e) {
124            Console.printerr("failure during DecisionTreeSelection: " + e.getMessage());
125            throw new RuntimeException(e);
126        }
127    }
128}
Note: See TracBrowser for help on using the repository browser.