/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.functions;

import java.util.Collections;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.Random;
import java.util.Vector;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import weka.classifiers.Classifier;
import weka.classifiers.RandomizableClassifier;
import weka.classifiers.rules.ZeroR;
import weka.clusterers.SimpleKMeans;
import weka.core.Capabilities;
import weka.core.ConjugateGradientOptimization;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Optimization;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.NominalToBinary;
import weka.filters.unsupervised.attribute.Normalize;
import weka.filters.unsupervised.attribute.Remove;
import weka.filters.unsupervised.attribute.RemoveUseless;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;

public abstract class RBFModel
extends RandomizableClassifier {
    private static final long serialVersionUID = -7847473336438394611L;
    public static final int USE_GLOBAL_SCALE = 1;
    public static final int USE_SCALE_PER_UNIT = 2;
    public static final int USE_SCALE_PER_UNIT_AND_ATTRIBUTE = 3;
    public static final Tag[] TAGS_SCALE = new Tag[]{new Tag(1, "Use global scale"), new Tag(2, "Use scale per unit"), new Tag(3, "Use scale per unit and attribute")};
    protected int m_scaleOptimizationOption = 2;
    protected int m_numUnits = 2;
    protected int m_classIndex = -1;
    protected Instances m_data = null;
    protected int m_numAttributes = -1;
    protected double[] m_RBFParameters = null;
    protected double m_ridge = 0.01;
    protected boolean m_useCGD = false;
    protected boolean m_useNormalizedBasisFunctions = false;
    protected boolean m_useAttributeWeights = false;
    protected double m_tolerance = 1.0E-6;
    protected int m_numThreads = 1;
    protected int m_poolSize = 1;
    protected Filter m_Filter = null;
    protected int OFFSET_WEIGHTS = -1;
    protected int OFFSET_SCALES = -1;
    protected int OFFSET_CENTERS = -1;
    protected int OFFSET_ATTRIBUTE_WEIGHTS = -1;
    protected RemoveUseless m_AttFilter;
    protected NominalToBinary m_NominalToBinary;
    protected ReplaceMissingValues m_ReplaceMissingValues;
    protected Classifier m_ZeroR;
    protected transient ExecutorService m_Pool = null;
    protected int m_numClasses = -1;
    protected double m_x1 = 1.0;
    protected double m_x0 = 0.0;

    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capabilities.Capability.DATE_ATTRIBUTES);
        result.enable(Capabilities.Capability.MISSING_VALUES);
        return result;
    }

    protected Instances initializeClassifier(Instances data) throws Exception {
        int i;
        int index;
        this.getCapabilities().testWithFail(data);
        data = new Instances(data);
        data.deleteWithMissingClass();
        Random random = new Random(this.m_Seed);
        if (data.numInstances() > 2) {
            random = data.getRandomNumberGenerator((long)this.m_Seed);
        }
        data.randomize(random);
        double y0 = data.instance(0).classValue();
        for (index = 1; index < data.numInstances() && data.instance(index).classValue() == y0; ++index) {
        }
        if (index == data.numInstances()) {
            throw new Exception("All class values are the same. At least two class values should be different");
        }
        double y1 = data.instance(index).classValue();
        this.m_ReplaceMissingValues = new ReplaceMissingValues();
        this.m_ReplaceMissingValues.setInputFormat(data);
        data = Filter.useFilter((Instances)data, (Filter)this.m_ReplaceMissingValues);
        this.m_AttFilter = new RemoveUseless();
        this.m_AttFilter.setInputFormat(data);
        data = Filter.useFilter((Instances)data, (Filter)this.m_AttFilter);
        if (data.numAttributes() == 1) {
            System.err.println("Cannot build model (only class attribute present in data after removing useless attributes!), using ZeroR model instead!");
            this.m_ZeroR = new ZeroR();
            this.m_ZeroR.buildClassifier(data);
            return data;
        }
        this.m_ZeroR = null;
        this.m_NominalToBinary = new NominalToBinary();
        this.m_NominalToBinary.setInputFormat(data);
        data = Filter.useFilter((Instances)data, (Filter)this.m_NominalToBinary);
        this.m_Filter = new Normalize();
        ((Normalize)this.m_Filter).setIgnoreClass(true);
        this.m_Filter.setInputFormat(data);
        data = Filter.useFilter((Instances)data, (Filter)this.m_Filter);
        double z0 = data.instance(0).classValue();
        double z1 = data.instance(index).classValue();
        this.m_x1 = (y0 - y1) / (z0 - z1);
        this.m_x0 = y0 - this.m_x1 * z0;
        this.m_classIndex = data.classIndex();
        this.m_numClasses = data.numClasses();
        this.m_numAttributes = data.numAttributes();
        SimpleKMeans skm = new SimpleKMeans();
        skm.setMaxIterations(10000);
        skm.setNumClusters(this.m_numUnits);
        Remove rm = new Remove();
        data.setClassIndex(-1);
        rm.setAttributeIndices(this.m_classIndex + 1 + "");
        rm.setInputFormat(data);
        Instances dataRemoved = Filter.useFilter((Instances)data, (Filter)rm);
        data.setClassIndex(this.m_classIndex);
        skm.buildClusterer(dataRemoved);
        Instances centers = skm.getClusterCentroids();
        if (centers.numInstances() < this.m_numUnits) {
            this.m_numUnits = centers.numInstances();
        }
        this.OFFSET_WEIGHTS = 0;
        if (this.m_useAttributeWeights) {
            this.OFFSET_ATTRIBUTE_WEIGHTS = (this.m_numUnits + 1) * this.m_numClasses;
            this.OFFSET_CENTERS = this.OFFSET_ATTRIBUTE_WEIGHTS + this.m_numAttributes;
        } else {
            this.OFFSET_ATTRIBUTE_WEIGHTS = -1;
            this.OFFSET_CENTERS = (this.m_numUnits + 1) * this.m_numClasses;
        }
        this.OFFSET_SCALES = this.OFFSET_CENTERS + this.m_numUnits * this.m_numAttributes;
        switch (this.m_scaleOptimizationOption) {
            case 1: {
                this.m_RBFParameters = new double[this.OFFSET_SCALES + 1];
                break;
            }
            case 3: {
                this.m_RBFParameters = new double[this.OFFSET_SCALES + this.m_numUnits * this.m_numAttributes];
                break;
            }
            default: {
                this.m_RBFParameters = new double[this.OFFSET_SCALES + this.m_numUnits];
            }
        }
        double maxMinDist = -1.0;
        for (i = 0; i < centers.numInstances(); ++i) {
            double minDist = Double.MAX_VALUE;
            for (int j = i + 1; j < centers.numInstances(); ++j) {
                double dist = 0.0;
                for (int k = 0; k < centers.numAttributes(); ++k) {
                    if (k == centers.classIndex()) continue;
                    double diff = centers.instance(i).value(k) - centers.instance(j).value(k);
                    dist += diff * diff;
                }
                if (!(dist < minDist)) continue;
                minDist = dist;
            }
            if (minDist == Double.MAX_VALUE || !(minDist > maxMinDist)) continue;
            maxMinDist = minDist;
        }
        if (this.m_scaleOptimizationOption == 1) {
            this.m_RBFParameters[this.OFFSET_SCALES] = Math.sqrt(maxMinDist);
        }
        for (i = 0; i < this.m_numUnits; ++i) {
            if (this.m_scaleOptimizationOption == 2) {
                this.m_RBFParameters[this.OFFSET_SCALES + i] = Math.sqrt(maxMinDist);
            }
            int k = 0;
            for (int j = 0; j < this.m_numAttributes; ++j) {
                if (k == centers.classIndex()) {
                    ++k;
                }
                if (j == data.classIndex()) continue;
                if (this.m_scaleOptimizationOption == 3) {
                    this.m_RBFParameters[this.OFFSET_SCALES + (i * this.m_numAttributes + j)] = Math.sqrt(maxMinDist);
                }
                this.m_RBFParameters[this.OFFSET_CENTERS + i * this.m_numAttributes + j] = centers.instance(i).value(k);
                ++k;
            }
        }
        if (this.m_useAttributeWeights) {
            for (int j = 0; j < this.m_numAttributes; ++j) {
                if (j == data.classIndex()) continue;
                this.m_RBFParameters[this.OFFSET_ATTRIBUTE_WEIGHTS + j] = 1.0;
            }
        }
        this.initializeOutputLayer(random);
        return data;
    }

    protected abstract void initializeOutputLayer(Random var1);

    public void buildClassifier(Instances data) throws Exception {
        this.m_data = this.initializeClassifier(data);
        if (this.m_ZeroR != null) {
            return;
        }
        this.m_Pool = Executors.newFixedThreadPool(this.m_poolSize);
        Object opt = null;
        opt = !this.m_useCGD ? new OptEng() : new OptEngCGD();
        opt.setDebug(this.m_Debug);
        double[][] b = new double[2][this.m_RBFParameters.length];
        for (int i = 0; i < 2; ++i) {
            for (int j = 0; j < this.m_RBFParameters.length; ++j) {
                b[i][j] = Double.NaN;
            }
        }
        this.m_RBFParameters = opt.findArgmin(this.m_RBFParameters, b);
        while (this.m_RBFParameters == null) {
            this.m_RBFParameters = opt.getVarbValues();
            if (this.m_Debug) {
                System.out.println("200 iterations finished, not enough!");
            }
            this.m_RBFParameters = opt.findArgmin(this.m_RBFParameters, b);
        }
        if (this.m_Debug) {
            System.out.println("SE (normalized space) after optimization: " + opt.getMinFunction());
        }
        this.m_data = new Instances(this.m_data, 0);
        this.m_Pool.shutdown();
    }

    protected abstract double calculateError(double[] var1, Instance var2);

    protected abstract double postprocessError(double var1);

    protected abstract void postprocessGradient(double[] var1);

    protected double calculateSE() {
        int chunksize = this.m_data.numInstances() / this.m_numThreads;
        HashSet<Future<Double>> results = new HashSet<Future<Double>>();
        for (int j = 0; j < this.m_numThreads; ++j) {
            final int lo = j * chunksize;
            final int hi = j < this.m_numThreads - 1 ? lo + chunksize : this.m_data.numInstances();
            Future<Double> futureSE = this.m_Pool.submit(new Callable<Double>(){

                @Override
                public Double call() {
                    double[] outputs = new double[RBFModel.this.m_numUnits];
                    double SE = 0.0;
                    for (int k = lo; k < hi; ++k) {
                        Instance inst = RBFModel.this.m_data.instance(k);
                        RBFModel.this.calculateOutputs(inst, outputs, null);
                        SE += RBFModel.this.calculateError(outputs, inst);
                    }
                    return SE;
                }
            });
            results.add(futureSE);
        }
        double SE = 0.0;
        try {
            for (Future<Double> futureSE : results) {
                SE += ((Double)futureSE.get()).doubleValue();
            }
        }
        catch (Exception e) {
            System.out.println("Squared error could not be calculated.");
        }
        return this.postprocessError(0.5 * SE);
    }

    protected double[] calculateGradient() {
        int chunksize = this.m_data.numInstances() / this.m_numThreads;
        HashSet<Future<double[]>> results = new HashSet<Future<double[]>>();
        for (int j = 0; j < this.m_numThreads; ++j) {
            final int lo = j * chunksize;
            final int n = j < this.m_numThreads - 1 ? lo + chunksize : this.m_data.numInstances();
            Future<double[]> futureGrad = this.m_Pool.submit(new Callable<double[]>(){

                @Override
                public double[] call() {
                    double[] outputs = new double[RBFModel.this.m_numUnits];
                    double[] deltaHidden = new double[RBFModel.this.m_numUnits];
                    double[] derivativesOutput = new double[1];
                    double[] derivativesHidden = new double[RBFModel.this.m_numUnits];
                    double[] localGrad = new double[RBFModel.this.m_RBFParameters.length];
                    for (int k = lo; k < n; ++k) {
                        Instance inst = RBFModel.this.m_data.instance(k);
                        RBFModel.this.calculateOutputs(inst, outputs, derivativesHidden);
                        RBFModel.this.updateGradient(localGrad, inst, outputs, derivativesOutput, deltaHidden);
                        RBFModel.this.updateGradientForHiddenUnits(localGrad, inst, derivativesHidden, deltaHidden);
                    }
                    return localGrad;
                }
            });
            results.add(futureGrad);
        }
        double[] grad = new double[this.m_RBFParameters.length];
        try {
            for (Future future : results) {
                double[] lg = (double[])future.get();
                for (int i = 0; i < lg.length; ++i) {
                    int n = i;
                    grad[n] = grad[n] + lg[i];
                }
            }
        }
        catch (Exception e) {
            System.out.println("Gradient could not be calculated.");
        }
        this.postprocessGradient(grad);
        return grad;
    }

    protected abstract void updateGradient(double[] var1, Instance var2, double[] var3, double[] var4, double[] var5);

    protected void updateGradientForHiddenUnits(double[] grad, Instance inst, double[] derivativesHidden, double[] deltaHidden) {
        int i;
        for (i = 0; i < this.m_numUnits; ++i) {
            int n = i;
            deltaHidden[n] = deltaHidden[n] * derivativesHidden[i];
        }
        block5: for (i = 0; i < this.m_numUnits; ++i) {
            if (deltaHidden[i] <= this.m_tolerance && deltaHidden[i] >= -this.m_tolerance) continue;
            switch (this.m_scaleOptimizationOption) {
                case 1: {
                    int n = this.OFFSET_SCALES;
                    grad[n] = grad[n] + this.derivativeOneScale(grad, deltaHidden, this.m_RBFParameters[this.OFFSET_SCALES], inst, i);
                    continue block5;
                }
                case 3: {
                    this.derivativeScalePerAttribute(grad, deltaHidden, inst, i);
                    continue block5;
                }
                default: {
                    int n = this.OFFSET_SCALES + i;
                    grad[n] = grad[n] + this.derivativeOneScale(grad, deltaHidden, this.m_RBFParameters[this.OFFSET_SCALES + i], inst, i);
                }
            }
        }
    }

    protected void derivativeScalePerAttribute(double[] grad, double[] deltaHidden, Instance inst, int unitIndex) {
        double scalePart;
        double diff;
        int j;
        double constant = deltaHidden[unitIndex];
        int offsetC = this.OFFSET_CENTERS + unitIndex * this.m_numAttributes;
        int offsetS = this.OFFSET_SCALES + unitIndex * this.m_numAttributes;
        double attWeight = 1.0;
        for (j = 0; j < this.m_classIndex; ++j) {
            diff = inst.value(j) - this.m_RBFParameters[offsetC + j];
            scalePart = this.m_RBFParameters[offsetS + j] * this.m_RBFParameters[offsetS + j];
            if (this.m_useAttributeWeights) {
                attWeight = this.m_RBFParameters[this.OFFSET_ATTRIBUTE_WEIGHTS + j] * this.m_RBFParameters[this.OFFSET_ATTRIBUTE_WEIGHTS + j];
                int n = this.OFFSET_ATTRIBUTE_WEIGHTS + j;
                grad[n] = grad[n] - this.m_RBFParameters[this.OFFSET_ATTRIBUTE_WEIGHTS + j] * constant * diff * diff / scalePart;
            }
            int n = offsetS + j;
            grad[n] = grad[n] + constant * attWeight * diff * diff / (scalePart * this.m_RBFParameters[offsetS + j]);
            int n2 = offsetC + j;
            grad[n2] = grad[n2] + constant * attWeight * diff / scalePart;
        }
        for (j = this.m_classIndex + 1; j < this.m_numAttributes; ++j) {
            diff = inst.value(j) - this.m_RBFParameters[offsetC + j];
            scalePart = this.m_RBFParameters[offsetS + j] * this.m_RBFParameters[offsetS + j];
            if (this.m_useAttributeWeights) {
                attWeight = this.m_RBFParameters[this.OFFSET_ATTRIBUTE_WEIGHTS + j] * this.m_RBFParameters[this.OFFSET_ATTRIBUTE_WEIGHTS + j];
                int n = this.OFFSET_ATTRIBUTE_WEIGHTS + j;
                grad[n] = grad[n] - this.m_RBFParameters[this.OFFSET_ATTRIBUTE_WEIGHTS + j] * constant * diff * diff / scalePart;
            }
            int n = offsetS + j;
            grad[n] = grad[n] + constant * attWeight * diff * diff / (scalePart * this.m_RBFParameters[offsetS + j]);
            int n3 = offsetC + j;
            grad[n3] = grad[n3] + constant * attWeight * diff / scalePart;
        }
    }

    protected double derivativeOneScale(double[] grad, double[] deltaHidden, double scale, Instance inst, int unitIndex) {
        double diffSquared;
        double diff;
        int j;
        double constant = deltaHidden[unitIndex] / (scale * scale);
        double sumDiffSquared = 0.0;
        int offsetC = this.OFFSET_CENTERS + unitIndex * this.m_numAttributes;
        double attWeight = 1.0;
        for (j = 0; j < this.m_classIndex; ++j) {
            diff = inst.value(j) - this.m_RBFParameters[offsetC + j];
            diffSquared = diff * diff;
            if (this.m_useAttributeWeights) {
                attWeight = this.m_RBFParameters[this.OFFSET_ATTRIBUTE_WEIGHTS + j] * this.m_RBFParameters[this.OFFSET_ATTRIBUTE_WEIGHTS + j];
                int n = this.OFFSET_ATTRIBUTE_WEIGHTS + j;
                grad[n] = grad[n] - this.m_RBFParameters[this.OFFSET_ATTRIBUTE_WEIGHTS + j] * constant * diffSquared;
            }
            sumDiffSquared += attWeight * diffSquared;
            int n = offsetC + j;
            grad[n] = grad[n] + constant * attWeight * diff;
        }
        for (j = this.m_classIndex + 1; j < this.m_numAttributes; ++j) {
            diff = inst.value(j) - this.m_RBFParameters[offsetC + j];
            diffSquared = diff * diff;
            if (this.m_useAttributeWeights) {
                attWeight = this.m_RBFParameters[this.OFFSET_ATTRIBUTE_WEIGHTS + j] * this.m_RBFParameters[this.OFFSET_ATTRIBUTE_WEIGHTS + j];
                int n = this.OFFSET_ATTRIBUTE_WEIGHTS + j;
                grad[n] = grad[n] - this.m_RBFParameters[this.OFFSET_ATTRIBUTE_WEIGHTS + j] * constant * diffSquared;
            }
            sumDiffSquared += attWeight * diffSquared;
            int n = offsetC + j;
            grad[n] = grad[n] + constant * attWeight * diff;
        }
        return constant * sumDiffSquared / scale;
    }

    protected void calculateOutputs(Instance inst, double[] o, double[] d) {
        for (int i = 0; i < this.m_numUnits; ++i) {
            double sumSquaredDiff = 0.0;
            switch (this.m_scaleOptimizationOption) {
                case 1: {
                    sumSquaredDiff = this.sumSquaredDiffOneScale(this.m_RBFParameters[this.OFFSET_SCALES], inst, i);
                    break;
                }
                case 3: {
                    sumSquaredDiff = this.sumSquaredDiffScalePerAttribute(inst, i);
                    break;
                }
                default: {
                    sumSquaredDiff = this.sumSquaredDiffOneScale(this.m_RBFParameters[this.OFFSET_SCALES + i], inst, i);
                }
            }
            if (!this.m_useNormalizedBasisFunctions) {
                o[i] = Math.exp(-sumSquaredDiff);
                if (d == null) continue;
                d[i] = o[i];
                continue;
            }
            o[i] = -sumSquaredDiff;
        }
        if (this.m_useNormalizedBasisFunctions) {
            int i;
            double max = o[Utils.maxIndex((double[])o)];
            double sum = 0.0;
            for (i = 0; i < o.length; ++i) {
                o[i] = Math.exp(o[i] - max);
                sum += o[i];
            }
            i = 0;
            while (i < o.length) {
                int n = i++;
                o[n] = o[n] / sum;
            }
            if (d != null) {
                for (i = 0; i < o.length; ++i) {
                    d[i] = o[i] * (1.0 - o[i]);
                }
            }
        }
    }

    protected double sumSquaredDiffScalePerAttribute(Instance inst, int unitIndex) {
        double diff;
        int j;
        int offsetS = this.OFFSET_SCALES + unitIndex * this.m_numAttributes;
        int offsetC = this.OFFSET_CENTERS + unitIndex * this.m_numAttributes;
        double sumSquaredDiff = 0.0;
        for (j = 0; j < this.m_classIndex; ++j) {
            diff = this.m_RBFParameters[offsetC + j] - inst.value(j);
            if (this.m_useAttributeWeights) {
                diff *= this.m_RBFParameters[this.OFFSET_ATTRIBUTE_WEIGHTS + j];
            }
            sumSquaredDiff += diff * diff / (2.0 * this.m_RBFParameters[offsetS + j] * this.m_RBFParameters[offsetS + j]);
        }
        for (j = this.m_classIndex + 1; j < this.m_numAttributes; ++j) {
            diff = this.m_RBFParameters[offsetC + j] - inst.value(j);
            if (this.m_useAttributeWeights) {
                diff *= this.m_RBFParameters[this.OFFSET_ATTRIBUTE_WEIGHTS + j];
            }
            sumSquaredDiff += diff * diff / (2.0 * this.m_RBFParameters[offsetS + j] * this.m_RBFParameters[offsetS + j]);
        }
        return sumSquaredDiff;
    }

    protected double sumSquaredDiffOneScale(double scale, Instance inst, int unitIndex) {
        double diff;
        int j;
        int offsetC = this.OFFSET_CENTERS + unitIndex * this.m_numAttributes;
        double sumSquaredDiff = 0.0;
        for (j = 0; j < this.m_classIndex; ++j) {
            diff = this.m_RBFParameters[offsetC + j] - inst.value(j);
            if (this.m_useAttributeWeights) {
                diff *= this.m_RBFParameters[this.OFFSET_ATTRIBUTE_WEIGHTS + j];
            }
            sumSquaredDiff += diff * diff;
        }
        for (j = this.m_classIndex + 1; j < this.m_numAttributes; ++j) {
            diff = this.m_RBFParameters[offsetC + j] - inst.value(j);
            if (this.m_useAttributeWeights) {
                diff *= this.m_RBFParameters[this.OFFSET_ATTRIBUTE_WEIGHTS + j];
            }
            sumSquaredDiff += diff * diff;
        }
        return sumSquaredDiff / (2.0 * scale * scale);
    }

    protected abstract double[] getDistribution(double[] var1);

    public double[] distributionForInstance(Instance inst) throws Exception {
        this.m_ReplaceMissingValues.input(inst);
        inst = this.m_ReplaceMissingValues.output();
        this.m_AttFilter.input(inst);
        inst = this.m_AttFilter.output();
        if (this.m_ZeroR != null) {
            return this.m_ZeroR.distributionForInstance(inst);
        }
        this.m_NominalToBinary.input(inst);
        inst = this.m_NominalToBinary.output();
        this.m_Filter.input(inst);
        inst = this.m_Filter.output();
        double[] outputs = new double[this.m_numUnits];
        this.calculateOutputs(inst, outputs, null);
        return this.getDistribution(outputs);
    }

    public String globalInfo() {
        return "Class implementing radial basis function networks for classification, trained in a fully supervised manner using WEKA's Optimization class by minimizing squared error with the BFGS method. Note that all attributes are normalized into the [0,1] scale. The initial centers for the Gaussian radial basis functions are found using WEKA's SimpleKMeans. The initial sigma values are set to the maximum distance between any center and its nearest neighbour in the set of centers. There are several parameters. The ridge parameter is used to penalize the size of the weights in the output layer. The number of basis functions can also be specified. Note that large numbers produce long training times. Another option determines whether one global sigma value is used for all units (fastest), whether one value is used per unit (common practice, it seems, and set as the default), or a different value is learned for every unit/attribute combination. It is also possible to learn attribute weights for the distance function. (The square of the value shown in the output is used.)  Finally, it is possible to use conjugate gradient descent rather than BFGS updates, which can be faster for cases with many parameters, and to use normalized basis functions instead of unnormalized ones. To improve speed, an approximate version of the logistic function is used as the activation function in the output layer. Also, if delta values in the backpropagation step are  within the user-specified tolerance, the gradient is not updated for that particular instance, which saves some additional time. Paralled calculation of squared error and gradient is possible when multiple CPU cores are present. Data is split into batches and processed in separate threads in this case. Note that this only improves runtime for larger datasets. Nominal attributes are processed using the unsupervised  NominalToBinary filter and missing values are replaced globally using ReplaceMissingValues.";
    }

    public String toleranceTipText() {
        return "The tolerance parameter for the delta values.";
    }

    public double getTolerance() {
        return this.m_tolerance;
    }

    public void setTolerance(double newTolerance) {
        this.m_tolerance = newTolerance;
    }

    public String numFunctionsTipText() {
        return "The number of basis functions to use.";
    }

    public int getNumFunctions() {
        return this.m_numUnits;
    }

    public void setNumFunctions(int newNumFunctions) {
        this.m_numUnits = newNumFunctions;
    }

    public String ridgeTipText() {
        return "The ridge penalty factor for the output layer.";
    }

    public double getRidge() {
        return this.m_ridge;
    }

    public void setRidge(double newRidge) {
        this.m_ridge = newRidge;
    }

    public String useCGDTipText() {
        return "Whether to use conjugate gradient descent (recommended for many parameters).";
    }

    public boolean getUseCGD() {
        return this.m_useCGD;
    }

    public void setUseCGD(boolean newUseCGD) {
        this.m_useCGD = newUseCGD;
    }

    public String useAttributeWeightsTipText() {
        return "Whether to use attribute weights.";
    }

    public boolean getUseAttributeWeights() {
        return this.m_useAttributeWeights;
    }

    public void setUseAttributeWeights(boolean newUseAttributeWeights) {
        this.m_useAttributeWeights = newUseAttributeWeights;
    }

    public String useNormalizedBasisFunctionsTipText() {
        return "Whether to use normalized basis functions.";
    }

    public boolean getUseNormalizedBasisFunctions() {
        return this.m_useNormalizedBasisFunctions;
    }

    public void setUseNormalizedBasisFunctions(boolean newUseNormalizedBasisFunctions) {
        this.m_useNormalizedBasisFunctions = newUseNormalizedBasisFunctions;
    }

    public String scaleOptimizationOptionTipText() {
        return "The number of sigma parameters to use.";
    }

    public SelectedTag getScaleOptimizationOption() {
        return new SelectedTag(this.m_scaleOptimizationOption, TAGS_SCALE);
    }

    public void setScaleOptimizationOption(SelectedTag newMethod) {
        if (newMethod.getTags() == TAGS_SCALE) {
            this.m_scaleOptimizationOption = newMethod.getSelectedTag().getID();
        }
    }

    public String numThreadsTipText() {
        return "The number of threads to use, which should be >= size of thread pool.";
    }

    public int getNumThreads() {
        return this.m_numThreads;
    }

    public void setNumThreads(int nT) {
        this.m_numThreads = nT;
    }

    public String poolSizeTipText() {
        return "The size of the thread pool, for example, the number of cores in the CPU.";
    }

    public int getPoolSize() {
        return this.m_poolSize;
    }

    public void setPoolSize(int nT) {
        this.m_poolSize = nT;
    }

    public Enumeration<Option> listOptions() {
        Vector<Object> newVector = new Vector<Object>(9);
        newVector.addElement(new Option("\tNumber of Gaussian basis functions (default is 2).\n", "N", 1, "-N <int>"));
        newVector.addElement(new Option("\tRidge factor for quadratic penalty on output weights (default is 0.01).\n", "R", 1, "-R <double>"));
        newVector.addElement(new Option("\tTolerance parameter for delta values (default is 1.0e-6).\n", "L", 1, "-L <double>"));
        newVector.addElement(new Option("\tThe scale optimization option: global scale (1), one scale per unit (2), scale per unit and attribute (3) (default is 2).\n", "C", 1, "-C <1|2|3>"));
        newVector.addElement(new Option("\tUse conjugate gradient descent (recommended for many attributes).\n", "G", 0, "-G"));
        newVector.addElement(new Option("\tUse normalized basis functions.\n", "O", 0, "-O"));
        newVector.addElement(new Option("\tUse attribute weights.\n", "A", 0, "-A"));
        newVector.addElement(new Option("\t" + this.poolSizeTipText() + " (default 1)\n", "P", 1, "-P <int>"));
        newVector.addElement(new Option("\t" + this.numThreadsTipText() + " (default 1)\n", "E", 1, "-E <int>"));
        newVector.addAll(Collections.list(super.listOptions()));
        return newVector.elements();
    }

    public void setOptions(String[] options) throws Exception {
        String numFunctions = Utils.getOption((char)'N', (String[])options);
        if (numFunctions.length() != 0) {
            this.setNumFunctions(Integer.parseInt(numFunctions));
        } else {
            this.setNumFunctions(2);
        }
        String Ridge = Utils.getOption((char)'R', (String[])options);
        if (Ridge.length() != 0) {
            this.setRidge(Double.parseDouble(Ridge));
        } else {
            this.setRidge(0.01);
        }
        String scale = Utils.getOption((char)'C', (String[])options);
        if (scale.length() != 0) {
            this.setScaleOptimizationOption(new SelectedTag(Integer.parseInt(scale), TAGS_SCALE));
        } else {
            this.setScaleOptimizationOption(new SelectedTag(2, TAGS_SCALE));
        }
        String Tolerance = Utils.getOption((char)'L', (String[])options);
        if (Tolerance.length() != 0) {
            this.setTolerance(Double.parseDouble(Tolerance));
        } else {
            this.setTolerance(1.0E-6);
        }
        this.m_useCGD = Utils.getFlag((char)'G', (String[])options);
        this.m_useNormalizedBasisFunctions = Utils.getFlag((char)'O', (String[])options);
        this.m_useAttributeWeights = Utils.getFlag((char)'A', (String[])options);
        String PoolSize = Utils.getOption((char)'P', (String[])options);
        if (PoolSize.length() != 0) {
            this.setPoolSize(Integer.parseInt(PoolSize));
        } else {
            this.setPoolSize(1);
        }
        String NumThreads = Utils.getOption((char)'E', (String[])options);
        if (NumThreads.length() != 0) {
            this.setNumThreads(Integer.parseInt(NumThreads));
        } else {
            this.setNumThreads(1);
        }
        super.setOptions(options);
        Utils.checkForRemainingOptions((String[])options);
    }

    public String[] getOptions() {
        Vector<String> options = new Vector<String>();
        options.add("-N");
        options.add("" + this.getNumFunctions());
        options.add("-R");
        options.add("" + this.getRidge());
        options.add("-L");
        options.add("" + this.getTolerance());
        options.add("-C");
        options.add("" + this.getScaleOptimizationOption().getSelectedTag().getID());
        if (this.m_useCGD) {
            options.add("-G");
        }
        if (this.m_useNormalizedBasisFunctions) {
            options.add("-O");
        }
        if (this.m_useAttributeWeights) {
            options.add("-A");
        }
        options.add("-P");
        options.add("" + this.getPoolSize());
        options.add("-E");
        options.add("" + this.getNumThreads());
        Collections.addAll(options, super.getOptions());
        return options.toArray(new String[0]);
    }

    protected class OptEngCGD
    extends ConjugateGradientOptimization {
        protected OptEngCGD() {
        }

        protected double objectiveFunction(double[] x) {
            RBFModel.this.m_RBFParameters = x;
            return RBFModel.this.calculateSE();
        }

        protected double[] evaluateGradient(double[] x) {
            RBFModel.this.m_RBFParameters = x;
            return RBFModel.this.calculateGradient();
        }

        public String getRevision() {
            return RevisionUtils.extract((String)"$Revision: 8966 $");
        }
    }

    protected class OptEng
    extends Optimization {
        protected OptEng() {
        }

        protected double objectiveFunction(double[] x) {
            RBFModel.this.m_RBFParameters = x;
            return RBFModel.this.calculateSE();
        }

        protected double[] evaluateGradient(double[] x) {
            RBFModel.this.m_RBFParameters = x;
            return RBFModel.this.calculateGradient();
        }

        public String getRevision() {
            return RevisionUtils.extract((String)"$Revision: 8966 $");
        }
    }
}

