/*
 * Decompiled with CFR 0.152.
 */
package de.lmu.ifi.dbs.elki.algorithm.statistics;

import de.lmu.ifi.dbs.elki.algorithm.AbstractDistanceBasedAlgorithm;
import de.lmu.ifi.dbs.elki.data.LabelList;
import de.lmu.ifi.dbs.elki.data.type.AlternativeTypeInformation;
import de.lmu.ifi.dbs.elki.data.type.TypeInformation;
import de.lmu.ifi.dbs.elki.data.type.TypeUtil;
import de.lmu.ifi.dbs.elki.database.Database;
import de.lmu.ifi.dbs.elki.database.ids.DBIDIter;
import de.lmu.ifi.dbs.elki.database.ids.DBIDUtil;
import de.lmu.ifi.dbs.elki.database.ids.DBIDs;
import de.lmu.ifi.dbs.elki.database.ids.DoubleDBIDList;
import de.lmu.ifi.dbs.elki.database.ids.DoubleDBIDListMIter;
import de.lmu.ifi.dbs.elki.database.ids.HashSetModifiableDBIDs;
import de.lmu.ifi.dbs.elki.database.ids.ModifiableDBIDs;
import de.lmu.ifi.dbs.elki.database.ids.ModifiableDoubleDBIDList;
import de.lmu.ifi.dbs.elki.database.query.distance.DistanceQuery;
import de.lmu.ifi.dbs.elki.database.relation.Relation;
import de.lmu.ifi.dbs.elki.distance.distancefunction.DistanceFunction;
import de.lmu.ifi.dbs.elki.evaluation.scores.AveragePrecisionEvaluation;
import de.lmu.ifi.dbs.elki.evaluation.scores.ROCEvaluation;
import de.lmu.ifi.dbs.elki.logging.Logging;
import de.lmu.ifi.dbs.elki.logging.progress.FiniteProgress;
import de.lmu.ifi.dbs.elki.logging.statistics.DoubleStatistic;
import de.lmu.ifi.dbs.elki.math.random.RandomFactory;
import de.lmu.ifi.dbs.elki.result.Result;
import de.lmu.ifi.dbs.elki.result.textwriter.TextWriteable;
import de.lmu.ifi.dbs.elki.result.textwriter.TextWriterStream;
import de.lmu.ifi.dbs.elki.utilities.exceptions.AbortException;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.OptionID;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.CommonConstraints;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameterization.Parameterization;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.DoubleParameter;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.Flag;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.IntParameter;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.RandomParameter;
import gnu.trove.iterator.TObjectIntIterator;
import gnu.trove.map.hash.TObjectIntHashMap;

public class EvaluateRetrievalPerformance<O>
extends AbstractDistanceBasedAlgorithm<O, RetrievalPerformanceResult> {
    private static final Logging LOG = Logging.getLogger(EvaluateRetrievalPerformance.class);
    protected double sampling = 1.0;
    protected RandomFactory random = null;
    protected boolean includeSelf;
    private final String PREFIX = this.getClass().getName();
    protected int maxk = 100;

    public EvaluateRetrievalPerformance(DistanceFunction<? super O> distanceFunction, double d, RandomFactory randomFactory, boolean bl, int n) {
        super(distanceFunction);
        this.sampling = d;
        this.random = randomFactory;
        this.includeSelf = bl;
        this.maxk = n;
    }

    public RetrievalPerformanceResult run(Database database, Relation<O> relation, Relation<?> relation2) {
        DistanceQuery<O> distanceQuery = database.getDistanceQuery(relation, this.getDistanceFunction(), new Object[0]);
        DBIDs dBIDs = DBIDUtil.randomSample(relation.getDBIDs(), this.sampling, this.random);
        HashSetModifiableDBIDs hashSetModifiableDBIDs = DBIDUtil.newHashSet();
        ModifiableDoubleDBIDList modifiableDoubleDBIDList = DBIDUtil.newDistanceDBIDList(relation.size());
        TObjectIntHashMap<Object> tObjectIntHashMap = new TObjectIntHashMap<Object>();
        double d = 0.0;
        double d2 = 0.0;
        double[] dArray = new double[this.maxk];
        int n = 0;
        FiniteProgress finiteProgress = LOG.isVerbose() ? new FiniteProgress("Processing query objects", dBIDs.size(), LOG) : null;
        DBIDIter dBIDIter = dBIDs.iter();
        while (dBIDIter.valid()) {
            Object obj = relation2.get(dBIDIter);
            this.findMatches(hashSetModifiableDBIDs, relation2, obj);
            if (hashSetModifiableDBIDs.size() > 0) {
                this.computeDistances(modifiableDoubleDBIDList, dBIDIter, distanceQuery, relation);
                if (modifiableDoubleDBIDList.size() != relation.size() - (this.includeSelf ? 0 : 1)) {
                    LOG.warning("Neighbor list does not have the desired size: " + modifiableDoubleDBIDList.size());
                }
                d += AveragePrecisionEvaluation.STATIC.evaluate((DBIDs)hashSetModifiableDBIDs, (DoubleDBIDList)modifiableDoubleDBIDList);
                d2 += ROCEvaluation.STATIC.evaluate((DBIDs)hashSetModifiableDBIDs, (DoubleDBIDList)modifiableDoubleDBIDList);
                KNNEvaluator.STATIC.evaluateKNN(dArray, modifiableDoubleDBIDList, relation2, tObjectIntHashMap, obj);
                ++n;
            }
            LOG.incrementProcessed(finiteProgress);
            dBIDIter.advance();
        }
        LOG.ensureCompleted(finiteProgress);
        if (n < 1) {
            throw new AbortException("No object matched - are labels parsed correctly?");
        }
        if (!(d >= 0.0) || !(d2 >= 0.0)) {
            throw new AbortException("NaN in MAP/ROC.");
        }
        LOG.statistics(new DoubleStatistic(this.PREFIX + ".map", d /= (double)n));
        LOG.statistics(new DoubleStatistic(this.PREFIX + ".rocauc", d2 /= (double)n));
        LOG.statistics(new DoubleStatistic(this.PREFIX + ".samples", n));
        for (int i = 0; i < this.maxk; ++i) {
            dArray[i] = dArray[i] / (double)n;
            LOG.statistics(new DoubleStatistic(this.PREFIX + ".knn-" + (i + 1), dArray[i]));
        }
        return new RetrievalPerformanceResult(n, d, d2, dArray);
    }

    protected static boolean match(Object object, Object object2) {
        if (object == null) {
            return false;
        }
        if (object == object2) {
            return true;
        }
        if (object instanceof LabelList && object2 instanceof LabelList) {
            LabelList labelList = (LabelList)object;
            LabelList labelList2 = (LabelList)object2;
            int n = labelList.size();
            int n2 = labelList2.size();
            if (n == 0 || n2 == 0) {
                return false;
            }
            for (int i = 0; i < n; ++i) {
                String string = labelList.get(i);
                if (string == null) continue;
                for (int j = 0; j < n2; ++j) {
                    if (!string.equals(labelList2.get(j))) continue;
                    return true;
                }
            }
        }
        return object.equals(object2);
    }

    private void findMatches(ModifiableDBIDs modifiableDBIDs, Relation<?> relation, Object object) {
        modifiableDBIDs.clear();
        DBIDIter dBIDIter = relation.iterDBIDs();
        while (dBIDIter.valid()) {
            if (EvaluateRetrievalPerformance.match(object, relation.get(dBIDIter))) {
                modifiableDBIDs.add(dBIDIter);
            }
            dBIDIter.advance();
        }
    }

    private void computeDistances(ModifiableDoubleDBIDList modifiableDoubleDBIDList, DBIDIter dBIDIter, DistanceQuery<O> distanceQuery, Relation<O> relation) {
        modifiableDoubleDBIDList.clear();
        O o = relation.get(dBIDIter);
        DBIDIter dBIDIter2 = relation.iterDBIDs();
        while (dBIDIter2.valid()) {
            if (this.includeSelf || !DBIDUtil.equal(dBIDIter2, dBIDIter)) {
                double d = distanceQuery.distance((DBIDIter)o, dBIDIter2);
                if (d != d) {
                    d = Double.POSITIVE_INFINITY;
                }
                modifiableDoubleDBIDList.add(d, dBIDIter2);
            }
            dBIDIter2.advance();
        }
        modifiableDoubleDBIDList.sort();
    }

    @Override
    public TypeInformation[] getInputTypeRestriction() {
        AlternativeTypeInformation alternativeTypeInformation = new AlternativeTypeInformation(TypeUtil.CLASSLABEL, TypeUtil.LABELLIST);
        return TypeUtil.array(this.getDistanceFunction().getInputTypeRestriction(), alternativeTypeInformation);
    }

    @Override
    protected Logging getLogger() {
        return LOG;
    }

    public static class Parameterizer<O>
    extends AbstractDistanceBasedAlgorithm.Parameterizer<O> {
        public static final OptionID SAMPLING_ID = new OptionID("map.sampling", "Relative amount of object to sample.");
        public static final OptionID SEED_ID = new OptionID("map.sampling-seed", "Random seed for deterministic sampling.");
        public static final OptionID INCLUDESELF_ID = new OptionID("map.includeself", "Include the query object in the evaluation.");
        public static final OptionID MAXK_ID = new OptionID("map.maxk", "Maximum value of k for kNN evaluation.");
        protected double sampling = 1.0;
        protected RandomFactory seed = null;
        protected boolean includeSelf;
        protected int maxk = 0;

        @Override
        protected void makeOptions(Parameterization parameterization) {
            IntParameter intParameter;
            Flag flag;
            super.makeOptions(parameterization);
            DoubleParameter doubleParameter = new DoubleParameter(SAMPLING_ID);
            doubleParameter.addConstraint(CommonConstraints.GREATER_THAN_ZERO_DOUBLE);
            doubleParameter.addConstraint(CommonConstraints.LESS_EQUAL_ONE_DOUBLE);
            doubleParameter.setOptional(true);
            if (parameterization.grab(doubleParameter)) {
                this.sampling = (Double)doubleParameter.getValue();
            }
            RandomParameter randomParameter = new RandomParameter(SEED_ID);
            randomParameter.setOptional(true);
            if (parameterization.grab(randomParameter)) {
                this.seed = (RandomFactory)randomParameter.getValue();
            }
            if (parameterization.grab(flag = new Flag(INCLUDESELF_ID))) {
                this.includeSelf = flag.isTrue();
            }
            if (parameterization.grab(intParameter = (IntParameter)new IntParameter(MAXK_ID).setOptional(true))) {
                this.maxk = intParameter.intValue();
            }
        }

        @Override
        protected EvaluateRetrievalPerformance<O> makeInstance() {
            return new EvaluateRetrievalPerformance(this.distanceFunction, this.sampling, this.seed, this.includeSelf, this.maxk);
        }
    }

    public static class RetrievalPerformanceResult
    implements Result,
    TextWriteable {
        private int samplesize;
        private double map;
        private double rocauc;
        private double[] knnperf;

        public RetrievalPerformanceResult(int n, double d, double d2, double[] dArray) {
            this.map = d;
            this.rocauc = d2;
            this.samplesize = n;
            this.knnperf = dArray;
        }

        public double getROCAUC() {
            return this.rocauc;
        }

        public double getMAP() {
            return this.map;
        }

        @Override
        public String getLongName() {
            return "Distance function retrieval evaluation.";
        }

        @Override
        public String getShortName() {
            return "distance-retrieval-evaluation";
        }

        @Override
        public void writeToText(TextWriterStream textWriterStream, String string) {
            textWriterStream.inlinePrintNoQuotes("MAP");
            textWriterStream.inlinePrint(this.map);
            textWriterStream.flush();
            textWriterStream.inlinePrintNoQuotes("ROCAUC");
            textWriterStream.inlinePrint(this.rocauc);
            textWriterStream.flush();
            textWriterStream.inlinePrintNoQuotes("Samplesize");
            textWriterStream.inlinePrint(this.samplesize);
            textWriterStream.flush();
            for (int i = 0; i < this.knnperf.length; ++i) {
                textWriterStream.inlinePrintNoQuotes("knn-" + (i + 1));
                textWriterStream.inlinePrint(this.knnperf[i]);
                textWriterStream.flush();
            }
        }
    }

    public static class KNNEvaluator {
        public static final KNNEvaluator STATIC = new KNNEvaluator();

        public void evaluateKNN(double[] dArray, ModifiableDoubleDBIDList modifiableDoubleDBIDList, Relation<?> relation, TObjectIntHashMap<Object> tObjectIntHashMap, Object object) {
            int n = dArray.length;
            int n2 = 1;
            int n3 = 0;
            int n4 = 0;
            tObjectIntHashMap.clear();
            DoubleDBIDListMIter doubleDBIDListMIter = modifiableDoubleDBIDList.iter();
            while (doubleDBIDListMIter.valid() && n3 < n) {
                double d = doubleDBIDListMIter.doubleValue();
                Object obj = relation.get(doubleDBIDListMIter);
                n4 = Math.max(n4, this.countkNN(tObjectIntHashMap, obj));
                doubleDBIDListMIter.advance();
                ++n2;
                if (doubleDBIDListMIter.valid() && !(doubleDBIDListMIter.doubleValue() > d)) continue;
                int n5 = 0;
                int n6 = 0;
                TObjectIntIterator<Object> tObjectIntIterator = tObjectIntHashMap.iterator();
                block1: while (tObjectIntIterator.hasNext()) {
                    tObjectIntIterator.advance();
                    if (tObjectIntIterator.value() < n4) continue;
                    ++n6;
                    if (tObjectIntIterator.key() == null) continue;
                    if (tObjectIntIterator.key().equals(object)) {
                        ++n5;
                        continue;
                    }
                    if (!(object instanceof LabelList)) continue;
                    LabelList labelList = (LabelList)object;
                    int n7 = labelList.size();
                    for (int i = 0; i < n7; ++i) {
                        if (!tObjectIntIterator.key().equals(labelList.get(i))) continue;
                        ++n5;
                        continue block1;
                    }
                }
                while (n3 < n2 && n3 < n) {
                    int n8 = n3++;
                    dArray[n8] = dArray[n8] + (double)n5 / (double)n6;
                }
            }
        }

        public int countkNN(TObjectIntHashMap<Object> tObjectIntHashMap, Object object) {
            if (object instanceof LabelList) {
                LabelList labelList = (LabelList)object;
                int n = 0;
                int n2 = labelList.size();
                for (int i = 0; i < n2; ++i) {
                    n = Math.max(n, tObjectIntHashMap.adjustOrPutValue(labelList.get(i), 1, 1));
                }
                return n;
            }
            return tObjectIntHashMap.adjustOrPutValue(object, 1, 1);
        }
    }
}

