package elki.classification;

import elki.Algorithm;
import elki.data.ClassLabel;
import elki.data.type.TypeInformation;
import elki.data.type.TypeUtil;
import elki.database.Database;
import elki.database.ids.DoubleDBIDListIter;
import elki.database.query.QueryBuilder;
import elki.database.query.knn.KNNSearcher;
import elki.database.relation.Relation;
import elki.distance.Distance;
import elki.distance.minkowski.EuclideanDistance;
import elki.utilities.Priority;
import elki.utilities.documentation.Description;
import elki.utilities.documentation.Title;
import elki.utilities.optionhandling.OptionID;
import elki.utilities.optionhandling.Parameterizer;
import elki.utilities.optionhandling.constraints.CommonConstraints;
import elki.utilities.optionhandling.parameterization.Parameterization;
import elki.utilities.optionhandling.parameters.IntParameter;
import elki.utilities.optionhandling.parameters.ObjectParameter;
import it.unimi.dsi.fastutil.objects.Object2IntMap;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
import it.unimi.dsi.fastutil.objects.ObjectIterator;
import java.util.ArrayList;
import java.util.Collections;

@Title("kNN-classifier")
@Description("Lazy classifier classifies a given instance to the majority class of the k-nearest neighbors.")
@Priority(100)
/* loaded from: input_file:elki/classification/KNNClassifier.class */
public class KNNClassifier<O> implements Classifier<O> {
    protected int k;
    protected KNNSearcher<O> knnq;
    protected Relation<? extends ClassLabel> labelrep;
    protected Distance<? super O> distance;

    /* loaded from: input_file:elki/classification/KNNClassifier$Par.class */
    public static class Par<O> implements Parameterizer {
        public static final OptionID K_ID = new OptionID("knnclassifier.k", "The number of neighbors to take into account for classification.");
        protected Distance<? super O> distanceFunction;
        protected int k;

        public void configure(Parameterization parameterization) {
            new ObjectParameter(Algorithm.Utils.DISTANCE_FUNCTION_ID, Distance.class, EuclideanDistance.class).grab(parameterization, distance -> {
                this.distanceFunction = distance;
            });
            new IntParameter(K_ID, 1).addConstraint(CommonConstraints.GREATER_EQUAL_ONE_INT).grab(parameterization, i -> {
                this.k = i;
            });
        }

        /* renamed from: make, reason: merged with bridge method [inline-methods] */
        public KNNClassifier<O> m6make() {
            return new KNNClassifier<>(this.distanceFunction, this.k);
        }
    }

    public KNNClassifier(Distance<? super O> distance, int i) {
        this.distance = distance;
        this.k = i;
    }

    public TypeInformation[] getInputTypeRestriction() {
        return TypeUtil.array(new TypeInformation[]{TypeUtil.NUMBER_VECTOR_FIELD});
    }

    @Override // elki.classification.Classifier
    public void buildClassifier(Database database, Relation<? extends ClassLabel> relation) {
        this.knnq = new QueryBuilder(database.getRelation(this.distance.getInputTypeRestriction(), new Object[0]), this.distance).kNNByObject(this.k);
        this.labelrep = relation;
    }

    @Override // elki.classification.Classifier
    public ClassLabel classify(O o) {
        Object2IntOpenHashMap object2IntOpenHashMap = new Object2IntOpenHashMap();
        DoubleDBIDListIter iter = this.knnq.getKNN(o, this.k).iter();
        while (iter.valid()) {
            object2IntOpenHashMap.addTo((ClassLabel) this.labelrep.get(iter), 1);
            iter.advance();
        }
        int i = Integer.MIN_VALUE;
        ClassLabel classLabel = null;
        ObjectIterator fastIterator = object2IntOpenHashMap.object2IntEntrySet().fastIterator();
        while (fastIterator.hasNext()) {
            Object2IntMap.Entry entry = (Object2IntMap.Entry) fastIterator.next();
            if (entry.getIntValue() > i) {
                i = entry.getIntValue();
                classLabel = (ClassLabel) entry.getKey();
            }
        }
        return classLabel;
    }

    public double[] classProbabilities(O o, ArrayList<ClassLabel> arrayList) {
        int[] iArr = new int[arrayList.size()];
        DoubleDBIDListIter iter = this.knnq.getKNN(o, this.k).iter();
        while (iter.valid()) {
            int binarySearch = Collections.binarySearch(arrayList, (ClassLabel) this.labelrep.get(iter));
            if (binarySearch >= 0) {
                iArr[binarySearch] = iArr[binarySearch] + 1;
            }
            iter.advance();
        }
        double[] dArr = new double[arrayList.size()];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = iArr[i] / r0.size();
        }
        return dArr;
    }

    @Override // elki.classification.Classifier
    public String model() {
        return "lazy learner - provides no model";
    }

    public Distance<? super O> getDistance() {
        return this.distance;
    }
}
