package org.datacleaner.components.machinelearning;

import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import org.datacleaner.components.machinelearning.api.MLClassificationRecord;
import org.datacleaner.components.machinelearning.api.MLClassifier;
import org.datacleaner.result.Crosstab;
import org.datacleaner.result.CrosstabDimension;
import org.datacleaner.result.CrosstabNavigator;

/* loaded from: input_file:org/datacleaner/components/machinelearning/MLConfusionMatrixBuilder.class */
public class MLConfusionMatrixBuilder {
    private final MLClassifier classifier;
    private final Crosstab<Integer> crosstab = new Crosstab<>(Integer.class, new String[]{"Expected", "Actual"});
    private final CrosstabDimension expectedDimension = this.crosstab.getDimension(0);
    private final CrosstabDimension actualDimension = this.crosstab.getDimension(1);

    public MLConfusionMatrixBuilder(MLClassifier mLClassifier) {
        this.classifier = mLClassifier;
        List list = (List) mLClassifier.getMetadata().getClassifications().stream().map(this::getClassificationLabel).collect(Collectors.toList());
        this.expectedDimension.addCategories(list);
        this.actualDimension.addCategories(list);
        Iterator it = list.iterator();
        while (it.hasNext()) {
            CrosstabNavigator where = this.crosstab.where(this.expectedDimension, (String) it.next());
            Iterator it2 = list.iterator();
            while (it2.hasNext()) {
                where.where(this.actualDimension, (String) it2.next()).put(0);
            }
        }
    }

    public void append(MLClassificationRecord mLClassificationRecord) {
        CrosstabNavigator where = this.crosstab.navigate().where(this.expectedDimension, getClassificationLabel(mLClassificationRecord.getClassification())).where(this.actualDimension, getClassificationLabel(this.classifier.getMetadata().getClassification(this.classifier.classify(mLClassificationRecord).getBestClassificationIndex())));
        Integer num = (Integer) where.get();
        if (num == null) {
            where.put(1);
        } else {
            where.put(Integer.valueOf(num.intValue() + 1));
        }
    }

    private String getClassificationLabel(Object obj) {
        return obj.toString();
    }

    public Crosstab<Integer> build() {
        return this.crosstab;
    }
}
