/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.meta;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Enumeration;
import java.util.List;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.RandomizableParallelIteratedSingleClassifierEnhancer;
import weka.classifiers.trees.REPTree;
import weka.core.AdditionalMeasureProducer;
import weka.core.Aggregateable;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.PartitionGenerator;
import weka.core.Randomizable;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

public class Bagging
extends RandomizableParallelIteratedSingleClassifierEnhancer
implements WeightedInstancesHandler,
AdditionalMeasureProducer,
TechnicalInformationHandler,
PartitionGenerator,
Aggregateable<Bagging> {
    static final long serialVersionUID = -115879962237199703L;
    protected int m_BagSizePercent = 100;
    protected boolean m_CalcOutOfBag = false;
    protected boolean m_RepresentUsingWeights = false;
    protected double m_OutOfBagError;
    protected Random m_random;
    protected boolean[][] m_inBag;
    protected Instances m_data;
    protected List<Classifier> m_classifiersCache;

    public Bagging() {
        this.m_Classifier = new REPTree();
    }

    public String globalInfo() {
        return "Class for bagging a classifier to reduce variance. Can do classification and regression depending on the base learner. \n\nFor more information, see\n\n" + this.getTechnicalInformation().toString();
    }

    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.ARTICLE);
        result.setValue(TechnicalInformation.Field.AUTHOR, "Leo Breiman");
        result.setValue(TechnicalInformation.Field.YEAR, "1996");
        result.setValue(TechnicalInformation.Field.TITLE, "Bagging predictors");
        result.setValue(TechnicalInformation.Field.JOURNAL, "Machine Learning");
        result.setValue(TechnicalInformation.Field.VOLUME, "24");
        result.setValue(TechnicalInformation.Field.NUMBER, "2");
        result.setValue(TechnicalInformation.Field.PAGES, "123-140");
        return result;
    }

    @Override
    protected String defaultClassifierString() {
        return "weka.classifiers.trees.REPTree";
    }

    @Override
    public Enumeration<Option> listOptions() {
        Vector<Option> newVector = new Vector<Option>(3);
        newVector.addElement(new Option("\tSize of each bag, as a percentage of the\n\ttraining set size. (default 100)", "P", 1, "-P"));
        newVector.addElement(new Option("\tCalculate the out of bag error.", "O", 0, "-O"));
        newVector.addElement(new Option("\tRepresent copies of instances using weights rather than explicitly.", "-represent-copies-using-weights", 0, "-represent-copies-using-weights"));
        newVector.addAll(Collections.list(super.listOptions()));
        return newVector.elements();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        String bagSize = Utils.getOption('P', options);
        if (bagSize.length() != 0) {
            this.setBagSizePercent(Integer.parseInt(bagSize));
        } else {
            this.setBagSizePercent(100);
        }
        this.setCalcOutOfBag(Utils.getFlag('O', options));
        this.setRepresentCopiesUsingWeights(Utils.getFlag("represent-copies-using-weights", options));
        super.setOptions(options);
        Utils.checkForRemainingOptions(options);
    }

    @Override
    public String[] getOptions() {
        Vector<String> options = new Vector<String>();
        options.add("-P");
        options.add("" + this.getBagSizePercent());
        if (this.getCalcOutOfBag()) {
            options.add("-O");
        }
        if (this.getRepresentCopiesUsingWeights()) {
            options.add("-represent-copies-using-weights");
        }
        Collections.addAll(options, super.getOptions());
        return options.toArray(new String[0]);
    }

    public String bagSizePercentTipText() {
        return "Size of each bag, as a percentage of the training set size.";
    }

    public int getBagSizePercent() {
        return this.m_BagSizePercent;
    }

    public void setBagSizePercent(int newBagSizePercent) {
        this.m_BagSizePercent = newBagSizePercent;
    }

    public String representCopiesUsingWeightsTipText() {
        return "Whether to represent copies of instances using weights rather than explicitly.";
    }

    public void setRepresentCopiesUsingWeights(boolean representUsingWeights) {
        this.m_RepresentUsingWeights = representUsingWeights;
    }

    public boolean getRepresentCopiesUsingWeights() {
        return this.m_RepresentUsingWeights;
    }

    public String calcOutOfBagTipText() {
        return "Whether the out-of-bag error is calculated.";
    }

    public void setCalcOutOfBag(boolean calcOutOfBag) {
        this.m_CalcOutOfBag = calcOutOfBag;
    }

    public boolean getCalcOutOfBag() {
        return this.m_CalcOutOfBag;
    }

    public double measureOutOfBagError() {
        return this.m_OutOfBagError;
    }

    @Override
    public Enumeration<String> enumerateMeasures() {
        Vector<String> newVector = new Vector<String>(1);
        newVector.addElement("measureOutOfBagError");
        return newVector.elements();
    }

    @Override
    public double getMeasure(String additionalMeasureName) {
        if (additionalMeasureName.equalsIgnoreCase("measureOutOfBagError")) {
            return this.measureOutOfBagError();
        }
        throw new IllegalArgumentException(String.valueOf(additionalMeasureName) + " not supported (Bagging)");
    }

    @Override
    protected synchronized Instances getTrainingSet(int iteration) throws Exception {
        int bagSize = this.m_data.numInstances() * this.m_BagSizePercent / 100;
        Instances bagData = null;
        Random r = new Random(this.m_Seed + iteration);
        if (this.m_CalcOutOfBag) {
            this.m_inBag[iteration] = new boolean[this.m_data.numInstances()];
            bagData = this.m_data.resampleWithWeights(r, this.m_inBag[iteration], this.getRepresentCopiesUsingWeights());
        } else {
            bagData = this.m_data.resampleWithWeights(r, this.getRepresentCopiesUsingWeights());
            if (bagSize < this.m_data.numInstances()) {
                Instances newBagData;
                bagData.randomize(r);
                bagData = newBagData = new Instances(bagData, 0, bagSize);
            }
        }
        return bagData;
    }

    @Override
    public void buildClassifier(Instances data) throws Exception {
        this.getCapabilities().testWithFail(data);
        if (this.getRepresentCopiesUsingWeights() && !(this.m_Classifier instanceof WeightedInstancesHandler)) {
            throw new IllegalArgumentException("Cannot represent copies using weights when base learner in bagging does not implement WeightedInstancesHandler.");
        }
        this.m_data = new Instances(data);
        this.m_data.deleteWithMissingClass();
        super.buildClassifier(this.m_data);
        if (this.m_CalcOutOfBag && this.m_BagSizePercent != 100) {
            throw new IllegalArgumentException("Bag size needs to be 100% if out-of-bag error is to be calculated!");
        }
        this.m_random = new Random(this.m_Seed);
        this.m_inBag = null;
        if (this.m_CalcOutOfBag) {
            this.m_inBag = new boolean[this.m_Classifiers.length][];
        }
        int j = 0;
        while (j < this.m_Classifiers.length) {
            if (this.m_Classifier instanceof Randomizable) {
                ((Randomizable)((Object)this.m_Classifiers[j])).setSeed(this.m_random.nextInt());
            }
            ++j;
        }
        this.buildClassifiers();
        if (this.getCalcOutOfBag()) {
            double outOfBagCount = 0.0;
            double errorSum = 0.0;
            boolean numeric = this.m_data.classAttribute().isNumeric();
            int i = 0;
            while (i < this.m_data.numInstances()) {
                double vote;
                double[] votes = numeric ? new double[1] : new double[this.m_data.numClasses()];
                int voteCount = 0;
                int j2 = 0;
                while (j2 < this.m_Classifiers.length) {
                    if (!this.m_inBag[j2][i]) {
                        if (numeric) {
                            double pred = this.m_Classifiers[j2].classifyInstance(this.m_data.instance(i));
                            if (!Utils.isMissingValue(pred)) {
                                votes[0] = votes[0] + pred;
                                ++voteCount;
                            }
                        } else {
                            ++voteCount;
                            double[] newProbs = this.m_Classifiers[j2].distributionForInstance(this.m_data.instance(i));
                            int k = 0;
                            while (k < newProbs.length) {
                                int n = k;
                                votes[n] = votes[n] + newProbs[k];
                                ++k;
                            }
                        }
                    }
                    ++j2;
                }
                if (numeric) {
                    vote = voteCount == 0 ? Utils.missingValue() : votes[0] / (double)voteCount;
                } else if (Utils.eq(Utils.sum(votes), 0.0)) {
                    vote = Utils.missingValue();
                } else {
                    vote = Utils.maxIndex(votes);
                    Utils.normalize(votes);
                }
                if (!Utils.isMissingValue(vote)) {
                    outOfBagCount += this.m_data.instance(i).weight();
                    if (numeric) {
                        errorSum += StrictMath.abs(vote - this.m_data.instance(i).classValue()) * this.m_data.instance(i).weight();
                    } else if (vote != this.m_data.instance(i).classValue()) {
                        errorSum += this.m_data.instance(i).weight();
                    }
                }
                ++i;
            }
            if (outOfBagCount > 0.0) {
                this.m_OutOfBagError = errorSum / outOfBagCount;
            }
        } else {
            this.m_OutOfBagError = 0.0;
        }
        this.m_data = null;
    }

    @Override
    public double[] distributionForInstance(Instance instance) throws Exception {
        double[] sums = new double[instance.numClasses()];
        double numPreds = 0.0;
        int i = 0;
        while (i < this.m_NumIterations) {
            if (instance.classAttribute().isNumeric()) {
                double pred = this.m_Classifiers[i].classifyInstance(instance);
                if (!Utils.isMissingValue(pred)) {
                    sums[0] = sums[0] + pred;
                    numPreds += 1.0;
                }
            } else {
                double[] newProbs = this.m_Classifiers[i].distributionForInstance(instance);
                int j = 0;
                while (j < newProbs.length) {
                    int n = j;
                    sums[n] = sums[n] + newProbs[j];
                    ++j;
                }
            }
            ++i;
        }
        if (instance.classAttribute().isNumeric()) {
            sums[0] = numPreds == 0.0 ? Utils.missingValue() : sums[0] / numPreds;
            return sums;
        }
        if (Utils.eq(Utils.sum(sums), 0.0)) {
            return sums;
        }
        Utils.normalize(sums);
        return sums;
    }

    public String toString() {
        if (this.m_Classifiers == null) {
            return "Bagging: No model built yet.";
        }
        StringBuffer text = new StringBuffer();
        text.append("All the base classifiers: \n\n");
        int i = 0;
        while (i < this.m_Classifiers.length) {
            text.append(String.valueOf(this.m_Classifiers[i].toString()) + "\n\n");
            ++i;
        }
        if (this.m_CalcOutOfBag) {
            text.append("Out of bag error: " + Utils.doubleToString(this.m_OutOfBagError, 4) + "\n\n");
        }
        return text.toString();
    }

    @Override
    public void generatePartition(Instances data) throws Exception {
        if (!(this.m_Classifier instanceof PartitionGenerator)) {
            throw new Exception("Classifier: " + this.getClassifierSpec() + " cannot generate a partition");
        }
        this.buildClassifier(data);
    }

    @Override
    public double[] getMembershipValues(Instance inst) throws Exception {
        if (this.m_Classifier instanceof PartitionGenerator) {
            ArrayList<double[]> al = new ArrayList<double[]>();
            int size = 0;
            int i = 0;
            while (i < this.m_Classifiers.length) {
                double[] r = ((PartitionGenerator)((Object)this.m_Classifiers[i])).getMembershipValues(inst);
                size += r.length;
                al.add(r);
                ++i;
            }
            double[] values = new double[size];
            int pos = 0;
            for (double[] v : al) {
                System.arraycopy(v, 0, values, pos, v.length);
                pos += v.length;
            }
            return values;
        }
        throw new Exception("Classifier: " + this.getClassifierSpec() + " cannot generate a partition");
    }

    @Override
    public int numElements() throws Exception {
        if (this.m_Classifier instanceof PartitionGenerator) {
            int size = 0;
            int i = 0;
            while (i < this.m_Classifiers.length) {
                size += ((PartitionGenerator)((Object)this.m_Classifiers[i])).numElements();
                ++i;
            }
            return size;
        }
        throw new Exception("Classifier: " + this.getClassifierSpec() + " cannot generate a partition");
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 10470 $");
    }

    public static void main(String[] argv) {
        Bagging.runClassifier(new Bagging(), argv);
    }

    @Override
    public Bagging aggregate(Bagging toAggregate) throws Exception {
        if (!this.m_Classifier.getClass().isAssignableFrom(toAggregate.m_Classifier.getClass())) {
            throw new Exception("Can't aggregate because base classifiers differ");
        }
        if (this.m_classifiersCache == null) {
            this.m_classifiersCache = new ArrayList<Classifier>();
            this.m_classifiersCache.addAll(Arrays.asList(this.m_Classifiers));
        }
        this.m_classifiersCache.addAll(Arrays.asList(toAggregate.m_Classifiers));
        return this;
    }

    @Override
    public void finalizeAggregation() throws Exception {
        this.m_Classifiers = this.m_classifiersCache.toArray(new Classifier[1]);
        this.m_NumIterations = this.m_Classifiers.length;
        this.m_classifiersCache = null;
    }
}

