package org.apache.ignite.ml.knn;

import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import org.apache.ignite.ml.Exportable;
import org.apache.ignite.ml.Exporter;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.environment.deploy.DeployableObject;
import org.apache.ignite.ml.knn.ann.KNNModelFormat;
import org.apache.ignite.ml.math.distances.DistanceMeasure;
import org.apache.ignite.ml.math.distances.EuclideanDistance;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.structures.LabeledVectorSet;
import org.apache.ignite.ml.util.ModelTrace;
import org.jetbrains.annotations.NotNull;

/* loaded from: input_file:org/apache/ignite/ml/knn/NNClassificationModel.class */
public abstract class NNClassificationModel implements IgniteModel<Vector, Double>, Exportable<KNNModelFormat>, DeployableObject {
    protected int k = 5;
    protected DistanceMeasure distanceMeasure = new EuclideanDistance();
    protected boolean weighted;

    public NNClassificationModel withK(int i) {
        this.k = i;
        return this;
    }

    public NNClassificationModel withWeighted(boolean z) {
        this.weighted = z;
        return this;
    }

    public NNClassificationModel withDistanceMeasure(DistanceMeasure distanceMeasure) {
        this.distanceMeasure = distanceMeasure;
        return this;
    }

    protected LabeledVectorSet<LabeledVector> buildLabeledDatasetOnListOfVectors(List<LabeledVector> list) {
        LabeledVector[] labeledVectorArr = new LabeledVector[list.size()];
        for (int i = 0; i < labeledVectorArr.length; i++) {
            labeledVectorArr[i] = list.get(i);
        }
        return new LabeledVectorSet<>(labeledVectorArr);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @NotNull
    protected LabeledVector[] getKClosestVectors(LabeledVectorSet<LabeledVector> labeledVectorSet, TreeMap<Double, Set<Integer>> treeMap) {
        LabeledVector[] labeledVectorArr;
        if (labeledVectorSet.rowSize() > this.k) {
            labeledVectorArr = new LabeledVector[this.k];
            int i = 0;
            Iterator<Double> it = treeMap.keySet().iterator();
            while (i < this.k) {
                Iterator<Integer> it2 = treeMap.get(Double.valueOf(it.next().doubleValue())).iterator();
                while (it2.hasNext()) {
                    labeledVectorArr[i] = (LabeledVector) labeledVectorSet.getRow(it2.next().intValue());
                    i++;
                    if (i >= this.k) {
                        break;
                    }
                }
            }
        } else {
            labeledVectorArr = new LabeledVector[labeledVectorSet.rowSize()];
            for (int i2 = 0; i2 < labeledVectorSet.rowSize(); i2++) {
                labeledVectorArr[i2] = (LabeledVector) labeledVectorSet.getRow(i2);
            }
        }
        return labeledVectorArr;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @NotNull
    protected TreeMap<Double, Set<Integer>> getDistances(Vector vector, LabeledVectorSet<LabeledVector> labeledVectorSet) {
        TreeMap<Double, Set<Integer>> treeMap = new TreeMap<>();
        for (int i = 0; i < labeledVectorSet.rowSize(); i++) {
            LabeledVector labeledVector = (LabeledVector) labeledVectorSet.getRow(i);
            if (labeledVector != null) {
                putDistanceIdxPair(treeMap, i, this.distanceMeasure.compute(vector, labeledVector.features()));
            }
        }
        return treeMap;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void putDistanceIdxPair(Map<Double, Set<Integer>> map, int i, double d) {
        if (map.containsKey(Double.valueOf(d))) {
            map.get(Double.valueOf(d)).add(Integer.valueOf(i));
            return;
        }
        HashSet hashSet = new HashSet();
        hashSet.add(Integer.valueOf(i));
        map.put(Double.valueOf(d), hashSet);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double getClassWithMaxVotes(Map<Double, Double> map) {
        return ((Double) ((Map.Entry) Collections.max(map.entrySet(), Map.Entry.comparingByValue())).getKey()).doubleValue();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double getClassVoteForVector(boolean z, double d) {
        if (z) {
            return 1.0d / d;
        }
        return 1.0d;
    }

    public DistanceMeasure getDistanceMeasure() {
        return this.distanceMeasure;
    }

    public int hashCode() {
        return (((((1 * 37) + this.k) * 37) + this.distanceMeasure.hashCode()) * 37) + Boolean.hashCode(this.weighted);
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        NNClassificationModel nNClassificationModel = (NNClassificationModel) obj;
        return this.k == nNClassificationModel.k && this.distanceMeasure.equals(nNClassificationModel.distanceMeasure) && this.weighted == nNClassificationModel.weighted;
    }

    public String toString() {
        return toString(false);
    }

    @Override // org.apache.ignite.ml.IgniteModel
    public String toString(boolean z) {
        return ModelTrace.builder("KNNClassificationModel", z).addField("k", String.valueOf(this.k)).addField("measure", this.distanceMeasure.getClass().getSimpleName()).addField("weighted", String.valueOf(this.weighted)).toString();
    }

    protected void copyParametersFrom(NNClassificationModel nNClassificationModel) {
        this.k = nNClassificationModel.k;
        this.distanceMeasure = nNClassificationModel.distanceMeasure;
        this.weighted = nNClassificationModel.weighted;
    }

    @Override // org.apache.ignite.ml.Exportable
    public abstract <P> void saveModel(Exporter<KNNModelFormat, P> exporter, P p);

    @Override // org.apache.ignite.ml.environment.deploy.DeployableObject
    public List<Object> getDependencies() {
        return Collections.singletonList(this.distanceMeasure);
    }
}
