package org.datacleaner.components.machinelearning;

import com.google.common.io.Files;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import org.apache.commons.lang.SerializationUtils;
import org.apache.metamodel.util.CollectionUtils;
import org.apache.metamodel.util.HasNameMapper;
import org.datacleaner.api.Categorized;
import org.datacleaner.api.Configured;
import org.datacleaner.api.Description;
import org.datacleaner.api.Initialize;
import org.datacleaner.api.InputColumn;
import org.datacleaner.api.InputRow;
import org.datacleaner.api.NumberProperty;
import org.datacleaner.components.machinelearning.api.MLFeatureModifier;
import org.datacleaner.components.machinelearning.api.MLFeatureModifierBuilder;
import org.datacleaner.components.machinelearning.api.MLFeatureModifierBuilderFactory;
import org.datacleaner.components.machinelearning.api.MLFeatureModifierType;
import org.datacleaner.components.machinelearning.api.MLRegressionRecord;
import org.datacleaner.components.machinelearning.api.MLRegressor;
import org.datacleaner.components.machinelearning.api.MLRegressorTrainer;
import org.datacleaner.components.machinelearning.api.MLTrainerCallback;
import org.datacleaner.components.machinelearning.api.MLTrainingConstraints;
import org.datacleaner.components.machinelearning.api.MLTrainingOptions;
import org.datacleaner.components.machinelearning.impl.MLFeatureModifierBuilderFactoryImpl;
import org.datacleaner.components.machinelearning.impl.MLFeatureUtils;
import org.datacleaner.components.machinelearning.impl.MLRegressionRecordImpl;
import org.datacleaner.util.Percentage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Categorized({MachineLearningCategory.class})
/* loaded from: input_file:org/datacleaner/components/machinelearning/MLRegressionTrainingAnalyzer.class */
public abstract class MLRegressionTrainingAnalyzer extends MLTrainingAnalyzer<MLRegressionAnalyzerResult> {
    private static final Logger logger = LoggerFactory.getLogger(MLRegressionTrainingAnalyzer.class);
    private static final MLFeatureModifierBuilderFactory featureModifierBuilderFactory = new MLFeatureModifierBuilderFactoryImpl();

    @Configured
    InputColumn<Number> regressionOutput;

    @NumberProperty(negative = false)
    @Configured
    @Description("Determine how much (if any) of the records should be used for cross-validation.")
    Percentage crossValidationSampleRate = new Percentage(10);
    private AtomicInteger recordCounter;
    private Collection<MLRegressionRecord> trainingRecords;
    private Collection<MLRegressionRecord> crossValidationRecords;
    private List<MLFeatureModifierBuilder> featureModifierBuilders;

    @Initialize
    public void init() {
        this.recordCounter = new AtomicInteger();
        this.trainingRecords = new ConcurrentLinkedQueue();
        this.crossValidationRecords = new ConcurrentLinkedQueue();
        this.featureModifierBuilders = new ArrayList(this.featureModifierTypes.length);
        MLTrainingConstraints mLTrainingConstraints = new MLTrainingConstraints(this.maxFeaturesGeneratedPerColumn == null ? -1 : this.maxFeaturesGeneratedPerColumn.intValue(), this.includeUniqueValueFeatures);
        for (MLFeatureModifierType mLFeatureModifierType : this.featureModifierTypes) {
            this.featureModifierBuilders.add(featureModifierBuilderFactory.create(mLFeatureModifierType, mLTrainingConstraints));
        }
    }

    public void run(InputRow inputRow, int i) {
        MLRegressionRecord forTraining = MLRegressionRecordImpl.forTraining(inputRow, this.regressionOutput, this.featureColumns);
        if (forTraining == null) {
            return;
        }
        Object[] recordValues = forTraining.getRecordValues();
        for (int i2 = 0; i2 < recordValues.length; i2++) {
            this.featureModifierBuilders.get(i2).addRecordValue(recordValues[i2]);
        }
        if (this.recordCounter.incrementAndGet() % 100 > this.crossValidationSampleRate.getNominator()) {
            this.trainingRecords.add(forTraining);
        } else {
            this.crossValidationRecords.add(forTraining);
        }
    }

    /* renamed from: getResult, reason: merged with bridge method [inline-methods] */
    public MLRegressionAnalyzerResult m5getResult() {
        List<MLFeatureModifier> list = (List) this.featureModifierBuilders.stream().map((v0) -> {
            return v0.build();
        }).collect(Collectors.toList());
        List map = CollectionUtils.map(this.featureColumns, new HasNameMapper());
        MLRegressorTrainer createTrainer = createTrainer(new MLTrainingOptions(Double.class, map, list));
        log("Training model starting. Records=" + this.trainingRecords.size() + ", Columns=" + map.size() + ", Features=" + MLFeatureUtils.getFeatureCount(list) + ".");
        MLRegressor train = createTrainer.train(this.trainingRecords, list, new MLTrainerCallback() { // from class: org.datacleaner.components.machinelearning.MLRegressionTrainingAnalyzer.1
            @Override // org.datacleaner.components.machinelearning.api.MLTrainerCallback
            public void epochDone(int i, int i2) {
                if (i2 > 1) {
                    MLRegressionTrainingAnalyzer.this.log("Training progress: Epoch " + i + " of " + i2 + " done.");
                }
            }
        });
        if (this.saveModelToFile != null) {
            logger.info("Saving model to file: {}", this.saveModelToFile);
            try {
                Files.write(SerializationUtils.serialize(train), this.saveModelToFile);
            } catch (IOException e) {
                throw new UncheckedIOException("Failed to save model to file: " + this.saveModelToFile, e);
            }
        }
        log("Trained model. Creating evaluation matrices.");
        return new MLRegressionAnalyzerResult(train);
    }

    protected abstract MLRegressorTrainer createTrainer(MLTrainingOptions mLTrainingOptions);
}
