package org.datacleaner.components.machinelearning;

import javax.inject.Named;
import org.datacleaner.api.Configured;
import org.datacleaner.api.Description;
import org.datacleaner.api.NumberProperty;
import org.datacleaner.components.machinelearning.api.MLClassificationTrainer;
import org.datacleaner.components.machinelearning.api.MLTrainingOptions;
import org.datacleaner.components.machinelearning.impl.NeuralNetTrainer;
import smile.classification.NeuralNetwork;

@Named("Train Neural Net classifier")
@Description("Train a classifier of the 'Neural Net' type.")
/* loaded from: input_file:org/datacleaner/components/machinelearning/NeuralNetTrainingAnalyzer.class */
public class NeuralNetTrainingAnalyzer extends MLClassificationTrainingAnalyzer {

    @NumberProperty(negative = false, zero = false)
    @Configured
    int epochs = 10;

    @Configured
    NeuralNetwork.ErrorFunction errorFunction = NeuralNetwork.ErrorFunction.CROSS_ENTROPY;

    @Configured
    NeuralNetwork.ActivationFunction activationFunction = NeuralNetwork.ActivationFunction.SOFTMAX;

    @NumberProperty(negative = false, zero = false)
    @Configured("Hidden layers")
    int[] numUnitsPerLayer = {64};

    @Configured
    double learningRate = 0.1d;

    @Configured
    double momentum = 0.1d;

    @Override // org.datacleaner.components.machinelearning.MLClassificationTrainingAnalyzer
    protected MLClassificationTrainer createTrainer(MLTrainingOptions mLTrainingOptions) {
        return new NeuralNetTrainer(mLTrainingOptions, this.epochs, this.errorFunction, this.activationFunction, this.numUnitsPerLayer, this.learningRate, this.momentum);
    }
}
