/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.trees.lmt;

import java.util.Collections;
import java.util.Iterator;
import java.util.Vector;
import weka.classifiers.Evaluation;
import weka.classifiers.trees.j48.ClassifierSplitModel;
import weka.classifiers.trees.j48.ModelSelection;
import weka.classifiers.trees.lmt.CompareNode;
import weka.classifiers.trees.lmt.LogisticBase;
import weka.classifiers.trees.lmt.ResidualModelSelection;
import weka.classifiers.trees.lmt.SimpleLinearRegression;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.supervised.attribute.NominalToBinary;

public class LMTNode
extends LogisticBase {
    static final long serialVersionUID = 1862737145870398755L;
    protected double m_totalInstanceWeight;
    protected int m_id;
    protected int m_leafModelNum;
    public double m_alpha;
    public double m_numIncorrectModel;
    public double m_numIncorrectTree;
    protected int m_minNumInstances;
    protected ModelSelection m_modelSelection;
    protected NominalToBinary m_nominalToBinary;
    protected static int m_numFoldsPruning = 5;
    protected boolean m_fastRegression;
    protected int m_numInstances;
    protected ClassifierSplitModel m_localModel;
    protected LMTNode[] m_sons;
    protected boolean m_isLeaf;

    public LMTNode(ModelSelection modelSelection, int numBoostingIterations, boolean fastRegression, boolean errorOnProbabilities, int minNumInstances, double weightTrimBeta, boolean useAIC, NominalToBinary ntb) {
        this.m_modelSelection = modelSelection;
        this.m_fixedNumIterations = numBoostingIterations;
        this.m_fastRegression = fastRegression;
        this.m_errorOnProbabilities = errorOnProbabilities;
        this.m_minNumInstances = minNumInstances;
        this.m_maxIterations = 200;
        this.setWeightTrimBeta(weightTrimBeta);
        this.setUseAIC(useAIC);
        this.m_nominalToBinary = ntb;
    }

    @Override
    public void buildClassifier(Instances data) throws Exception {
        if (this.m_fastRegression && this.m_fixedNumIterations < 0) {
            this.m_fixedNumIterations = this.tryLogistic(data);
        }
        Instances cvData = new Instances(data);
        cvData.stratify(m_numFoldsPruning);
        double[][] alphas = new double[m_numFoldsPruning][];
        double[][] errors = new double[m_numFoldsPruning][];
        int i = 0;
        while (i < m_numFoldsPruning) {
            Instances train = cvData.trainCV(m_numFoldsPruning, i);
            Instances test = cvData.testCV(m_numFoldsPruning, i);
            this.buildTree(train, null, train.numInstances(), 0.0, null);
            int numNodes = this.getNumInnerNodes();
            alphas[i] = new double[numNodes + 2];
            errors[i] = new double[numNodes + 2];
            this.prune(alphas[i], errors[i], test);
            ++i;
        }
        cvData = null;
        this.buildTree(data, null, data.numInstances(), 0.0, null);
        int numNodes = this.getNumInnerNodes();
        double[] treeAlphas = new double[numNodes + 2];
        int iterations = this.prune(treeAlphas, null, null);
        double[] treeErrors = new double[numNodes + 2];
        int i2 = 0;
        while (i2 <= iterations) {
            double alpha = Math.sqrt(treeAlphas[i2] * treeAlphas[i2 + 1]);
            double error = 0.0;
            int k = 0;
            while (k < m_numFoldsPruning) {
                int l = 0;
                while (alphas[k][l] <= alpha) {
                    ++l;
                }
                error += errors[k][l - 1];
                ++k;
            }
            treeErrors[i2] = error;
            ++i2;
        }
        int best = -1;
        double bestError = Double.MAX_VALUE;
        int i3 = iterations;
        while (i3 >= 0) {
            if (treeErrors[i3] < bestError) {
                bestError = treeErrors[i3];
                best = i3;
            }
            --i3;
        }
        double bestAlpha = Math.sqrt(treeAlphas[best] * treeAlphas[best + 1]);
        this.unprune();
        this.prune(bestAlpha);
    }

    public void buildTree(Instances data, SimpleLinearRegression[][] higherRegressions, double totalInstanceWeight, double higherNumParameters, Instances numericDataHeader) throws Exception {
        boolean grow;
        this.m_totalInstanceWeight = totalInstanceWeight;
        this.m_train = data;
        this.m_isLeaf = true;
        this.m_sons = null;
        this.m_numInstances = this.m_train.numInstances();
        this.m_numClasses = this.m_train.numClasses();
        this.m_numericDataHeader = numericDataHeader;
        this.m_numericData = this.getNumericData(this.m_train);
        this.m_regressions = higherRegressions == null ? this.initRegressions() : higherRegressions;
        this.m_numParameters = higherNumParameters;
        this.m_numRegressions = 0;
        if (this.m_numInstances >= m_numFoldsBoosting) {
            if (this.m_fixedNumIterations > 0) {
                this.performBoosting(this.m_fixedNumIterations);
            } else if (this.getUseAIC()) {
                this.performBoostingInfCriterion();
            } else {
                this.performBoostingCV();
            }
        }
        this.m_numParameters += (double)this.m_numRegressions;
        Evaluation eval = new Evaluation(this.m_train);
        eval.evaluateModel(this, this.m_train, new Object[0]);
        this.m_numIncorrectModel = eval.incorrect();
        if (this.m_numInstances > this.m_minNumInstances) {
            if (this.m_modelSelection instanceof ResidualModelSelection) {
                double[][] probs = this.getProbs(this.getFs(this.m_numericData));
                double[][] trainYs = this.getYs(this.m_train);
                double[][] dataZs = this.getZs(probs, trainYs);
                double[][] dataWs = this.getWs(probs, trainYs);
                this.m_localModel = ((ResidualModelSelection)this.m_modelSelection).selectModel(this.m_train, dataZs, dataWs);
            } else {
                this.m_localModel = this.m_modelSelection.selectModel(this.m_train);
            }
            grow = this.m_localModel.numSubsets() > 1;
        } else {
            grow = false;
        }
        if (grow) {
            this.m_isLeaf = false;
            Instances[] localInstances = this.m_localModel.split(this.m_train);
            this.cleanup();
            this.m_sons = new LMTNode[this.m_localModel.numSubsets()];
            int i = 0;
            while (i < this.m_sons.length) {
                this.m_sons[i] = new LMTNode(this.m_modelSelection, this.m_fixedNumIterations, this.m_fastRegression, this.m_errorOnProbabilities, this.m_minNumInstances, this.getWeightTrimBeta(), this.getUseAIC(), this.m_nominalToBinary);
                this.m_sons[i].buildTree(localInstances[i], this.copyRegressions(this.m_regressions), this.m_totalInstanceWeight, this.m_numParameters, this.m_numericDataHeader);
                localInstances[i] = null;
                ++i;
            }
        } else {
            this.cleanup();
        }
    }

    public void prune(double alpha) throws Exception {
        CompareNode comparator = new CompareNode();
        this.treeErrors();
        this.calculateAlphas();
        Vector<LMTNode> nodeList = this.getNodes();
        boolean prune = nodeList.size() > 0;
        while (prune) {
            LMTNode nodeToPrune = Collections.min(nodeList, comparator);
            if (nodeToPrune.m_alpha > alpha) break;
            nodeToPrune.m_isLeaf = true;
            nodeToPrune.m_sons = null;
            this.treeErrors();
            this.calculateAlphas();
            nodeList = this.getNodes();
            boolean bl = prune = nodeList.size() > 0;
        }
        Iterator<LMTNode> iterator = this.getNodes().iterator();
        while (iterator.hasNext()) {
            LMTNode node;
            LMTNode lnode = node = iterator.next();
            if (lnode.m_isLeaf) continue;
            this.m_regressions = null;
        }
    }

    public int prune(double[] alphas, double[] errors, Instances test) throws Exception {
        Evaluation eval;
        CompareNode comparator = new CompareNode();
        this.treeErrors();
        this.calculateAlphas();
        Vector<LMTNode> nodeList = this.getNodes();
        boolean prune = nodeList.size() > 0;
        alphas[0] = 0.0;
        if (errors != null) {
            eval = new Evaluation(test);
            eval.evaluateModel(this, test, new Object[0]);
            errors[0] = eval.errorRate();
        }
        int iteration = 0;
        while (prune) {
            LMTNode nodeToPrune = Collections.min(nodeList, comparator);
            nodeToPrune.m_isLeaf = true;
            alphas[++iteration] = nodeToPrune.m_alpha;
            if (errors != null) {
                eval = new Evaluation(test);
                eval.evaluateModel(this, test, new Object[0]);
                errors[iteration] = eval.errorRate();
            }
            this.treeErrors();
            this.calculateAlphas();
            nodeList = this.getNodes();
            boolean bl = prune = nodeList.size() > 0;
        }
        alphas[iteration + 1] = 1.0;
        return iteration;
    }

    protected void unprune() {
        if (this.m_sons != null) {
            this.m_isLeaf = false;
            LMTNode[] lMTNodeArray = this.m_sons;
            int n = this.m_sons.length;
            int n2 = 0;
            while (n2 < n) {
                LMTNode m_son = lMTNodeArray[n2];
                m_son.unprune();
                ++n2;
            }
        }
    }

    protected int tryLogistic(Instances data) throws Exception {
        Instances filteredData = Filter.useFilter(data, this.m_nominalToBinary);
        LogisticBase logistic = new LogisticBase(0, true, this.m_errorOnProbabilities);
        logistic.setMaxIterations(200);
        logistic.setWeightTrimBeta(this.getWeightTrimBeta());
        logistic.setUseAIC(this.getUseAIC());
        logistic.buildClassifier(filteredData);
        return logistic.getNumRegressions();
    }

    public int getNumInnerNodes() {
        if (this.m_isLeaf) {
            return 0;
        }
        int numNodes = 1;
        LMTNode[] lMTNodeArray = this.m_sons;
        int n = this.m_sons.length;
        int n2 = 0;
        while (n2 < n) {
            LMTNode m_son = lMTNodeArray[n2];
            numNodes += m_son.getNumInnerNodes();
            ++n2;
        }
        return numNodes;
    }

    public int getNumLeaves() {
        int numLeaves;
        if (!this.m_isLeaf) {
            numLeaves = 0;
            int numEmptyLeaves = 0;
            int i = 0;
            while (i < this.m_sons.length) {
                numLeaves += this.m_sons[i].getNumLeaves();
                if (this.m_sons[i].m_isLeaf && !this.m_sons[i].hasModels()) {
                    ++numEmptyLeaves;
                }
                ++i;
            }
            if (numEmptyLeaves > 1) {
                numLeaves -= numEmptyLeaves - 1;
            }
        } else {
            numLeaves = 1;
        }
        return numLeaves;
    }

    public void treeErrors() {
        if (this.m_isLeaf) {
            this.m_numIncorrectTree = this.m_numIncorrectModel;
        } else {
            this.m_numIncorrectTree = 0.0;
            LMTNode[] lMTNodeArray = this.m_sons;
            int n = this.m_sons.length;
            int n2 = 0;
            while (n2 < n) {
                LMTNode m_son = lMTNodeArray[n2];
                m_son.treeErrors();
                this.m_numIncorrectTree += m_son.m_numIncorrectTree;
                ++n2;
            }
        }
    }

    public void calculateAlphas() throws Exception {
        if (!this.m_isLeaf) {
            double errorDiff = this.m_numIncorrectModel - this.m_numIncorrectTree;
            if (errorDiff <= 0.0) {
                this.m_isLeaf = true;
                this.m_sons = null;
                this.m_alpha = Double.MAX_VALUE;
            } else {
                this.m_alpha = (errorDiff /= this.m_totalInstanceWeight) / (double)(this.getNumLeaves() - 1);
                LMTNode[] lMTNodeArray = this.m_sons;
                int n = this.m_sons.length;
                int n2 = 0;
                while (n2 < n) {
                    LMTNode m_son = lMTNodeArray[n2];
                    m_son.calculateAlphas();
                    ++n2;
                }
            }
        } else {
            this.m_alpha = Double.MAX_VALUE;
        }
    }

    public Vector<LMTNode> getNodes() {
        Vector<LMTNode> nodeList = new Vector<LMTNode>();
        this.getNodes(nodeList);
        return nodeList;
    }

    public void getNodes(Vector<LMTNode> nodeList) {
        if (!this.m_isLeaf) {
            nodeList.add(this);
            LMTNode[] lMTNodeArray = this.m_sons;
            int n = this.m_sons.length;
            int n2 = 0;
            while (n2 < n) {
                LMTNode m_son = lMTNodeArray[n2];
                m_son.getNodes(nodeList);
                ++n2;
            }
        }
    }

    @Override
    protected Instances getNumericData(Instances train) throws Exception {
        Instances filteredData = Filter.useFilter(train, this.m_nominalToBinary);
        return super.getNumericData(filteredData);
    }

    public boolean hasModels() {
        return this.m_numRegressions > 0;
    }

    public double[] modelDistributionForInstance(Instance instance) throws Exception {
        this.m_nominalToBinary.input(instance);
        instance = this.m_nominalToBinary.output();
        instance.setDataset(this.m_numericDataHeader);
        return this.probs(this.getFs(instance));
    }

    @Override
    public double[] distributionForInstance(Instance instance) throws Exception {
        double[] probs;
        if (this.m_isLeaf) {
            probs = this.modelDistributionForInstance(instance);
        } else {
            int branch = this.m_localModel.whichSubset(instance);
            probs = this.m_sons[branch].distributionForInstance(instance);
        }
        return probs;
    }

    public int numLeaves() {
        if (this.m_isLeaf) {
            return 1;
        }
        int numLeaves = 0;
        LMTNode[] lMTNodeArray = this.m_sons;
        int n = this.m_sons.length;
        int n2 = 0;
        while (n2 < n) {
            LMTNode m_son = lMTNodeArray[n2];
            numLeaves += m_son.numLeaves();
            ++n2;
        }
        return numLeaves;
    }

    public int numNodes() {
        if (this.m_isLeaf) {
            return 1;
        }
        int numNodes = 1;
        LMTNode[] lMTNodeArray = this.m_sons;
        int n = this.m_sons.length;
        int n2 = 0;
        while (n2 < n) {
            LMTNode m_son = lMTNodeArray[n2];
            numNodes += m_son.numNodes();
            ++n2;
        }
        return numNodes;
    }

    @Override
    public String toString() {
        this.assignLeafModelNumbers(0);
        try {
            StringBuffer text = new StringBuffer();
            if (this.m_isLeaf) {
                text.append(": ");
                text.append("LM_" + this.m_leafModelNum + ":" + this.getModelParameters());
            } else {
                this.dumpTree(0, text);
            }
            text.append("\n\nNumber of Leaves  : \t" + this.numLeaves() + "\n");
            text.append("\nSize of the Tree : \t" + this.numNodes() + "\n");
            text.append(this.modelsToString());
            return text.toString();
        }
        catch (Exception e) {
            return "Can't print logistic model tree";
        }
    }

    public String getModelParameters() {
        StringBuffer text = new StringBuffer();
        int numModels = (int)this.m_numParameters;
        text.append(String.valueOf(this.m_numRegressions) + "/" + numModels + " (" + this.m_numInstances + ")");
        return text.toString();
    }

    protected void dumpTree(int depth, StringBuffer text) throws Exception {
        int i = 0;
        while (i < this.m_sons.length) {
            text.append("\n");
            int j = 0;
            while (j < depth) {
                text.append("|   ");
                ++j;
            }
            text.append(this.m_localModel.leftSide(this.m_train));
            text.append(this.m_localModel.rightSide(i, this.m_train));
            if (this.m_sons[i].m_isLeaf) {
                text.append(": ");
                text.append("LM_" + this.m_sons[i].m_leafModelNum + ":" + this.m_sons[i].getModelParameters());
            } else {
                this.m_sons[i].dumpTree(depth + 1, text);
            }
            ++i;
        }
    }

    public int assignIDs(int lastID) {
        int currLastID;
        this.m_id = currLastID = lastID + 1;
        if (this.m_sons != null) {
            LMTNode[] lMTNodeArray = this.m_sons;
            int n = this.m_sons.length;
            int n2 = 0;
            while (n2 < n) {
                LMTNode m_son = lMTNodeArray[n2];
                currLastID = m_son.assignIDs(currLastID);
                ++n2;
            }
        }
        return currLastID;
    }

    public int assignLeafModelNumbers(int leafCounter) {
        if (!this.m_isLeaf) {
            this.m_leafModelNum = 0;
            LMTNode[] lMTNodeArray = this.m_sons;
            int n = this.m_sons.length;
            int n2 = 0;
            while (n2 < n) {
                LMTNode m_son = lMTNodeArray[n2];
                leafCounter = m_son.assignLeafModelNumbers(leafCounter);
                ++n2;
            }
        } else {
            this.m_leafModelNum = ++leafCounter;
        }
        return leafCounter;
    }

    public String modelsToString() {
        StringBuffer text = new StringBuffer();
        if (this.m_isLeaf) {
            text.append("LM_" + this.m_leafModelNum + ":" + super.toString());
        } else {
            LMTNode[] lMTNodeArray = this.m_sons;
            int n = this.m_sons.length;
            int n2 = 0;
            while (n2 < n) {
                LMTNode m_son = lMTNodeArray[n2];
                text.append("\n" + m_son.modelsToString());
                ++n2;
            }
        }
        return text.toString();
    }

    public String graph() throws Exception {
        StringBuffer text = new StringBuffer();
        this.assignIDs(-1);
        this.assignLeafModelNumbers(0);
        text.append("digraph LMTree {\n");
        if (this.m_isLeaf) {
            text.append("N" + this.m_id + " [label=\"LM_" + this.m_leafModelNum + ":" + this.getModelParameters() + "\" " + "shape=box style=filled");
            text.append("]\n");
        } else {
            text.append("N" + this.m_id + " [label=\"" + Utils.backQuoteChars(this.m_localModel.leftSide(this.m_train)) + "\" ");
            text.append("]\n");
            this.graphTree(text);
        }
        return String.valueOf(text.toString()) + "}\n";
    }

    private void graphTree(StringBuffer text) throws Exception {
        int i = 0;
        while (i < this.m_sons.length) {
            text.append("N" + this.m_id + "->" + "N" + this.m_sons[i].m_id + " [label=\"" + Utils.backQuoteChars(this.m_localModel.rightSide(i, this.m_train).trim()) + "\"]\n");
            if (this.m_sons[i].m_isLeaf) {
                text.append("N" + this.m_sons[i].m_id + " [label=\"LM_" + this.m_sons[i].m_leafModelNum + ":" + this.m_sons[i].getModelParameters() + "\" " + "shape=box style=filled");
                text.append("]\n");
            } else {
                text.append("N" + this.m_sons[i].m_id + " [label=\"" + Utils.backQuoteChars(this.m_sons[i].m_localModel.leftSide(this.m_train)) + "\" ");
                text.append("]\n");
                this.m_sons[i].graphTree(text);
            }
            ++i;
        }
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 10401 $");
    }
}

