/*
 * Decompiled with CFR 0.152.
 */
package de.lmu.ifi.dbs.elki.algorithm.statistics;

import de.lmu.ifi.dbs.elki.algorithm.AbstractDistanceBasedAlgorithm;
import de.lmu.ifi.dbs.elki.data.DoubleVector;
import de.lmu.ifi.dbs.elki.data.LabelList;
import de.lmu.ifi.dbs.elki.data.type.AlternativeTypeInformation;
import de.lmu.ifi.dbs.elki.data.type.TypeInformation;
import de.lmu.ifi.dbs.elki.data.type.TypeUtil;
import de.lmu.ifi.dbs.elki.database.Database;
import de.lmu.ifi.dbs.elki.database.ids.DBIDRef;
import de.lmu.ifi.dbs.elki.database.ids.DBIDUtil;
import de.lmu.ifi.dbs.elki.database.ids.DBIDs;
import de.lmu.ifi.dbs.elki.database.ids.KNNList;
import de.lmu.ifi.dbs.elki.database.query.distance.DistanceQuery;
import de.lmu.ifi.dbs.elki.database.query.knn.KNNQuery;
import de.lmu.ifi.dbs.elki.database.relation.Relation;
import de.lmu.ifi.dbs.elki.distance.distancefunction.DistanceFunction;
import de.lmu.ifi.dbs.elki.logging.Logging;
import de.lmu.ifi.dbs.elki.logging.progress.FiniteProgress;
import de.lmu.ifi.dbs.elki.math.MeanVariance;
import de.lmu.ifi.dbs.elki.math.MeanVarianceMinMax;
import de.lmu.ifi.dbs.elki.math.random.RandomFactory;
import de.lmu.ifi.dbs.elki.result.CollectionResult;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.OptionID;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.CommonConstraints;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameterization.Parameterization;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.DoubleParameter;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.Flag;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.IntParameter;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.RandomParameter;
import java.util.ArrayList;
import java.util.Collection;

public class AveragePrecisionAtK<O>
extends AbstractDistanceBasedAlgorithm<O, CollectionResult<DoubleVector>> {
    private static final Logging LOG = Logging.getLogger(AveragePrecisionAtK.class);
    private int k;
    private double sampling = 1.0;
    private RandomFactory random = null;
    private boolean includeSelf;

    public AveragePrecisionAtK(DistanceFunction<? super O> distanceFunction, int n, double d, RandomFactory randomFactory, boolean bl) {
        super(distanceFunction);
        this.k = n;
        this.sampling = d;
        this.random = randomFactory;
        this.includeSelf = bl;
    }

    public CollectionResult<DoubleVector> run(Database database, Relation<O> relation, Relation<?> relation2) {
        Object object;
        Object object2;
        DistanceQuery<O> distanceQuery = database.getDistanceQuery(relation, this.getDistanceFunction(), new Object[0]);
        int n = this.k + (this.includeSelf ? 0 : 1);
        KNNQuery<O> kNNQuery = database.getKNNQuery(distanceQuery, n);
        MeanVarianceMinMax[] meanVarianceMinMaxArray = MeanVarianceMinMax.newArray(this.k);
        DBIDs dBIDs = DBIDUtil.randomSample(relation.getDBIDs(), this.sampling, this.random);
        FiniteProgress finiteProgress = LOG.isVerbose() ? new FiniteProgress("Computing nearest neighbors", dBIDs.size(), LOG) : null;
        Object object3 = dBIDs.iter();
        while (object3.valid()) {
            KNNList kNNList = kNNQuery.getKNNForDBID((DBIDRef)object3, n);
            object2 = relation2.get((DBIDRef)object3);
            int n2 = 0;
            int n3 = 0;
            object = kNNList.iter();
            while (n3 < this.k && object.valid()) {
                if (this.includeSelf || !DBIDUtil.equal((DBIDRef)object3, (DBIDRef)object)) {
                    double d = (double)(n2 += AveragePrecisionAtK.match(object2, relation2.get((DBIDRef)object)) ? 1 : 0) / (double)(n3 + 1);
                    meanVarianceMinMaxArray[n3].put(d);
                    ++n3;
                }
                object.advance();
            }
            LOG.incrementProcessed(finiteProgress);
            object3.advance();
        }
        LOG.ensureCompleted(finiteProgress);
        object3 = new ArrayList(this.k);
        for (int i = 0; i < this.k; ++i) {
            object2 = meanVarianceMinMaxArray[i];
            double d = ((MeanVariance)object2).getCount() > 1.0 ? ((MeanVariance)object2).getSampleStddev() : 0.0;
            object = new DoubleVector(new double[]{i + 1, ((MeanVariance)object2).getMean(), d, ((MeanVarianceMinMax)object2).getMin(), ((MeanVarianceMinMax)object2).getMax(), ((MeanVariance)object2).getCount()});
            object3.add(object);
        }
        return new CollectionResult<DoubleVector>("Average Precision", "average-precision", (Collection<DoubleVector>)object3);
    }

    protected static boolean match(Object object, Object object2) {
        if (object == null) {
            return false;
        }
        if (object == object2) {
            return true;
        }
        if (object instanceof LabelList && object2 instanceof LabelList) {
            LabelList labelList = (LabelList)object;
            LabelList labelList2 = (LabelList)object2;
            int n = labelList.size();
            int n2 = labelList2.size();
            if (n == 0 || n2 == 0) {
                return false;
            }
            for (int i = 0; i < n; ++i) {
                String string = labelList.get(i);
                if (string == null) continue;
                for (int j = 0; j < n2; ++j) {
                    if (!string.equals(labelList2.get(j))) continue;
                    return true;
                }
            }
        }
        return object.equals(object2);
    }

    @Override
    public TypeInformation[] getInputTypeRestriction() {
        AlternativeTypeInformation alternativeTypeInformation = new AlternativeTypeInformation(TypeUtil.CLASSLABEL, TypeUtil.LABELLIST);
        return TypeUtil.array(this.getDistanceFunction().getInputTypeRestriction(), alternativeTypeInformation);
    }

    @Override
    protected Logging getLogger() {
        return LOG;
    }

    public static class Parameterizer<O>
    extends AbstractDistanceBasedAlgorithm.Parameterizer<O> {
        private static final OptionID K_ID = new OptionID("avep.k", "K to compute the average precision at.");
        public static final OptionID SAMPLING_ID = new OptionID("avep.sampling", "Relative amount of object to sample.");
        public static final OptionID SEED_ID = new OptionID("avep.sampling-seed", "Random seed for deterministic sampling.");
        public static final OptionID INCLUDESELF_ID = new OptionID("avep.includeself", "Include the query object in the evaluation.");
        protected int k = 20;
        protected double sampling = 1.0;
        protected RandomFactory seed = null;
        protected boolean includeSelf;

        @Override
        protected void makeOptions(Parameterization parameterization) {
            Flag flag;
            RandomParameter randomParameter;
            super.makeOptions(parameterization);
            IntParameter intParameter = new IntParameter(K_ID);
            intParameter.addConstraint(CommonConstraints.GREATER_THAN_ONE_INT);
            if (parameterization.grab(intParameter)) {
                this.k = (Integer)intParameter.getValue();
            }
            DoubleParameter doubleParameter = new DoubleParameter(SAMPLING_ID);
            doubleParameter.addConstraint(CommonConstraints.GREATER_THAN_ZERO_DOUBLE);
            doubleParameter.addConstraint(CommonConstraints.LESS_EQUAL_ONE_DOUBLE);
            doubleParameter.setOptional(true);
            if (parameterization.grab(doubleParameter)) {
                this.sampling = (Double)doubleParameter.getValue();
            }
            if (parameterization.grab(randomParameter = new RandomParameter(SEED_ID))) {
                this.seed = (RandomFactory)randomParameter.getValue();
            }
            if (parameterization.grab(flag = new Flag(INCLUDESELF_ID))) {
                this.includeSelf = flag.isTrue();
            }
        }

        @Override
        protected AveragePrecisionAtK<O> makeInstance() {
            return new AveragePrecisionAtK(this.distanceFunction, this.k, this.sampling, this.seed, this.includeSelf);
        }
    }
}

