/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.ai.chat.client.advisor.vectorstore;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.springframework.ai.chat.client.ChatClientMessageAggregator;
import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.client.ChatClientResponse;
import org.springframework.ai.chat.client.advisor.api.AdvisorChain;
import org.springframework.ai.chat.client.advisor.api.BaseAdvisor;
import org.springframework.ai.chat.client.advisor.api.BaseChatMemoryAdvisor;
import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.util.Assert;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler;

public final class VectorStoreChatMemoryAdvisor
implements BaseChatMemoryAdvisor {
    public static final String TOP_K = "chat_memory_vector_store_top_k";
    private static final String DOCUMENT_METADATA_CONVERSATION_ID = "conversationId";
    private static final String DOCUMENT_METADATA_MESSAGE_TYPE = "messageType";
    private static final int DEFAULT_TOP_K = 20;
    private static final PromptTemplate DEFAULT_SYSTEM_PROMPT_TEMPLATE = new PromptTemplate("{instructions}\n\nUse the long term conversation memory from the LONG_TERM_MEMORY section to provide accurate answers.\n\n---------------------\nLONG_TERM_MEMORY:\n{long_term_memory}\n---------------------\n");
    private final PromptTemplate systemPromptTemplate;
    private final int defaultTopK;
    private final String defaultConversationId;
    private final int order;
    private final Scheduler scheduler;
    private final VectorStore vectorStore;

    private VectorStoreChatMemoryAdvisor(PromptTemplate systemPromptTemplate, int defaultTopK, String defaultConversationId, int order, Scheduler scheduler, VectorStore vectorStore) {
        Assert.notNull((Object)systemPromptTemplate, (String)"systemPromptTemplate cannot be null");
        Assert.isTrue((defaultTopK > 0 ? 1 : 0) != 0, (String)"topK must be greater than 0");
        Assert.hasText((String)defaultConversationId, (String)"defaultConversationId cannot be null or empty");
        Assert.notNull((Object)scheduler, (String)"scheduler cannot be null");
        Assert.notNull((Object)vectorStore, (String)"vectorStore cannot be null");
        this.systemPromptTemplate = systemPromptTemplate;
        this.defaultTopK = defaultTopK;
        this.defaultConversationId = defaultConversationId;
        this.order = order;
        this.scheduler = scheduler;
        this.vectorStore = vectorStore;
    }

    public static Builder builder(VectorStore chatMemory) {
        return new Builder(chatMemory);
    }

    public int getOrder() {
        return this.order;
    }

    public Scheduler getScheduler() {
        return this.scheduler;
    }

    public ChatClientRequest before(ChatClientRequest request, AdvisorChain advisorChain) {
        String conversationId = this.getConversationId(request.context(), this.defaultConversationId);
        String query = request.prompt().getUserMessage() != null ? request.prompt().getUserMessage().getText() : "";
        int topK = this.getChatMemoryTopK(request.context());
        String filter = "conversationId=='" + conversationId + "'";
        SearchRequest searchRequest = SearchRequest.builder().query(query).topK(topK).filterExpression(filter).build();
        List documents = this.vectorStore.similaritySearch(searchRequest);
        String longTermMemory = documents == null ? "" : documents.stream().map(Document::getText).collect(Collectors.joining(System.lineSeparator()));
        SystemMessage systemMessage = request.prompt().getSystemMessage();
        String augmentedSystemText = this.systemPromptTemplate.render(Map.of("instructions", systemMessage.getText(), "long_term_memory", longTermMemory));
        ChatClientRequest processedChatClientRequest = request.mutate().prompt(request.prompt().augmentSystemMessage(augmentedSystemText)).build();
        UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage();
        if (userMessage != null) {
            this.vectorStore.write(this.toDocuments(List.of(userMessage), conversationId));
        }
        return processedChatClientRequest;
    }

    private int getChatMemoryTopK(Map<String, Object> context) {
        return context.containsKey(TOP_K) ? Integer.parseInt(context.get(TOP_K).toString()) : this.defaultTopK;
    }

    public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorChain advisorChain) {
        List<Message> assistantMessages = new ArrayList<Message>();
        if (chatClientResponse.chatResponse() != null) {
            assistantMessages = chatClientResponse.chatResponse().getResults().stream().map(g -> g.getOutput()).toList();
        }
        this.vectorStore.write(this.toDocuments(assistantMessages, this.getConversationId(chatClientResponse.context(), this.defaultConversationId)));
        return chatClientResponse;
    }

    public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest, StreamAdvisorChain streamAdvisorChain) {
        Scheduler scheduler = this.getScheduler();
        return Mono.just((Object)chatClientRequest).publishOn(scheduler).map(request -> this.before((ChatClientRequest)request, (AdvisorChain)streamAdvisorChain)).flatMapMany(arg_0 -> ((StreamAdvisorChain)streamAdvisorChain).nextStream(arg_0)).transform(flux -> new ChatClientMessageAggregator().aggregateChatClientResponse(flux, response -> this.after((ChatClientResponse)response, (AdvisorChain)streamAdvisorChain)));
    }

    private List<Document> toDocuments(List<Message> messages, String conversationId) {
        List<Document> docs = messages.stream().filter(m -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT).map(message -> {
            HashMap<String, String> metadata = new HashMap<String, String>(message.getMetadata() != null ? message.getMetadata() : new HashMap());
            metadata.put(DOCUMENT_METADATA_CONVERSATION_ID, conversationId);
            metadata.put(DOCUMENT_METADATA_MESSAGE_TYPE, message.getMessageType().name());
            if (message instanceof UserMessage) {
                UserMessage userMessage = (UserMessage)message;
                return Document.builder().text(userMessage.getText()).metadata(metadata).build();
            }
            if (message instanceof AssistantMessage) {
                AssistantMessage assistantMessage = (AssistantMessage)message;
                return Document.builder().text(assistantMessage.getText()).metadata(metadata).build();
            }
            throw new RuntimeException("Unknown message type: " + String.valueOf(message.getMessageType()));
        }).toList();
        return docs;
    }

    public static class Builder {
        private PromptTemplate systemPromptTemplate = DEFAULT_SYSTEM_PROMPT_TEMPLATE;
        private Integer defaultTopK = 20;
        private String conversationId = "default";
        private Scheduler scheduler = BaseAdvisor.DEFAULT_SCHEDULER;
        private int order = -2147482648;
        private VectorStore vectorStore;

        protected Builder(VectorStore vectorStore) {
            this.vectorStore = vectorStore;
        }

        public Builder systemPromptTemplate(PromptTemplate systemPromptTemplate) {
            this.systemPromptTemplate = systemPromptTemplate;
            return this;
        }

        public Builder defaultTopK(int defaultTopK) {
            this.defaultTopK = defaultTopK;
            return this;
        }

        public Builder conversationId(String conversationId) {
            this.conversationId = conversationId;
            return this;
        }

        public Builder scheduler(Scheduler scheduler) {
            this.scheduler = scheduler;
            return this;
        }

        public Builder order(int order) {
            this.order = order;
            return this;
        }

        public VectorStoreChatMemoryAdvisor build() {
            return new VectorStoreChatMemoryAdvisor(this.systemPromptTemplate, this.defaultTopK, this.conversationId, this.order, this.scheduler, this.vectorStore);
        }
    }
}

