source: trunk/CrossPare/src/de/ugoe/cs/cpdp/wekaclassifier/WHICH.java

Last change on this file was 136, checked in by sherbold, 8 years ago
  • more code documentation
  • Property svn:mime-type set to text/plain
File size: 18.3 KB
Line 
1// Copyright 2015 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.cpdp.wekaclassifier;
16
17import java.util.ArrayList;
18import java.util.Arrays;
19import java.util.Collections;
20import java.util.LinkedList;
21import java.util.List;
22import java.util.Random;
23import java.util.logging.Level;
24
25import de.ugoe.cs.util.console.Console;
26import weka.classifiers.AbstractClassifier;
27import weka.core.Attribute;
28import weka.core.Instance;
29import weka.core.Instances;
30import weka.filters.Filter;
31import weka.filters.unsupervised.attribute.Discretize;
32
33/**
34 * <p>
35 * WHICH classifier after Menzies et al.
36 * </p>
37 *
38 * @author Steffen Herbold
39 */
40public class WHICH extends AbstractClassifier {
41
42    /**
43     * default id.
44     */
45    private static final long serialVersionUID = 1L;
46
47    /**
48     * number of bins used for discretization of data
49     */
50    private int numBins = 7;
51
52    /**
53     * number of new rules generate within each rule generation iteration
54     */
55    private final int numNewRules = 5;
56
57    /**
58     * number of rule generation iterations
59     */
60    private final int newRuleIterations = 20;
61
62    /**
63     * maximal number of tries to improve the best score
64     */
65    private final int maxIter = 100;
66
67    /**
68     * best rule determined by the training, i.e., the classifier
69     */
70    private WhichRule bestRule = null;
71
72    /*
73     * (non-Javadoc)
74     *
75     * @see weka.classifiers.Classifier#buildClassifier(weka.core.Instances)
76     */
77    @Override
78    public void buildClassifier(Instances traindata) throws Exception {
79        WhichStack whichStack = new WhichStack();
80        Discretize discretize = new Discretize();
81        discretize.setBins(numBins);
82        discretize.setIgnoreClass(true);
83        discretize.setInputFormat(traindata);
84        Instances discretizedData = Filter.useFilter(traindata, discretize);
85        // init WHICH stack
86        for (int j = 0; j < discretizedData.numAttributes(); j++) {
87            Attribute attr = discretizedData.attribute(j);
88            for (int k = 0; k < attr.numValues(); k++) {
89                // create rules for single variables
90                WhichRule rule = new WhichRule(Arrays.asList(new Integer[]
91                    { j }), Arrays.asList(new Double[]
92                    { (double) k }), Arrays.asList(new String[]
93                    { attr.value(k) }));
94                rule.scoreRule(discretizedData);
95                whichStack.push(rule);
96            }
97        }
98        double curBestScore = whichStack.bestScore;
99        int iter = 0;
100        do {
101            // generate new rules
102            for (int i = 0; i < newRuleIterations; i++) {
103                whichStack.generateRules(numNewRules, discretizedData);
104            }
105            if (curBestScore >= whichStack.bestScore) {
106                // no improvement, terminate
107                break;
108            }
109            curBestScore = whichStack.bestScore;
110            iter++;
111        }
112        while (iter < maxIter);
113
114        bestRule = whichStack.bestRule();
115    }
116
117    /*
118     * (non-Javadoc)
119     *
120     * @see weka.classifiers.AbstractClassifier#classifyInstance(weka.core.Instance)
121     */
122    @Override
123    public double classifyInstance(Instance instance) {
124        if (bestRule == null) {
125            throw new RuntimeException("you have to build the classifier first!");
126        }
127        return bestRule.applyRule(instance, false) ? 0.0 : 1.0;
128    }
129
130    /**
131     * <p>
132     * Internal helper class to handle WHICH rules. The compareTo method is NOT consistent with the
133     * equals method!
134     * </p>
135     *
136     * @author Steffen Herbold
137     */
138    private class WhichRule implements Comparable<WhichRule> {
139        /**
140         * indizes of the attributes in the data to which the rule is applied
141         */
142        final List<Integer> attributeIndizes;
143
144        /**
145         * index of the range for internal optimization during training
146         */
147        final List<Double> rangeIndizes;
148
149        /**
150         * String of the range as created by Discretize.
151         */
152        final List<String> ranges;
153
154        /**
155         * support of the rule
156         */
157        double support;
158
159        /**
160         * percentage of the defective matches where the rule applies
161         */
162        double e1;
163
164        /**
165         * percentage of the non-defective matches where the rule does not apply
166         */
167        double e2;
168
169        /**
170         * score of the rule
171         */
172        double score;
173
174        /**
175         * <p>
176         * Creates a new WhichRule.
177         * </p>
178         *
179         * @param attributeIndizes
180         *            attribute indizes
181         * @param rangeIndizes
182         *            range indizes
183         * @param ranges
184         *            range strings
185         */
186        public WhichRule(List<Integer> attributeIndizes,
187                         List<Double> rangeIndizes,
188                         List<String> ranges)
189        {
190            this.attributeIndizes = attributeIndizes;
191            this.rangeIndizes = rangeIndizes;
192            this.ranges = ranges;
193        }
194
195        /**
196         * <p>
197         * Combines two rules into a new rule
198         * </p>
199         *
200         * @param rule1
201         *            first rule in combination
202         * @param rule2
203         *            second rule in combination
204         */
205        public WhichRule(WhichRule rule1, WhichRule rule2) {
206            attributeIndizes = new ArrayList<>(rule1.attributeIndizes);
207            rangeIndizes = new ArrayList<>(rule1.rangeIndizes);
208            ranges = new ArrayList<>(rule1.ranges);
209            for (int k = 0; k < rule2.attributeIndizes.size(); k++) {
210                if (!attributeIndizes.contains(rule2.attributeIndizes.get(k))) {
211                    attributeIndizes.add(rule2.attributeIndizes.get(k));
212                    rangeIndizes.add(rule2.rangeIndizes.get(k));
213                    ranges.add(rule2.ranges.get(k));
214                }
215            }
216        }
217
218        /**
219         * <p>
220         * Determines the score of a rule.
221         * </p>
222         *
223         * @param traindata
224         */
225        public void scoreRule(Instances traindata) {
226            int numMatches = 0;
227            int numMatchDefective = 0;
228            int numMatchNondefective = 0;
229            @SuppressWarnings("unused")
230            int numNoMatchDefective = 0;
231            @SuppressWarnings("unused")
232            int numNoMatchNondefective = 0;
233            for (int i = 0; i < traindata.size(); i++) {
234                // check if rule applies
235                if (applyRule(traindata.get(i), true)) {
236                    // to something
237                    numMatches++;
238                    if (traindata.get(i).classValue() == 1.0) {
239                        numMatchDefective++;
240                    }
241                    else {
242                        numMatchNondefective++;
243                    }
244                }
245                else {
246                    if (traindata.get(i).classValue() == 1.0) {
247                        numNoMatchDefective++;
248                    }
249                    else {
250                        numNoMatchNondefective++;
251                    }
252                }
253            }
254            support = numMatches / ((double) traindata.size());
255            if (numMatches > 0) {
256                e1 = numMatchNondefective / ((double) numMatches);
257                e2 = numMatchDefective / ((double) numMatches);
258                if (e2 > 0) {
259                    score = e1 / e2 * support;
260                }
261                else {
262                    score = 0;
263                }
264            }
265            else {
266                e1 = 0;
267                e2 = 0;
268                score = 0;
269            }
270            if (score == 0) {
271                score = 0.000000001; // to disallow 0 total score
272            }
273        }
274
275        /**
276         * <p>
277         * Checks if a rule applies to an instance.
278         * </p>
279         *
280         * @param instance
281         *            the instance
282         * @param isTraining
283         *            if true, the data is discretized training data and rangeIndizes are used;
284         *            otherwise the data is numeric and the range string is used.
285         * @return true if the rule applies
286         */
287        public boolean applyRule(Instance instance, boolean isTraining) {
288            boolean result = true;
289            for (int k = 0; k < attributeIndizes.size(); k++) {
290                int attrIndex = attributeIndizes.get(k);
291                if (isTraining) {
292                    double rangeIndex = rangeIndizes.get(k);
293                    double instanceValue = instance.value(attrIndex);
294                    result &= (instanceValue == rangeIndex);
295                }
296                else {
297                    String range = ranges.get(k);
298                    if ("'All'".equals(range)) {
299                        result = true;
300                    }
301                    else {
302                        double instanceValue = instance.value(attrIndex);
303                        double lowerBound;
304                        double upperBound;
305                        String[] splitResult = range.split("--");
306                        if (splitResult.length > 1) {
307                            // second value is negative
308                            throw new RuntimeException("negative second value cannot be handled by WHICH yet");
309                        }
310                        else {
311                            splitResult = range.split("-");
312                            if (splitResult.length > 2) {
313                                // first value is negative
314                                if ("inf".equals(splitResult[1])) {
315                                    lowerBound = Double.NEGATIVE_INFINITY;
316                                }
317                                else {
318                                    lowerBound = -Double.parseDouble(splitResult[1]);
319                                }
320                                if (splitResult[2].startsWith("inf")) {
321                                    upperBound = Double.POSITIVE_INFINITY;
322                                }
323                                else {
324                                    upperBound = Double.parseDouble(splitResult[2]
325                                        .substring(0, splitResult[2].length() - 2));
326                                }
327                            }
328                            else {
329                                // first value is positive
330                                if (splitResult[0].substring(2, splitResult[0].length())
331                                    .equals("ll'"))
332                                {
333                                    System.out.println("foo");
334                                }
335                                lowerBound = Double.parseDouble(splitResult[0]
336                                    .substring(2, splitResult[0].length()));
337                                if (splitResult[1].startsWith("inf")) {
338                                    upperBound = Double.POSITIVE_INFINITY;
339                                }
340                                else {
341                                    upperBound = Double.parseDouble(splitResult[1]
342                                        .substring(0, splitResult[1].length() - 2));
343                                }
344                            }
345                        }
346                        boolean lowerBoundMatch =
347                            (range.charAt(1) == '(' && instanceValue > lowerBound) ||
348                                (range.charAt(1) == '[' && instanceValue >= lowerBound);
349                        boolean upperBoundMatch = (range.charAt(range.length() - 2) == ')' &&
350                            instanceValue < upperBound) ||
351                            (range.charAt(range.length() - 2) == ']' &&
352                                instanceValue <= upperBound);
353                        result = lowerBoundMatch && upperBoundMatch;
354                    }
355                }
356            }
357            return result;
358        }
359
360        /**
361         * <p>
362         * returns the score of the rule
363         * </p>
364         *
365         * @return the score
366         */
367        public double getScore() {
368            return score;
369        }
370
371        /*
372         * (non-Javadoc)
373         *
374         * @see java.lang.Comparable#compareTo(java.lang.Object)
375         */
376        @Override
377        public int compareTo(WhichRule other) {
378            // !!this compareTo is NOT consistent with equals!!
379            if (other == null) {
380                return -1;
381            }
382            if (other.score < this.score) {
383                return -1;
384            }
385            else if (other.score > this.score) {
386                return 1;
387            }
388            else {
389                return 0;
390            }
391        }
392
393        /*
394         * (non-Javadoc)
395         *
396         * @see java.lang.Object#equals(java.lang.Object)
397         */
398        @Override
399        public boolean equals(Object other) {
400            if (other == null) {
401                return false;
402            }
403            if (!(other instanceof WhichRule)) {
404                return false;
405            }
406            WhichRule otherRule = (WhichRule) other;
407            return attributeIndizes.equals(otherRule.attributeIndizes) &&
408                rangeIndizes.equals(otherRule.rangeIndizes) && ranges.equals(otherRule.ranges);
409        }
410
411        /*
412         * (non-Javadoc)
413         *
414         * @see java.lang.Object#hashCode()
415         */
416        @Override
417        public int hashCode() {
418            return 117 + attributeIndizes.hashCode() + rangeIndizes.hashCode() + ranges.hashCode();
419        }
420
421        /*
422         * (non-Javadoc)
423         *
424         * @see java.lang.Object#toString()
425         */
426        @Override
427        public String toString() {
428            return "indizes: " + attributeIndizes + "\tranges: " + ranges + "\t score: " + score;
429        }
430    }
431
432    /**
433     * <p>
434     * Internal helper class that handles the WHICH stack during training. Please not that this is
435     * not really a stack, we just stick to the name given in the publication.
436     * </p>
437     *
438     * @author Steffen Herbold
439     */
440    private class WhichStack {
441
442        /**
443         * rules on the WhichStack
444         */
445        List<WhichRule> rules;
446
447        /**
448         * Currently sum of rule scores.
449         */
450        double scoreSum;
451
452        /**
453         * Best rule score.
454         */
455        double bestScore;
456
457        /**
458         * checks if a rule was added after the last sorting
459         */
460        boolean pushAfterSort;
461
462        /**
463         * Internally used random number generator for creating new rules.
464         */
465        Random rand = new Random();
466
467        /**
468         * <p>
469         * Creates a new WhichStack.
470         * </p>
471         *
472         */
473        public WhichStack() {
474            rules = new LinkedList<>();
475            scoreSum = 0.0;
476            bestScore = 0.0;
477            pushAfterSort = false;
478        }
479
480        /**
481         * <p>
482         * Adds a rule to the WhichStack
483         * </p>
484         *
485         * @param rule
486         *            that is added.
487         */
488        public void push(WhichRule rule) {
489            rules.add(rule);
490            scoreSum += rule.getScore();
491            if (rule.getScore() > bestScore) {
492                bestScore = rule.getScore();
493            }
494            pushAfterSort = true;
495        }
496
497        /**
498         * <p>
499         * Generates a new rule as a random combination of two other rules. The two rules are drawn
500         * according to their scoring.
501         * </p>
502         *
503         * @param numRules
504         * @param traindata
505         */
506        public void generateRules(int numRules, Instances traindata) {
507            List<WhichRule> newRules = new LinkedList<>();
508
509            for (int i = 0; i < numRules; i++) {
510                WhichRule newRule;
511                do {
512                    WhichRule rule1 = drawRule();
513                    WhichRule rule2;
514                    do {
515                        rule2 = drawRule();
516                    }
517                    while (rule2.equals(rule1));
518                    newRule = new WhichRule(rule1, rule2);
519                }
520                while (newRules.contains(newRule));
521                newRules.add(newRule);
522            }
523            for (WhichRule newRule : newRules) {
524                newRule.scoreRule(traindata);
525                push(newRule);
526            }
527        }
528
529        /**
530         * <p>
531         * Randomly draws a rule weighted by the score.
532         * </p>
533         *
534         * @return drawn rule
535         */
536        public WhichRule drawRule() {
537            double randVal = rand.nextDouble() * scoreSum;
538            double curSum = 0.0;
539            for (WhichRule rule : rules) {
540                curSum += rule.getScore();
541                if (curSum >= randVal) {
542                    return rule;
543                }
544            }
545            Console.traceln(Level.SEVERE, "could not draw rule; bug in WhichStack.drawRule()");
546            return null;
547        }
548
549        /**
550         * <p>
551         * Returns the best rule.
552         * </p>
553         *
554         * @return best rule
555         */
556        public WhichRule bestRule() {
557            if (rules.isEmpty()) {
558                return null;
559            }
560            if (pushAfterSort) {
561                Collections.sort(rules);
562            }
563            return rules.get(0);
564        }
565    }
566}
Note: See TracBrowser for help on using the repository browser.