package dev.langchain4j.guardrail;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.internal.ValidationUtils;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:dev/langchain4j/guardrail/JsonExtractorOutputGuardrail.class */
public class JsonExtractorOutputGuardrail<T> implements OutputGuardrail {
    public static final String DEFAULT_REPROMPT_MESSAGE = "Invalid JSON";
    public static final String DEFAULT_REPROMPT_PROMPT = "Make sure you return a valid JSON object following the specified format";
    private static final Logger LOGGER = LoggerFactory.getLogger(JsonExtractorOutputGuardrail.class);
    private final ObjectMapper objectMapper;
    private Class<T> outputClass;
    private TypeReference<T> outputType;

    public JsonExtractorOutputGuardrail(ObjectMapper objectMapper, Class<T> cls) {
        this.objectMapper = (ObjectMapper) ValidationUtils.ensureNotNull(objectMapper, "objectMapper");
        this.outputClass = (Class) ValidationUtils.ensureNotNull(cls, "outputClass");
    }

    public JsonExtractorOutputGuardrail(ObjectMapper objectMapper, TypeReference<T> typeReference) {
        this.objectMapper = (ObjectMapper) ValidationUtils.ensureNotNull(objectMapper, "objectMapper");
        this.outputType = (TypeReference) ValidationUtils.ensureNotNull(typeReference, "outputType");
    }

    public JsonExtractorOutputGuardrail(Class<T> cls) {
        this(new ObjectMapper(), cls);
    }

    public JsonExtractorOutputGuardrail(TypeReference<T> typeReference) {
        this(new ObjectMapper(), typeReference);
    }

    @Override // dev.langchain4j.guardrail.OutputGuardrail
    public OutputGuardrailResult validate(AiMessage aiMessage) {
        String text = ((AiMessage) ValidationUtils.ensureNotNull(aiMessage, "responseFromLLM")).text();
        LOGGER.debug("LLM output: {}", text);
        return (OutputGuardrailResult) deserialize(text).map(obj -> {
            return successWith(text, obj);
        }).orElseGet(() -> {
            LOGGER.debug("LLM output contained invalid JSON. Attempting to trim non-JSON");
            String trimNonJson = trimNonJson(text);
            LOGGER.debug("Attempting to deserialize trimmed JSON: {}", trimNonJson);
            return (OutputGuardrailResult) deserialize(trimNonJson).map(obj2 -> {
                return successWith(trimNonJson, obj2);
            }).orElseGet(() -> {
                return invokeInvalidJson(aiMessage, trimNonJson);
            });
        });
    }

    protected String trimNonJson(String str) {
        int indexOf = str.indexOf(123);
        int indexOf2 = str.indexOf(91);
        if (indexOf < 0 && indexOf2 < 0) {
            return "";
        }
        boolean z = indexOf >= 0 && (indexOf < indexOf2 || indexOf2 < 0);
        int i = z ? indexOf : indexOf2;
        int lastIndexOf = z ? str.lastIndexOf(125) : str.lastIndexOf(93);
        return (lastIndexOf < 0 || i >= lastIndexOf) ? "" : str.substring(i, lastIndexOf + 1);
    }

    protected OutputGuardrailResult invokeInvalidJson(AiMessage aiMessage, String str) {
        LOGGER.debug("Found invalid JSON for aiMessage = {} and json = {}", aiMessage, str);
        return reprompt(getInvalidJsonMessage(aiMessage, str), getInvalidJsonReprompt(aiMessage, str));
    }

    protected String getInvalidJsonMessage(AiMessage aiMessage, String str) {
        return DEFAULT_REPROMPT_MESSAGE;
    }

    protected String getInvalidJsonReprompt(AiMessage aiMessage, String str) {
        return DEFAULT_REPROMPT_PROMPT;
    }

    protected Optional<T> deserialize(String str) {
        try {
            return Optional.ofNullable(this.outputClass != null ? this.objectMapper.readValue(str, this.outputClass) : this.objectMapper.readValue(str, this.outputType));
        } catch (JsonProcessingException e) {
            return Optional.empty();
        }
    }
}
