/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.service;

import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolExecutor;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.agent.tool.ToolSpecifications;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.exception.IllegalConfigurationException;
import dev.langchain4j.internal.Exceptions;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.input.structured.StructuredPrompt;
import dev.langchain4j.model.input.structured.StructuredPromptProcessor;
import dev.langchain4j.model.moderation.Moderation;
import dev.langchain4j.model.moderation.ModerationModel;
import dev.langchain4j.retriever.Retriever;
import dev.langchain4j.service.AiServiceContext;
import dev.langchain4j.service.AiServiceTokenStream;
import dev.langchain4j.service.MemoryId;
import dev.langchain4j.service.Moderate;
import dev.langchain4j.service.ModerationException;
import dev.langchain4j.service.ServiceOutputParser;
import dev.langchain4j.service.SystemMessage;
import dev.langchain4j.service.TokenStream;
import dev.langchain4j.service.UserMessage;
import dev.langchain4j.service.UserName;
import dev.langchain4j.service.V;
import java.lang.reflect.Array;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.lang.reflect.Proxy;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AiServices<T> {
    private final Logger log = LoggerFactory.getLogger(AiServices.class);
    private static final String DEFAULT = "default";
    private final AiServiceContext context = new AiServiceContext();

    private AiServices(Class<T> aiServiceClass) {
        this.context.aiServiceClass = aiServiceClass;
    }

    public static <T> T create(Class<T> aiService, ChatLanguageModel chatLanguageModel) {
        return AiServices.builder(aiService).chatLanguageModel(chatLanguageModel).build();
    }

    public static <T> T create(Class<T> aiService, StreamingChatLanguageModel streamingChatLanguageModel) {
        return AiServices.builder(aiService).streamingChatLanguageModel(streamingChatLanguageModel).build();
    }

    public static <T> AiServices<T> builder(Class<T> aiService) {
        return new AiServices<T>(aiService);
    }

    public AiServices<T> chatLanguageModel(ChatLanguageModel chatLanguageModel) {
        this.context.chatLanguageModel = chatLanguageModel;
        return this;
    }

    public AiServices<T> streamingChatLanguageModel(StreamingChatLanguageModel streamingChatLanguageModel) {
        this.context.streamingChatLanguageModel = streamingChatLanguageModel;
        return this;
    }

    public AiServices<T> chatMemory(ChatMemory chatMemory) {
        this.context.chatMemories = new ConcurrentHashMap<Object, ChatMemory>();
        this.context.chatMemories.put(DEFAULT, chatMemory);
        return this;
    }

    public AiServices<T> chatMemoryProvider(ChatMemoryProvider chatMemoryProvider) {
        this.context.chatMemories = new ConcurrentHashMap<Object, ChatMemory>();
        this.context.chatMemoryProvider = chatMemoryProvider;
        return this;
    }

    public AiServices<T> moderationModel(ModerationModel moderationModel) {
        this.context.moderationModel = moderationModel;
        return this;
    }

    public AiServices<T> tools(Object ... objectsWithTools) {
        return this.tools(Arrays.asList(objectsWithTools));
    }

    public AiServices<T> tools(List<Object> objectsWithTools) {
        this.context.toolSpecifications = new ArrayList<ToolSpecification>();
        this.context.toolExecutors = new HashMap<String, ToolExecutor>();
        for (Object objectWithTool : objectsWithTools) {
            for (Method method : objectWithTool.getClass().getDeclaredMethods()) {
                if (!method.isAnnotationPresent(Tool.class)) continue;
                ToolSpecification toolSpecification = ToolSpecifications.toolSpecificationFrom((Method)method);
                this.context.toolSpecifications.add(toolSpecification);
                this.context.toolExecutors.put(toolSpecification.name(), new ToolExecutor(objectWithTool, method));
            }
        }
        return this;
    }

    public AiServices<T> retriever(Retriever<TextSegment> retriever) {
        this.context.retriever = retriever;
        return this;
    }

    public T build() {
        if (this.context.chatLanguageModel == null && this.context.streamingChatLanguageModel == null) {
            throw IllegalConfigurationException.illegalConfiguration("Please specify either chatLanguageModel or streamingChatLanguageModel");
        }
        for (Method method : this.context.aiServiceClass.getMethods()) {
            if (!method.isAnnotationPresent(Moderate.class) || this.context.moderationModel != null) continue;
            throw IllegalConfigurationException.illegalConfiguration("The @Moderate annotation is present, but the moderationModel is not set up. Please ensure a valid moderationModel is configured before using the @Moderate annotation.");
        }
        if (this.context.toolSpecifications != null && !this.context.hasChatMemory()) {
            throw IllegalConfigurationException.illegalConfiguration("Please set up chatMemory or chatMemoryProvider in order to use tools. A ChatMemory that can hold at least 3 messages is required for the tools to work properly. While the LLM can technically execute a tool without chat memory, if it only receives the result of the tool's execution without the initial message from the user, it won't interpret the result properly.");
        }
        Object proxyInstance = Proxy.newProxyInstance(this.context.aiServiceClass.getClassLoader(), new Class[]{this.context.aiServiceClass}, new InvocationHandler(){
            private final ExecutorService executor = Executors.newCachedThreadPool();

            @Override
            public Object invoke(Object proxy, Method method, Object[] args) throws Exception {
                ArrayList<ChatMessage> messages;
                if (method.getDeclaringClass() == Object.class) {
                    return method.invoke((Object)this, args);
                }
                AiServices.validateParameters(method);
                Optional systemMessage = AiServices.this.prepareSystemMessage(method, args);
                ChatMessage userMessage = AiServices.prepareUserMessage(method, args);
                if (((AiServices)AiServices.this).context.retriever != null) {
                    List relevant = ((AiServices)AiServices.this).context.retriever.findRelevant(userMessage.text());
                    if (relevant == null || relevant.isEmpty()) {
                        AiServices.this.log.debug("No relevant information was found");
                    } else {
                        String relevantConcatenated = relevant.stream().map(TextSegment::text).collect(Collectors.joining("\n\n"));
                        AiServices.this.log.debug("Retrieved relevant information:\n" + relevantConcatenated + "\n");
                        userMessage = dev.langchain4j.data.message.UserMessage.userMessage((String)(userMessage.text() + "\n\nHere is some information that might be useful for answering:\n\n" + relevantConcatenated));
                    }
                }
                String memoryId = AiServices.this.memoryId(method, args).orElse(AiServices.DEFAULT);
                if (AiServices.this.context.hasChatMemory()) {
                    ChatMemory chatMemory = AiServices.this.context.chatMemory(memoryId);
                    systemMessage.ifPresent(it -> this.addIfNeeded((ChatMessage)it, chatMemory));
                    chatMemory.add(userMessage);
                }
                if (AiServices.this.context.hasChatMemory()) {
                    messages = AiServices.this.context.chatMemory(memoryId).messages();
                } else {
                    messages = new ArrayList();
                    systemMessage.ifPresent(messages::add);
                    messages.add(userMessage);
                }
                Future<Moderation> moderationFuture = this.triggerModerationIfNeeded(method, messages);
                if (method.getReturnType() == TokenStream.class) {
                    return new AiServiceTokenStream(messages, AiServices.this.context, memoryId);
                }
                AiMessage aiMessage = ((AiServices)AiServices.this).context.chatLanguageModel.sendMessages(messages, ((AiServices)AiServices.this).context.toolSpecifications);
                this.verifyModerationIfNeeded(moderationFuture);
                while (true) {
                    ToolExecutionRequest toolExecutionRequest;
                    if (AiServices.this.context.hasChatMemory()) {
                        AiServices.this.context.chatMemory(memoryId).add((ChatMessage)aiMessage);
                    }
                    if ((toolExecutionRequest = aiMessage.toolExecutionRequest()) == null) break;
                    ToolExecutor toolExecutor = ((AiServices)AiServices.this).context.toolExecutors.get(toolExecutionRequest.name());
                    String toolExecutionResult = toolExecutor.execute(toolExecutionRequest);
                    ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage.toolExecutionResultMessage((String)toolExecutionRequest.name(), (String)toolExecutionResult);
                    ChatMemory chatMemory = AiServices.this.context.chatMemory(memoryId);
                    chatMemory.add((ChatMessage)toolExecutionResultMessage);
                    aiMessage = ((AiServices)AiServices.this).context.chatLanguageModel.sendMessages(chatMemory.messages(), ((AiServices)AiServices.this).context.toolSpecifications);
                }
                return ServiceOutputParser.parse(aiMessage, method.getReturnType());
            }

            private Future<Moderation> triggerModerationIfNeeded(Method method, List<ChatMessage> messages) {
                if (method.isAnnotationPresent(Moderate.class)) {
                    return this.executor.submit(() -> {
                        List<ChatMessage> messagesToModerate = this.removeToolMessages(messages);
                        return ((AiServices)AiServices.this).context.moderationModel.moderate(messagesToModerate);
                    });
                }
                return null;
            }

            private List<ChatMessage> removeToolMessages(List<ChatMessage> messages) {
                return messages.stream().filter(it -> !(it instanceof ToolExecutionResultMessage)).filter(it -> !(it instanceof AiMessage) || ((AiMessage)it).toolExecutionRequest() == null).collect(Collectors.toList());
            }

            private void verifyModerationIfNeeded(Future<Moderation> moderationFuture) {
                if (moderationFuture != null) {
                    try {
                        Moderation moderation = moderationFuture.get();
                        if (moderation.flagged()) {
                            throw new ModerationException(String.format("Text \"%s\" violates content policy", moderation.flaggedText()));
                        }
                    }
                    catch (InterruptedException | ExecutionException e) {
                        throw new RuntimeException(e);
                    }
                }
            }

            private void addIfNeeded(ChatMessage systemMessage, ChatMemory chatMemory) {
                boolean shouldAddSystemMessage = true;
                List messages = chatMemory.messages();
                for (int i = messages.size() - 1; i >= 0; --i) {
                    if (!(messages.get(i) instanceof dev.langchain4j.data.message.SystemMessage)) continue;
                    if (!((ChatMessage)messages.get(i)).equals(systemMessage)) break;
                    shouldAddSystemMessage = false;
                    break;
                }
                if (shouldAddSystemMessage) {
                    chatMemory.add(systemMessage);
                }
            }
        });
        return (T)proxyInstance;
    }

    private Optional<Object> memoryId(Method method, Object[] args) {
        Parameter[] parameters = method.getParameters();
        for (int i = 0; i < parameters.length; ++i) {
            if (!parameters[i].isAnnotationPresent(MemoryId.class)) continue;
            Object memoryId = args[i];
            if (memoryId == null) {
                throw Exceptions.illegalArgument((String)"The value of parameter %s annotated with @MemoryId in method %s must not be null", (Object[])new Object[]{parameters[i].getName(), method.getName()});
            }
            return Optional.of(memoryId);
        }
        return Optional.empty();
    }

    private Optional<ChatMessage> prepareSystemMessage(Method method, Object[] args) {
        Parameter[] parameters = method.getParameters();
        Map<String, Object> variables = AiServices.getPromptTemplateVariables(args, parameters);
        SystemMessage annotation = method.getAnnotation(SystemMessage.class);
        if (annotation != null) {
            String systemMessageTemplate = String.join((CharSequence)annotation.delimiter(), annotation.value());
            if (systemMessageTemplate.isEmpty()) {
                throw IllegalConfigurationException.illegalConfiguration("@SystemMessage's template cannot be empty");
            }
            Prompt prompt = PromptTemplate.from((String)systemMessageTemplate).apply(variables);
            return Optional.of(prompt.toSystemMessage());
        }
        return Optional.empty();
    }

    private static ChatMessage prepareUserMessage(Method method, Object[] args) {
        Parameter[] parameters = method.getParameters();
        Map<String, Object> variables = AiServices.getPromptTemplateVariables(args, parameters);
        String outputFormatInstructions = ServiceOutputParser.outputFormatInstructions(method.getReturnType());
        String userName = AiServices.getUserName(parameters, args);
        UserMessage annotation = method.getAnnotation(UserMessage.class);
        if (annotation != null) {
            String userMessageTemplate = String.join((CharSequence)annotation.delimiter(), annotation.value()) + outputFormatInstructions;
            if (userMessageTemplate.contains("{{it}}")) {
                if (parameters.length != 1) {
                    throw IllegalConfigurationException.illegalConfiguration("Error: The {{it}} placeholder is present but the method does not have exactly one parameter. Please ensure that methods using the {{it}} placeholder have exactly one parameter.");
                }
                variables = Collections.singletonMap("it", AiServices.toString(args[0]));
            }
            Prompt prompt = PromptTemplate.from((String)userMessageTemplate).apply(variables);
            return dev.langchain4j.data.message.UserMessage.userMessage((String)userName, (String)prompt.text());
        }
        for (int i = 0; i < parameters.length; ++i) {
            if (!parameters[i].isAnnotationPresent(UserMessage.class)) continue;
            return dev.langchain4j.data.message.UserMessage.userMessage((String)userName, (String)(AiServices.toString(args[i]) + outputFormatInstructions));
        }
        if (args == null || args.length == 0) {
            throw IllegalConfigurationException.illegalConfiguration("Method should have at least one argument");
        }
        if (args.length == 1) {
            return dev.langchain4j.data.message.UserMessage.userMessage((String)userName, (String)(AiServices.toString(args[0]) + outputFormatInstructions));
        }
        throw IllegalConfigurationException.illegalConfiguration("For methods with multiple parameters, each parameter must be annotated with @V, @UserMessage, @UserName or @MemoryId");
    }

    private static String getUserName(Parameter[] parameters, Object[] args) {
        for (int i = 0; i < parameters.length; ++i) {
            if (!parameters[i].isAnnotationPresent(UserName.class)) continue;
            return args[i].toString();
        }
        return null;
    }

    private static void validateParameters(Method method) {
        Parameter[] parameters = method.getParameters();
        if (parameters == null || parameters.length < 2) {
            return;
        }
        for (Parameter parameter : parameters) {
            V v = parameter.getAnnotation(V.class);
            UserMessage userMessage = parameter.getAnnotation(UserMessage.class);
            MemoryId memoryId = parameter.getAnnotation(MemoryId.class);
            UserName userName = parameter.getAnnotation(UserName.class);
            if (v != null || userMessage != null || memoryId != null || userName != null) continue;
            throw IllegalConfigurationException.illegalConfiguration("Parameter '%s' of method '%s' should be annotated with @V or @UserMessage or @UserName or @MemoryId", parameter.getName(), method.getName());
        }
    }

    private static Map<String, Object> getPromptTemplateVariables(Object[] args, Parameter[] parameters) {
        HashMap<String, Object> variables = new HashMap<String, Object>();
        for (int i = 0; i < parameters.length; ++i) {
            V varAnnotation = parameters[i].getAnnotation(V.class);
            if (varAnnotation == null) continue;
            String variableName = varAnnotation.value();
            Object variableValue = args[i];
            variables.put(variableName, variableValue);
        }
        return variables;
    }

    private static Object toString(Object arg) {
        if (arg.getClass().isArray()) {
            return AiServices.arrayToString(arg);
        }
        if (arg.getClass().isAnnotationPresent(StructuredPrompt.class)) {
            return StructuredPromptProcessor.toPrompt((Object)arg).text();
        }
        return arg.toString();
    }

    private static String arrayToString(Object arg) {
        StringBuilder sb = new StringBuilder("[");
        int length = Array.getLength(arg);
        for (int i = 0; i < length; ++i) {
            sb.append(AiServices.toString(Array.get(arg, i)));
            if (i >= length - 1) continue;
            sb.append(", ");
        }
        sb.append("]");
        return sb.toString();
    }
}

