/*
 * Decompiled with CFR 0.152.
 */
package de.lmu.ifi.dbs.elki.evaluation.classification.holdout;

import de.lmu.ifi.dbs.elki.data.ClassLabel;
import de.lmu.ifi.dbs.elki.datasource.bundle.MultipleObjectsBundle;
import de.lmu.ifi.dbs.elki.evaluation.classification.holdout.AbstractHoldout;
import de.lmu.ifi.dbs.elki.evaluation.classification.holdout.TrainingAndTestSet;
import de.lmu.ifi.dbs.elki.utilities.exceptions.AbortException;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.AbstractParameterizer;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.OptionID;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.CommonConstraints;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameterization.Parameterization;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.IntParameter;
import gnu.trove.list.array.TIntArrayList;
import java.util.ArrayList;
import java.util.Collections;

public class StratifiedCrossValidation
extends AbstractHoldout {
    protected int nfold;
    protected int fold;
    protected int[] assignment;
    protected int[] sizes;

    public StratifiedCrossValidation(int n) {
        this.nfold = n;
    }

    @Override
    public int numberOfPartitions() {
        return this.nfold;
    }

    @Override
    public void initialize(MultipleObjectsBundle multipleObjectsBundle) {
        int n;
        super.initialize(multipleObjectsBundle);
        this.fold = 0;
        TIntArrayList[] tIntArrayListArray = new TIntArrayList[this.labels.size()];
        for (n = 0; n < this.labels.size(); ++n) {
            tIntArrayListArray[n] = new TIntArrayList();
        }
        int n2 = multipleObjectsBundle.dataLength();
        for (n = 0; n < n2; ++n) {
            ClassLabel classLabel = (ClassLabel)multipleObjectsBundle.data(n, this.labelcol);
            if (classLabel == null) {
                throw new AbortException("Unlabeled instances currently not supported.");
            }
            int n3 = Collections.binarySearch(this.labels, classLabel);
            if (n3 < 0) {
                throw new AbortException("Label not in label list: " + classLabel);
            }
            tIntArrayListArray[n3].add(n);
        }
        this.sizes = new int[this.nfold];
        this.assignment = new int[multipleObjectsBundle.dataLength()];
        for (TIntArrayList tIntArrayList : tIntArrayListArray) {
            for (int i = 0; i < tIntArrayList.size(); ++i) {
                this.assignment[tIntArrayList.get((int)i)] = i % this.nfold;
            }
        }
    }

    @Override
    public TrainingAndTestSet nextPartitioning() {
        if (this.fold >= this.nfold) {
            return null;
        }
        int n = this.sizes[this.fold];
        int n2 = this.bundle.dataLength() - n;
        MultipleObjectsBundle multipleObjectsBundle = new MultipleObjectsBundle();
        MultipleObjectsBundle multipleObjectsBundle2 = new MultipleObjectsBundle();
        int n3 = this.bundle.metaLength();
        for (int i = 0; i < n3; ++i) {
            ArrayList<Object> arrayList = new ArrayList<Object>(n2);
            ArrayList arrayList2 = new ArrayList(n);
            for (int j = 0; j < this.bundle.dataLength(); ++j) {
                (this.assignment[j] != this.fold ? arrayList : arrayList2).add(this.bundle.data(j, i));
            }
            multipleObjectsBundle.appendColumn(this.bundle.meta(i), arrayList);
            multipleObjectsBundle2.appendColumn(this.bundle.meta(i), arrayList2);
        }
        ++this.fold;
        return new TrainingAndTestSet(multipleObjectsBundle, multipleObjectsBundle2, this.labels);
    }

    public static class Parameterizer
    extends AbstractParameterizer {
        public static final int N_DEFAULT = 10;
        public static final OptionID NFOLD_ID = new OptionID("nfold", "Number of folds for cross-validation");
        protected int nfold;

        @Override
        protected void makeOptions(Parameterization parameterization) {
            super.makeOptions(parameterization);
            IntParameter intParameter = (IntParameter)new IntParameter(NFOLD_ID, 10).addConstraint(CommonConstraints.GREATER_EQUAL_ONE_INT);
            if (parameterization.grab(intParameter)) {
                this.nfold = intParameter.intValue();
            }
        }

        @Override
        protected StratifiedCrossValidation makeInstance() {
            return new StratifiedCrossValidation(this.nfold);
        }
    }
}

