/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.functions;

import java.util.Collections;
import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.evaluation.RegressionAnalysis;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

public class SimpleLinearRegression
extends AbstractClassifier
implements WeightedInstancesHandler {
    static final long serialVersionUID = 1679336022895414137L;
    private Attribute m_attribute;
    private int m_attributeIndex;
    private double m_slope;
    private double m_intercept;
    protected boolean m_outputAdditionalStats;
    private int m_df;
    private double m_seSlope = Double.NaN;
    private double m_seIntercept = Double.NaN;
    private double m_tstatSlope = Double.NaN;
    private double m_tstatIntercept = Double.NaN;
    private double m_rsquared = Double.NaN;
    private double m_rsquaredAdj = Double.NaN;
    private double m_fstat = Double.NaN;
    private boolean m_suppressErrorMessage = false;

    public String globalInfo() {
        return "Learns a simple linear regression model. Picks the attribute that results in the lowest squared error. Missing values are not allowed. Can only deal with numeric attributes.";
    }

    public SimpleLinearRegression() {
    }

    public SimpleLinearRegression(Instances data, int attIndex, double slope, double intercept) {
        this.m_attributeIndex = attIndex;
        this.m_slope = slope;
        this.m_intercept = intercept;
        this.m_attribute = data.attribute(attIndex);
    }

    @Override
    public Enumeration<Option> listOptions() {
        Vector<Option> newVector = new Vector<Option>();
        newVector.addElement(new Option("\tOutput additional statistics.", "additional-stats", 0, "-additional-stats"));
        newVector.addAll(Collections.list(super.listOptions()));
        return newVector.elements();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        this.setOutputAdditionalStats(Utils.getFlag("additional-stats", options));
        super.setOptions(options);
        Utils.checkForRemainingOptions(options);
    }

    @Override
    public String[] getOptions() {
        Vector<String> result = new Vector<String>();
        if (this.getOutputAdditionalStats()) {
            result.add("-additional-stats");
        }
        Collections.addAll(result, super.getOptions());
        return result.toArray(new String[result.size()]);
    }

    public String outputAdditionalStatsTipText() {
        return "Output additional statistics (such as std deviation of coefficients and t-statistics)";
    }

    public void setOutputAdditionalStats(boolean additional) {
        this.m_outputAdditionalStats = additional;
    }

    public boolean getOutputAdditionalStats() {
        return this.m_outputAdditionalStats;
    }

    public void addModel(SimpleLinearRegression slr) throws Exception {
        if (this.m_attribute == null || slr.m_attributeIndex == this.m_attributeIndex) {
            this.m_attributeIndex = slr.m_attributeIndex;
            this.m_attribute = slr.m_attribute;
            this.m_slope += slr.m_slope;
            this.m_intercept += slr.m_intercept;
        } else {
            throw new Exception("Could not add models. " + this.m_attributeIndex + " " + slr.m_attributeIndex + " " + this.m_attribute + " " + slr.m_attribute + " " + this.m_slope + " " + slr.m_slope + " " + this.m_intercept + " " + slr.m_intercept);
        }
    }

    @Override
    public double classifyInstance(Instance inst) throws Exception {
        if (this.m_attribute == null) {
            return this.m_intercept;
        }
        if (inst.isMissing(this.m_attribute.index())) {
            throw new Exception("SimpleLinearRegression: No missing values!");
        }
        return this.m_intercept + this.m_slope * inst.value(this.m_attribute.index());
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capabilities.Capability.DATE_ATTRIBUTES);
        result.enable(Capabilities.Capability.NUMERIC_CLASS);
        result.enable(Capabilities.Capability.DATE_CLASS);
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        return result;
    }

    protected double[] computeMeans(Instances insts) {
        double[] means = new double[insts.numAttributes()];
        double[] counts = new double[insts.numAttributes()];
        int j = 0;
        while (j < insts.numInstances()) {
            Instance inst = insts.instance(j);
            if (!inst.classIsMissing()) {
                int i = 0;
                while (i < insts.numAttributes()) {
                    int n = i;
                    means[n] = means[n] + inst.weight() * inst.value(i);
                    int n2 = i++;
                    counts[n2] = counts[n2] + inst.weight();
                }
            }
            ++j;
        }
        int i = 0;
        while (i < insts.numAttributes()) {
            if (counts[i] > 0.0) {
                int n = i;
                means[n] = means[n] / counts[i];
            } else {
                means[i] = 0.0;
            }
            ++i;
        }
        return means;
    }

    @Override
    public void buildClassifier(Instances insts) throws Exception {
        this.getCapabilities().testWithFail(insts);
        double[] means = this.computeMeans(insts);
        double[] slopes = new double[insts.numAttributes()];
        double[] sumWeightedDiffsSquared = new double[insts.numAttributes()];
        int classIndex = insts.classIndex();
        int j = 0;
        while (j < insts.numInstances()) {
            Instance inst = insts.instance(j);
            if (!inst.classIsMissing()) {
                double yDiff = inst.value(classIndex) - means[classIndex];
                double weightedYDiff = inst.weight() * yDiff;
                int i = 0;
                while (i < insts.numAttributes()) {
                    double diff = inst.value(i) - means[i];
                    double weightedDiff = inst.weight() * diff;
                    int n = i;
                    slopes[n] = slopes[n] + weightedYDiff * diff;
                    int n2 = i++;
                    sumWeightedDiffsSquared[n2] = sumWeightedDiffsSquared[n2] + weightedDiff * diff;
                }
            }
            ++j;
        }
        double minSSE = Double.MAX_VALUE;
        this.m_attribute = null;
        int chosen = -1;
        double chosenSlope = Double.NaN;
        double chosenIntercept = Double.NaN;
        int i = 0;
        while (i < insts.numAttributes()) {
            if (i != classIndex && sumWeightedDiffsSquared[i] != 0.0) {
                double numerator = slopes[i];
                int n = i;
                slopes[n] = slopes[n] / sumWeightedDiffsSquared[i];
                double intercept = means[classIndex] - slopes[i] * means[i];
                double sse = sumWeightedDiffsSquared[classIndex] - slopes[i] * numerator;
                if (sse < minSSE) {
                    minSSE = sse;
                    chosen = i;
                    chosenSlope = slopes[i];
                    chosenIntercept = intercept;
                }
            }
            ++i;
        }
        if (chosen == -1) {
            if (!this.m_suppressErrorMessage) {
                System.err.println("----- no useful attribute found");
            }
            this.m_attribute = null;
            this.m_attributeIndex = 0;
            this.m_slope = 0.0;
            this.m_intercept = means[classIndex];
        } else {
            this.m_attribute = insts.attribute(chosen);
            this.m_attributeIndex = chosen;
            this.m_slope = chosenSlope;
            this.m_intercept = chosenIntercept;
            if (this.m_outputAdditionalStats) {
                this.m_df = insts.numInstances() - 2;
                double[] stdErrors = RegressionAnalysis.calculateStdErrorOfCoef(insts, this.m_attribute, this.m_slope, this.m_intercept, this.m_df);
                this.m_seSlope = stdErrors[0];
                this.m_seIntercept = stdErrors[1];
                double[] coef = new double[]{this.m_slope, this.m_intercept};
                double[] tStats = RegressionAnalysis.calculateTStats(coef, stdErrors, 2);
                this.m_tstatSlope = tStats[0];
                this.m_tstatIntercept = tStats[1];
                double ssr = RegressionAnalysis.calculateSSR(insts, this.m_attribute, this.m_slope, this.m_intercept);
                this.m_rsquared = RegressionAnalysis.calculateRSquared(insts, ssr);
                this.m_rsquaredAdj = RegressionAnalysis.calculateAdjRSquared(this.m_rsquared, insts.numInstances(), 2);
                this.m_fstat = RegressionAnalysis.calculateFStat(this.m_rsquared, insts.numInstances(), 2);
            }
        }
    }

    public boolean foundUsefulAttribute() {
        return this.m_attribute != null;
    }

    public int getAttributeIndex() {
        return this.m_attributeIndex;
    }

    public double getSlope() {
        return this.m_slope;
    }

    public double getIntercept() {
        return this.m_intercept;
    }

    public void setSuppressErrorMessage(boolean s) {
        this.m_suppressErrorMessage = s;
    }

    public String toString() {
        StringBuffer text = new StringBuffer();
        if (this.m_attribute == null) {
            text.append("Predicting constant " + this.m_intercept);
        } else {
            text.append("Linear regression on " + this.m_attribute.name() + "\n\n");
            text.append(String.valueOf(Utils.doubleToString(this.m_slope, 2)) + " * " + this.m_attribute.name());
            if (this.m_intercept > 0.0) {
                text.append(" + " + Utils.doubleToString(this.m_intercept, 2));
            } else {
                text.append(" - " + Utils.doubleToString(-this.m_intercept, 2));
            }
            if (this.m_outputAdditionalStats) {
                int attNameLength = this.m_attribute.name().length() + 3;
                if (attNameLength < "Variable".length() + 3) {
                    attNameLength = "Variable".length() + 3;
                }
                text.append("\n\nRegression Analysis:\n\n" + Utils.padRight("Variable", attNameLength) + "  Coefficient     SE of Coef        t-Stat");
                text.append("\n" + Utils.padRight(this.m_attribute.name(), attNameLength));
                text.append(Utils.doubleToString(this.m_slope, 12, 4));
                text.append("   " + Utils.doubleToString(this.m_seSlope, 12, 5));
                text.append("   " + Utils.doubleToString(this.m_tstatSlope, 12, 5));
                text.append(String.valueOf(Utils.padRight("\nconst", attNameLength + 1)) + Utils.doubleToString(this.m_intercept, 12, 4));
                text.append("   " + Utils.doubleToString(this.m_seIntercept, 12, 5));
                text.append("   " + Utils.doubleToString(this.m_tstatIntercept, 12, 5));
                text.append("\n\nDegrees of freedom = " + Integer.toString(this.m_df));
                text.append("\nR^2 value = " + Utils.doubleToString(this.m_rsquared, 5));
                text.append("\nAdjusted R^2 = " + Utils.doubleToString(this.m_rsquaredAdj, 5));
                text.append("\nF-statistic = " + Utils.doubleToString(this.m_fstat, 5));
            }
        }
        text.append("\n");
        return text.toString();
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 10169 $");
    }

    public static void main(String[] argv) {
        SimpleLinearRegression.runClassifier(new SimpleLinearRegression(), argv);
    }
}

