package org.apache.ignite.ml.selection.scoring.evaluator.aggregator;

import org.apache.ignite.internal.util.typedef.internal.A;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.selection.scoring.evaluator.context.EmptyContext;
import org.apache.ignite.ml.structures.LabeledVector;

/* loaded from: input_file:org/apache/ignite/ml/selection/scoring/evaluator/aggregator/RegressionMetricStatsAggregator.class */
public class RegressionMetricStatsAggregator implements MetricStatsAggregator<Double, EmptyContext<Double>, RegressionMetricStatsAggregator> {
    private static final long serialVersionUID = -2459352313996869235L;
    private long n;
    private double absoluteError;
    private double rss;
    private double sumOfYs;
    private double sumOfSquaredYs;

    public RegressionMetricStatsAggregator() {
        this.absoluteError = Double.NaN;
        this.rss = Double.NaN;
        this.sumOfYs = Double.NaN;
        this.sumOfSquaredYs = Double.NaN;
    }

    public RegressionMetricStatsAggregator(long j, double d, double d2, double d3, double d4) {
        this.absoluteError = Double.NaN;
        this.rss = Double.NaN;
        this.sumOfYs = Double.NaN;
        this.sumOfSquaredYs = Double.NaN;
        this.n = j;
        this.absoluteError = d;
        this.rss = d2;
        this.sumOfYs = d3;
        this.sumOfSquaredYs = d4;
    }

    @Override // org.apache.ignite.ml.selection.scoring.evaluator.aggregator.MetricStatsAggregator
    public void aggregate(IgniteModel<Vector, Double> igniteModel, LabeledVector<Double> labeledVector) {
        this.n++;
        Double predict = igniteModel.predict(labeledVector.features());
        Double label = labeledVector.label();
        A.notNull(Boolean.valueOf(label != null), "Test set mustn't contain null labels");
        A.notNull(Boolean.valueOf(predict != null), "Model mustn't return null answers");
        double doubleValue = label.doubleValue() - predict.doubleValue();
        this.absoluteError = sum(Math.abs(doubleValue), this.absoluteError);
        this.rss = sum(Math.pow(doubleValue, 2.0d), this.rss);
        this.sumOfYs = sum(label.doubleValue(), this.sumOfYs);
        this.sumOfSquaredYs = sum(Math.pow(label.doubleValue(), 2.0d), this.sumOfSquaredYs);
    }

    @Override // org.apache.ignite.ml.selection.scoring.evaluator.aggregator.MetricStatsAggregator
    public RegressionMetricStatsAggregator mergeWith(RegressionMetricStatsAggregator regressionMetricStatsAggregator) {
        return new RegressionMetricStatsAggregator(this.n + regressionMetricStatsAggregator.n, sum(this.absoluteError, regressionMetricStatsAggregator.absoluteError), sum(this.rss, regressionMetricStatsAggregator.rss), sum(this.sumOfYs, regressionMetricStatsAggregator.sumOfYs), sum(this.sumOfSquaredYs, regressionMetricStatsAggregator.sumOfSquaredYs));
    }

    @Override // org.apache.ignite.ml.selection.scoring.evaluator.aggregator.MetricStatsAggregator
    /* renamed from: createInitializedContext, reason: merged with bridge method [inline-methods] */
    public EmptyContext<Double> createInitializedContext2() {
        return new EmptyContext<>();
    }

    /* renamed from: initByContext, reason: avoid collision after fix types in other method */
    public void initByContext2(EmptyContext emptyContext) {
    }

    public double getMAE() {
        if (Double.isNaN(this.absoluteError)) {
            return Double.NaN;
        }
        return this.absoluteError / Math.max(this.n, 1L);
    }

    public double getMSE() {
        return this.rss / Math.max(this.n, 1L);
    }

    public double ysRss() {
        return ysVariance() * Math.max(this.n, 1L);
    }

    public double ysVariance() {
        if (Double.isNaN(this.sumOfSquaredYs)) {
            return Double.NaN;
        }
        return (this.sumOfSquaredYs / Math.max(this.n, 1L)) - Math.pow(this.sumOfYs / Math.max(this.n, 1L), 2.0d);
    }

    public double getRss() {
        return this.rss;
    }

    private double sum(double d, double d2) {
        return Double.isNaN(d) ? d2 : Double.isNaN(d2) ? d : d + d2;
    }

    @Override // org.apache.ignite.ml.selection.scoring.evaluator.aggregator.MetricStatsAggregator
    public /* bridge */ /* synthetic */ void initByContext(EmptyContext<Double> emptyContext) {
        initByContext2((EmptyContext) emptyContext);
    }
}
