source: trunk/CrossPare/src/de/ugoe/cs/cpdp/wekaclassifier/BayesNetWrapper.java @ 130

Last change on this file since 130 was 130, checked in by sherbold, 8 years ago
  • added wrapper classes for BayesNet? and DecisionTable? training that can upscale attributes in case Discretize fails due to differences between buckets being smaller than 0.000001
  • Property svn:mime-type set to text/plain
File size: 4.2 KB
Line 
1// Copyright 2015 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.cpdp.wekaclassifier;
16
17import java.util.HashSet;
18import java.util.Set;
19import java.util.logging.Level;
20import java.util.regex.Matcher;
21import java.util.regex.Pattern;
22
23import de.ugoe.cs.cpdp.util.WekaUtils;
24import de.ugoe.cs.util.console.Console;
25import weka.classifiers.bayes.BayesNet;
26import weka.core.DenseInstance;
27import weka.core.Instance;
28import weka.core.Instances;
29
30/**
31 * <p>
32 * Wrapper to max BayesNet to deal with a problem with Discretize
33 * </p>
34 *
35 * @author Steffen Herbold
36 */
37public class BayesNetWrapper extends BayesNet {
38
39    /**
40     * generated ID
41     */
42    /**  */
43    private static final long serialVersionUID = -4835134612921456157L;
44
45    /**
46     * Map that store attributes for upscaling for each classifier
47     */
48    private Set<Integer> upscaleIndex = new HashSet<>();
49
50    /*
51     * (non-Javadoc)
52     *
53     * @see weka.classifiers.bayes.BayesNet#buildClassifier(weka.core.Instances)
54     */
55    @Override
56    public void buildClassifier(Instances traindata) throws Exception {
57        boolean trainingSuccessfull = false;
58        boolean secondAttempt = false;
59        Instances traindataCopy = null;
60        do {
61            try {
62                if (secondAttempt) {
63                    super.buildClassifier(traindataCopy);
64                    trainingSuccessfull = true;
65                }
66                else {
67                    super.buildClassifier(traindata);
68                    trainingSuccessfull = true;
69                }
70            }
71            catch (IllegalArgumentException e) {
72                String regex = "A nominal attribute \\((.*)\\) cannot have duplicate labels.*";
73                Pattern p = Pattern.compile(regex);
74                Matcher m = p.matcher(e.getMessage());
75                if (!m.find()) {
76                    // cannot treat problem, rethrow exception
77                    throw e;
78                }
79                String attributeName = m.group(1);
80                int attrIndex = traindata.attribute(attributeName).index();
81                if (secondAttempt) {
82                    throw new RuntimeException("cannot be handled correctly yet, because upscaleIndex is a Map");
83                    // traindataCopy = upscaleAttribute(traindataCopy, attrIndex);
84                }
85                else {
86                    traindataCopy = WekaUtils.upscaleAttribute(traindata, attrIndex);
87                }
88
89                upscaleIndex.add(attrIndex);
90                Console.traceln(Level.FINE, "upscaled attribute " + attributeName +
91                    "; restarting training of BayesNet");
92                secondAttempt = true;
93                continue;
94            }
95        }
96        while (!trainingSuccessfull); // dummy loop for internal continue
97    }
98
99    /*
100     * (non-Javadoc)
101     *
102     * @see weka.classifiers.bayes.BayesNet#distributionForInstance(weka.core.Instance)
103     */
104    @Override
105    public double[] distributionForInstance(Instance instance) throws Exception {
106        Instances traindataCopy;
107        for (int attrIndex : upscaleIndex) {
108            // instance value must be upscaled
109            double upscaledVal = instance.value(attrIndex) * WekaUtils.SCALER;
110            traindataCopy = new Instances(instance.dataset());
111            instance = new DenseInstance(instance.weight(), instance.toDoubleArray());
112            instance.setValue(attrIndex, upscaledVal);
113            traindataCopy.add(instance);
114            instance.setDataset(traindataCopy);
115        }
116        return super.distributionForInstance(instance);
117    }
118}
Note: See TracBrowser for help on using the repository browser.