source: trunk/CrossPare/src/de/ugoe/cs/cpdp/wekaclassifier/WHICH.java @ 130

Last change on this file since 130 was 127, checked in by sherbold, 8 years ago
  • fixed bug in WHICH scoring function
  • Property svn:mime-type set to text/plain
File size: 18.1 KB
Line 
1// Copyright 2015 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.cpdp.wekaclassifier;
16
17import java.util.ArrayList;
18import java.util.Arrays;
19import java.util.Collections;
20import java.util.LinkedList;
21import java.util.List;
22import java.util.Random;
23import java.util.logging.Level;
24
25import de.ugoe.cs.util.console.Console;
26import weka.classifiers.AbstractClassifier;
27import weka.core.Attribute;
28import weka.core.Instance;
29import weka.core.Instances;
30import weka.filters.Filter;
31import weka.filters.unsupervised.attribute.Discretize;
32
33/**
34 * <p>
35 * WHICH classifier after Menzies et al.
36 * </p>
37 *
38 * @author Steffen Herbold
39 */
40public class WHICH extends AbstractClassifier {
41
42    /**
43     * default id.
44     */
45    private static final long serialVersionUID = 1L;
46
47    /**
48     * number of bins used for discretization of data
49     */
50    private int numBins = 7;
51
52    /**
53     * number of new rules generate within each rule generation iteration
54     */
55    private final int numNewRules = 5;
56
57    /**
58     * number of rule generation iterations
59     */
60    private final int newRuleIterations = 20;
61
62    /**
63     * maximal number of tries to improve the best score
64     */
65    private final int maxIter = 100;
66
67    /**
68     * best rule determined by the training, i.e., the classifier
69     */
70    private WhichRule bestRule = null;
71
72    /*
73     * (non-Javadoc)
74     *
75     * @see weka.classifiers.Classifier#buildClassifier(weka.core.Instances)
76     */
77    @Override
78    public void buildClassifier(Instances traindata) throws Exception {
79        WhichStack whichStack = new WhichStack();
80        Discretize discretize = new Discretize();
81        discretize.setBins(numBins);
82        discretize.setIgnoreClass(true);
83        discretize.setInputFormat(traindata);
84        Instances discretizedData = Filter.useFilter(traindata, discretize);
85        // init WHICH stack
86        for (int j = 0; j < discretizedData.numAttributes(); j++) {
87            Attribute attr = discretizedData.attribute(j);
88            for (int k = 0; k < attr.numValues(); k++) {
89                // create rules for single variables
90                WhichRule rule = new WhichRule(Arrays.asList(new Integer[]
91                    { j }), Arrays.asList(new Double[]
92                    { (double) k }), Arrays.asList(new String[]
93                    { attr.value(k) }));
94                rule.scoreRule(discretizedData);
95                whichStack.push(rule);
96            }
97        }
98        double curBestScore = whichStack.bestScore;
99        int iter = 0;
100        do {
101            // generate new rules
102            for (int i = 0; i < newRuleIterations; i++) {
103                whichStack.generateRules(numNewRules, discretizedData);
104            }
105            if (curBestScore >= whichStack.bestScore) {
106                // no improvement, terminate
107                break;
108            }
109            curBestScore = whichStack.bestScore;
110            iter++;
111        }
112        while (iter < maxIter);
113
114        bestRule = whichStack.bestRule();
115    }
116
117    /*
118     * (non-Javadoc)
119     *
120     * @see weka.classifiers.AbstractClassifier#classifyInstance(weka.core.Instance)
121     */
122    @Override
123    public double classifyInstance(Instance instance) {
124        if (bestRule == null) {
125            throw new RuntimeException("you have to build the classifier first!");
126        }
127        return bestRule.applyRule(instance, false) ? 0.0 : 1.0;
128    }
129
130    /**
131     * <p>
132     * Internal helper class to handle WHICH rules. The compareTo method is NOT consistent with the
133     * equals method!
134     * </p>
135     *
136     * @author Steffen Herbold
137     */
138    private class WhichRule implements Comparable<WhichRule> {
139        /**
140         * indizes of the attributes in the data to which the rule is applied
141         */
142        final List<Integer> attributeIndizes;
143
144        /**
145         * index of the range for internal optimization during training
146         */
147        final List<Double> rangeIndizes;
148
149        /**
150         * String of the range as created by Discretize.
151         */
152        final List<String> ranges;
153
154        /**
155         * support of the rule
156         */
157        double support;
158
159        /**
160         * percentage of the defective matches where the rule applies
161         */
162        double e1;
163
164        /**
165         * percentage of the non-defective matches where the rule does not apply
166         */
167        double e2;
168
169        /**
170         * score of the rule
171         */
172        double score;
173
174        /**
175         * <p>
176         * Creates a new WhichRule.
177         * </p>
178         *
179         * @param attributeIndizes
180         *            attribute indizes
181         * @param rangeIndizes
182         *            range indizes
183         * @param ranges
184         *            range strings
185         */
186        public WhichRule(List<Integer> attributeIndizes,
187                         List<Double> rangeIndizes,
188                         List<String> ranges)
189        {
190            this.attributeIndizes = attributeIndizes;
191            this.rangeIndizes = rangeIndizes;
192            this.ranges = ranges;
193        }
194
195        /**
196         * <p>
197         * Combines two rules into a new rule
198         * </p>
199         *
200         * @param rule1
201         *            first rule in combination
202         * @param rule2
203         *            second rule in combination
204         */
205        public WhichRule(WhichRule rule1, WhichRule rule2) {
206            attributeIndizes = new ArrayList<>(rule1.attributeIndizes);
207            rangeIndizes = new ArrayList<>(rule1.rangeIndizes);
208            ranges = new ArrayList<>(rule1.ranges);
209            for (int k = 0; k < rule2.attributeIndizes.size(); k++) {
210                if (!attributeIndizes.contains(rule2.attributeIndizes.get(k))) {
211                    attributeIndizes.add(rule2.attributeIndizes.get(k));
212                    rangeIndizes.add(rule2.rangeIndizes.get(k));
213                    ranges.add(rule2.ranges.get(k));
214                }
215            }
216        }
217
218        /**
219         * <p>
220         * Determines the score of a rule.
221         * </p>
222         *
223         * @param traindata
224         */
225        public void scoreRule(Instances traindata) {
226            int numMatches = 0;
227            int numMatchDefective = 0;
228            int numMatchNondefective = 0;
229            @SuppressWarnings("unused")
230            int numNoMatchDefective = 0;
231            @SuppressWarnings("unused")
232            int numNoMatchNondefective = 0;
233            for (int i = 0; i < traindata.size(); i++) {
234                // check if rule applies
235                if (applyRule(traindata.get(i), true)) {
236                    // to something
237                    numMatches++;
238                    if (traindata.get(i).classValue() == 1.0) {
239                        numMatchDefective++;
240                    }
241                    else {
242                        numMatchNondefective++;
243                    }
244                }
245                else {
246                    if (traindata.get(i).classValue() == 1.0) {
247                        numNoMatchDefective++;
248                    }
249                    else {
250                        numNoMatchNondefective++;
251                    }
252                }
253            }
254            support = numMatches / ((double) traindata.size());
255            if (numMatches > 0) {
256                e1 = numMatchNondefective / ((double) numMatches);
257                e2 = numMatchDefective / ((double) numMatches);
258                if (e2 > 0) {
259                    score = e1 / e2 * support;
260                }
261                else {
262                    score = 0;
263                }
264            }
265            else {
266                e1 = 0;
267                e2 = 0;
268                score = 0;
269            }
270            if( score==0 ) {
271                score = 0.000000001; // to disallow 0 total score
272            }
273        }
274
275        /**
276         * <p>
277         * Checks if a rule applies to an instance.
278         * </p>
279         *
280         * @param instance
281         *            the instance
282         * @param isTraining
283         *            if true, the data is discretized training data and rangeIndizes are used;
284         *            otherwise the data is numeric and the range string is used.
285         * @return true if the rule applies
286         */
287        public boolean applyRule(Instance instance, boolean isTraining) {
288            boolean result = true;
289            for (int k = 0; k < attributeIndizes.size(); k++) {
290                int attrIndex = attributeIndizes.get(k);
291                if (isTraining) {
292                    double rangeIndex = rangeIndizes.get(k);
293                    double instanceValue = instance.value(attrIndex);
294                    result &= (instanceValue == rangeIndex);
295                }
296                else {
297                    String range = ranges.get(k);
298                    if( "'All'".equals(range) ) {
299                        result = true;
300                    } else {
301                        double instanceValue = instance.value(attrIndex);
302                        double lowerBound;
303                        double upperBound;
304                        String[] splitResult = range.split("--");
305                        if (splitResult.length > 1) {
306                            // second value is negative
307                            throw new RuntimeException("negative second value cannot be handled by WHICH yet");
308                        }
309                        else {
310                            splitResult = range.split("-");
311                            if (splitResult.length > 2) {
312                                // first value is negative
313                                if ("inf".equals(splitResult[1])) {
314                                    lowerBound = Double.NEGATIVE_INFINITY;
315                                }
316                                else {
317                                    lowerBound = -Double.parseDouble(splitResult[1]);
318                                }
319                                if (splitResult[2].startsWith("inf")) {
320                                    upperBound = Double.POSITIVE_INFINITY;
321                                }
322                                else {
323                                    upperBound = Double.parseDouble(splitResult[2]
324                                        .substring(0, splitResult[2].length() - 2));
325                                }
326                            }
327                            else {
328                                // first value is positive
329                                if( splitResult[0].substring(2, splitResult[0].length()).equals("ll'")) {
330                                    System.out.println("foo");
331                                }
332                                lowerBound = Double
333                                    .parseDouble(splitResult[0].substring(2, splitResult[0].length()));
334                                if (splitResult[1].startsWith("inf")) {
335                                    upperBound = Double.POSITIVE_INFINITY;
336                                }
337                                else {
338                                    upperBound = Double.parseDouble(splitResult[1]
339                                        .substring(0, splitResult[1].length() - 2));
340                                }
341                            }
342                        }
343                        boolean lowerBoundMatch =
344                            (range.charAt(1) == '(' && instanceValue > lowerBound) ||
345                                (range.charAt(1) == '[' && instanceValue >= lowerBound);
346                        boolean upperBoundMatch = (range.charAt(range.length() - 2) == ')' &&
347                            instanceValue < upperBound) ||
348                            (range.charAt(range.length() - 2) == ']' && instanceValue <= upperBound);
349                        result = lowerBoundMatch && upperBoundMatch;
350                    }
351                }
352            }
353            return result;
354        }
355
356        /**
357         * <p>
358         * returns the score of the rule
359         * </p>
360         *
361         * @return
362         */
363        public double getScore() {
364            return score;
365        }
366
367        /*
368         * (non-Javadoc)
369         *
370         * @see java.lang.Comparable#compareTo(java.lang.Object)
371         */
372        @Override
373        public int compareTo(WhichRule other) {
374            // !!this compareTo is NOT consistent with equals!!
375            if (other == null) {
376                return -1;
377            }
378            if (other.score < this.score) {
379                return -1;
380            }
381            else if (other.score > this.score) {
382                return 1;
383            }
384            else {
385                return 0;
386            }
387        }
388
389        /*
390         * (non-Javadoc)
391         *
392         * @see java.lang.Object#equals(java.lang.Object)
393         */
394        @Override
395        public boolean equals(Object other) {
396            if (other == null) {
397                return false;
398            }
399            if (!(other instanceof WhichRule)) {
400                return false;
401            }
402            WhichRule otherRule = (WhichRule) other;
403            return attributeIndizes.equals(otherRule.attributeIndizes) &&
404                rangeIndizes.equals(otherRule.rangeIndizes) && ranges.equals(otherRule.ranges);
405        }
406
407        /*
408         * (non-Javadoc)
409         *
410         * @see java.lang.Object#hashCode()
411         */
412        @Override
413        public int hashCode() {
414            return 117 + attributeIndizes.hashCode() + rangeIndizes.hashCode() + ranges.hashCode();
415        }
416
417        /*
418         * (non-Javadoc)
419         *
420         * @see java.lang.Object#toString()
421         */
422        @Override
423        public String toString() {
424            return "indizes: " + attributeIndizes + "\tranges: " + ranges + "\t score: " + score;
425        }
426    }
427
428    /**
429     * <p>
430     * Internal helper class that handles the WHICH stack during training. Please not that this is
431     * not really a stack, we just stick to the name given in the publication.
432     * </p>
433     *
434     * @author Steffen Herbold
435     */
436    private class WhichStack {
437
438        /**
439         * rules on the WhichStack
440         */
441        List<WhichRule> rules;
442
443        /**
444         * Currently sum of rule scores.
445         */
446        double scoreSum;
447
448        /**
449         * Best rule score.
450         */
451        double bestScore;
452
453        /**
454         * checks if a rule was added after the last sorting
455         */
456        boolean pushAfterSort;
457
458        /**
459         * Internally used random number generator for creating new rules.
460         */
461        Random rand = new Random();
462
463        /**
464         * <p>
465         * Creates a new WhichStack.
466         * </p>
467         *
468         */
469        public WhichStack() {
470            rules = new LinkedList<>();
471            scoreSum = 0.0;
472            bestScore = 0.0;
473            pushAfterSort = false;
474        }
475
476        /**
477         * <p>
478         * Adds a rule to the WhichStack
479         * </p>
480         *
481         * @param rule
482         *            that is added.
483         */
484        public void push(WhichRule rule) {
485            rules.add(rule);
486            scoreSum += rule.getScore();
487            if (rule.getScore() > bestScore) {
488                bestScore = rule.getScore();
489            }
490            pushAfterSort = true;
491        }
492
493        /**
494         * <p>
495         * Generates a new rule as a random combination of two other rules. The two rules are drawn
496         * according to their scoring.
497         * </p>
498         *
499         * @param numRules
500         * @param traindata
501         */
502        public void generateRules(int numRules, Instances traindata) {
503            List<WhichRule> newRules = new LinkedList<>();
504
505            for (int i = 0; i < numRules; i++) {
506                WhichRule newRule;
507                do {
508                    WhichRule rule1 = drawRule();
509                    WhichRule rule2;
510                    do {
511                        rule2 = drawRule();
512                    }
513                    while (rule2.equals(rule1));
514                    newRule = new WhichRule(rule1, rule2);
515                }
516                while (newRules.contains(newRule));
517                newRules.add(newRule);
518            }
519            for (WhichRule newRule : newRules) {
520                newRule.scoreRule(traindata);
521                push(newRule);
522            }
523        }
524
525        /**
526         * <p>
527         * Randomly draws a rule weighted by the score.
528         * </p>
529         *
530         * @return
531         */
532        public WhichRule drawRule() {
533            double randVal = rand.nextDouble() * scoreSum;
534            double curSum = 0.0;
535            for (WhichRule rule : rules) {
536                curSum += rule.getScore();
537                if (curSum >= randVal) {
538                    return rule;
539                }
540            }
541            Console.traceln(Level.SEVERE, "could not draw rule; bug in WhichStack.drawRule()");
542            return null;
543        }
544
545        /**
546         * <p>
547         * Returns the best rule.
548         * </p>
549         *
550         * @return best rule
551         */
552        public WhichRule bestRule() {
553            if (rules.isEmpty()) {
554                return null;
555            }
556            if (pushAfterSort) {
557                Collections.sort(rules);
558            }
559            return rules.get(0);
560        }
561    }
562}
Note: See TracBrowser for help on using the repository browser.