/*
 * Decompiled with CFR 0.152.
 */
package de.lmu.ifi.dbs.elki.application;

import de.lmu.ifi.dbs.elki.algorithm.classification.Classifier;
import de.lmu.ifi.dbs.elki.application.AbstractApplication;
import de.lmu.ifi.dbs.elki.data.ClassLabel;
import de.lmu.ifi.dbs.elki.data.type.TypeUtil;
import de.lmu.ifi.dbs.elki.database.AbstractDatabase;
import de.lmu.ifi.dbs.elki.database.StaticArrayDatabase;
import de.lmu.ifi.dbs.elki.database.relation.Relation;
import de.lmu.ifi.dbs.elki.datasource.DatabaseConnection;
import de.lmu.ifi.dbs.elki.datasource.FileBasedDatabaseConnection;
import de.lmu.ifi.dbs.elki.datasource.MultipleObjectsBundleDatabaseConnection;
import de.lmu.ifi.dbs.elki.datasource.bundle.MultipleObjectsBundle;
import de.lmu.ifi.dbs.elki.evaluation.classification.ConfusionMatrix;
import de.lmu.ifi.dbs.elki.evaluation.classification.holdout.AbstractHoldout;
import de.lmu.ifi.dbs.elki.evaluation.classification.holdout.Holdout;
import de.lmu.ifi.dbs.elki.evaluation.classification.holdout.StratifiedCrossValidation;
import de.lmu.ifi.dbs.elki.evaluation.classification.holdout.TrainingAndTestSet;
import de.lmu.ifi.dbs.elki.index.IndexFactory;
import de.lmu.ifi.dbs.elki.logging.Logging;
import de.lmu.ifi.dbs.elki.logging.statistics.Duration;
import de.lmu.ifi.dbs.elki.utilities.exceptions.UnableToComplyException;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.OptionID;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameterization.Parameterization;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.ObjectListParameter;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.ObjectParameter;
import de.lmu.ifi.dbs.elki.workflow.AlgorithmStep;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;

public class ClassifierHoldoutEvaluationTask<O>
extends AbstractApplication {
    private static final Logging LOG = Logging.getLogger(ClassifierHoldoutEvaluationTask.class);
    protected DatabaseConnection databaseConnection = null;
    protected Collection<IndexFactory<?, ?>> indexFactories;
    protected Classifier<O> algorithm;
    protected Holdout holdout;

    public ClassifierHoldoutEvaluationTask(DatabaseConnection databaseConnection, Collection<IndexFactory<?, ?>> collection, Classifier<O> classifier, Holdout holdout) {
        this.databaseConnection = databaseConnection;
        this.indexFactories = collection;
        this.algorithm = classifier;
        this.holdout = holdout;
    }

    @Override
    public void run() throws UnableToComplyException {
        Duration duration = LOG.newDuration("evaluation.time.load").begin();
        MultipleObjectsBundle multipleObjectsBundle = this.databaseConnection.loadData();
        this.holdout.initialize(multipleObjectsBundle);
        LOG.statistics(duration.end());
        Duration duration2 = LOG.newDuration("evaluation.time.total").begin();
        ArrayList<ClassLabel> arrayList = this.holdout.getLabels();
        int[][] nArray = new int[arrayList.size()][arrayList.size()];
        for (int i = 0; i < this.holdout.numberOfPartitions(); ++i) {
            TrainingAndTestSet trainingAndTestSet = this.holdout.nextPartitioning();
            Duration duration3 = LOG.newDuration(this.getClass().getName() + ".fold-" + (i + 1) + ".init.time").begin();
            StaticArrayDatabase staticArrayDatabase = new StaticArrayDatabase(new MultipleObjectsBundleDatabaseConnection(trainingAndTestSet.getTraining()), this.indexFactories);
            staticArrayDatabase.initialize();
            LOG.statistics(duration3.end());
            duration3 = LOG.newDuration(this.getClass().getName() + ".fold-" + (i + 1) + ".train.time").begin();
            Relation relation = staticArrayDatabase.getRelation(TypeUtil.CLASSLABEL, new Object[0]);
            this.algorithm.buildClassifier(staticArrayDatabase, relation);
            LOG.statistics(duration3.end());
            duration3 = LOG.newDuration(this.getClass().getName() + ".fold-" + (i + 1) + ".evaluation.time").begin();
            MultipleObjectsBundle multipleObjectsBundle2 = trainingAndTestSet.getTest();
            int n = AbstractHoldout.findClassLabelColumn(multipleObjectsBundle2);
            int n2 = n == 0 ? 1 : 0;
            int n3 = multipleObjectsBundle2.dataLength();
            for (int j = 0; j < n3; ++j) {
                Object object = multipleObjectsBundle2.data(j, n2);
                ClassLabel classLabel = (ClassLabel)multipleObjectsBundle2.data(j, n);
                ClassLabel classLabel2 = this.algorithm.classify(object);
                int n4 = Collections.binarySearch(arrayList, classLabel2);
                int n5 = Collections.binarySearch(arrayList, classLabel);
                int[] nArray2 = nArray[n4];
                int n6 = n5;
                nArray2[n6] = nArray2[n6] + 1;
            }
            LOG.statistics(duration3.end());
        }
        LOG.statistics(duration2.end());
        ConfusionMatrix confusionMatrix = new ConfusionMatrix(arrayList, nArray);
        LOG.statistics(confusionMatrix.toString());
    }

    public static void main(String[] stringArray) {
        ClassifierHoldoutEvaluationTask.runCLIApplication(ClassifierHoldoutEvaluationTask.class, stringArray);
    }

    public static class Parameterizer<O>
    extends AbstractApplication.Parameterizer {
        public static final OptionID HOLDOUT_ID = new OptionID("evaluation.holdout", "Holdout class used in evaluation.");
        protected DatabaseConnection databaseConnection = null;
        protected Collection<IndexFactory<?, ?>> indexFactories;
        protected Classifier<O> algorithm;
        protected Holdout holdout;

        @Override
        protected void makeOptions(Parameterization parameterization) {
            ObjectParameter objectParameter;
            ObjectParameter objectParameter2;
            ObjectListParameter objectListParameter;
            super.makeOptions(parameterization);
            ObjectParameter objectParameter3 = new ObjectParameter(AbstractDatabase.Parameterizer.DATABASE_CONNECTION_ID, (Class<?>)DatabaseConnection.class, FileBasedDatabaseConnection.class);
            if (parameterization.grab(objectParameter3)) {
                this.databaseConnection = (DatabaseConnection)objectParameter3.instantiateClass(parameterization);
            }
            if (parameterization.grab(objectListParameter = new ObjectListParameter(AbstractDatabase.Parameterizer.INDEX_ID, IndexFactory.class, true))) {
                this.indexFactories = objectListParameter.instantiateClasses(parameterization);
            }
            if (parameterization.grab(objectParameter2 = new ObjectParameter(AlgorithmStep.Parameterizer.ALGORITHM_ID, Classifier.class))) {
                this.algorithm = (Classifier)objectParameter2.instantiateClass(parameterization);
            }
            if (parameterization.grab(objectParameter = new ObjectParameter(HOLDOUT_ID, (Class<?>)Holdout.class, StratifiedCrossValidation.class))) {
                this.holdout = (Holdout)objectParameter.instantiateClass(parameterization);
            }
        }

        @Override
        protected ClassifierHoldoutEvaluationTask<O> makeInstance() {
            return new ClassifierHoldoutEvaluationTask<O>(this.databaseConnection, this.indexFactories, this.algorithm, this.holdout);
        }
    }
}

