package org.datacleaner.components.machinelearning.impl;

import java.util.List;
import org.datacleaner.components.machinelearning.api.MLClassificationMetadata;
import org.datacleaner.components.machinelearning.api.MLClassificationRecord;
import org.datacleaner.components.machinelearning.api.MLClassificationTrainer;
import org.datacleaner.components.machinelearning.api.MLClassifier;
import org.datacleaner.components.machinelearning.api.MLFeatureModifier;
import org.datacleaner.components.machinelearning.api.MLTrainerCallback;
import org.datacleaner.components.machinelearning.api.MLTrainingOptions;
import smile.classification.NeuralNetwork;

/* loaded from: input_file:org/datacleaner/components/machinelearning/impl/NeuralNetTrainer.class */
public class NeuralNetTrainer implements MLClassificationTrainer {
    private final MLTrainingOptions trainingOptions;
    private final int epochs;
    private final NeuralNetwork.ErrorFunction errorFunction;
    private final NeuralNetwork.ActivationFunction activationFunction;
    private final int[] hiddenNeuronPerLayer;
    private final double learningRate;
    private final double momentum;

    public NeuralNetTrainer(MLTrainingOptions mLTrainingOptions, int i, NeuralNetwork.ErrorFunction errorFunction, NeuralNetwork.ActivationFunction activationFunction, int[] iArr, double d, double d2) {
        this.trainingOptions = mLTrainingOptions;
        this.epochs = i;
        this.errorFunction = errorFunction;
        this.activationFunction = activationFunction;
        this.hiddenNeuronPerLayer = iArr;
        this.learningRate = d;
        this.momentum = d2;
    }

    @Override // org.datacleaner.components.machinelearning.api.MLClassificationTrainer
    public MLClassifier train(Iterable<MLClassificationRecord> iterable, List<MLFeatureModifier> list, MLTrainerCallback mLTrainerCallback) {
        List<Object> classifications = MLFeatureUtils.toClassifications(iterable);
        double[][] featureVector = MLFeatureUtils.toFeatureVector(iterable, list);
        int[] classificationVector = MLFeatureUtils.toClassificationVector(iterable);
        int[] iArr = new int[this.hiddenNeuronPerLayer.length + 2];
        iArr[0] = MLFeatureUtils.getFeatureCount(list);
        for (int i = 0; i < iArr.length - 2; i++) {
            iArr[i + 1] = this.hiddenNeuronPerLayer[i];
        }
        iArr[iArr.length - 1] = classifications.size();
        NeuralNetwork neuralNetwork = new NeuralNetwork(this.errorFunction, this.activationFunction, iArr);
        neuralNetwork.setLearningRate(this.learningRate);
        neuralNetwork.setMomentum(this.momentum);
        for (int i2 = 0; i2 < this.epochs; i2++) {
            neuralNetwork.learn(featureVector, classificationVector);
            mLTrainerCallback.epochDone(i2 + 1, this.epochs);
        }
        return new SmileClassifier(neuralNetwork, new MLClassificationMetadata(this.trainingOptions.getClassificationType(), classifications, this.trainingOptions.getColumnNames(), list));
    }
}
