import {createSlice, PayloadAction} from "@reduxjs/toolkit";
import {chatApi} from "./chatApi";
import {
    Message,
    RagProcessMessage,
    RagSourceMessage,
    System,
    Thread,
    FilteredContentMessage,
} from "../../types";

interface ThreadState {
    /* threads: Record<string, Thread>; */
    id: string | undefined;
    name: string | undefined;
    datestamp: Date | undefined;
    messages: Message[];
    isStreaming: boolean;
    system: System | null;
    filteredContent: FilteredContentMessage | null;
}

const initialState: ThreadState = {
    /* threads: {}, */
    id: undefined,
    name: undefined,
    datestamp: undefined,
    messages: [],
    isStreaming: false,
    system: {
        name: "DIB",
    },
    filteredContent: null,
};

const threadSlice = createSlice({
    name: "thread",
    initialState,
    reducers: {
        finishAssistantMessage(state, action: PayloadAction<string>) {
            // Set the last assistant message to the finalized response
            const lastMessage = state.messages[state.messages.length - 1];

            if (lastMessage?.role === "assistant") {
                lastMessage.content = action.payload;
            }

            state.isStreaming = false;
        },
        addPartialMessage(state, action: PayloadAction<string>) {
            const lastMessage = state.messages[state.messages.length - 1];

            if (lastMessage?.role === "assistant") {
                lastMessage.content += action.payload;
            }
        },
        replaceMessage(state, action: PayloadAction<string>) {
            const lastMessage = state.messages[state.messages.length - 1];

            if (lastMessage?.role === "assistant") {
                lastMessage.content = action.payload;
            }
        },
        addMessage(state, action: PayloadAction<Message>) {
            state.messages.push(action.payload);
        },
        updateName(state, action: PayloadAction<string>) {
            state.name = action.payload;
        },
        updateDatestamp(state, action: PayloadAction<Date>) {
            state.datestamp = action.payload;
        },
        updateId(state, action: PayloadAction<string>) {
            state.id = action.payload;
        },
        updateThread(
            state,
            action: PayloadAction<{
                id: string;
                name: string;
                messages: Message[];
                datestamp: Date;
            }>
        ) {
            const {id, name, messages, datestamp} = action.payload;

            state.id = id;
            state.name = name;
            state.datestamp = datestamp;
            // Replace the messages without IDs with the new ones
            // This way we handle the first temporary message from the user
            // that doesnt contain any ID
            state.messages = [
                ...state.messages.filter((m) => !!m.id),
                ...messages,
            ];
        },
        addRagProcess(state, action: PayloadAction<RagProcessMessage>) {
            const lastMessage = state.messages[state.messages.length - 1];

            if (lastMessage.ragProcesses) {
                lastMessage.ragProcesses.push(action.payload);
            } else {
                lastMessage.ragProcesses = [action.payload];
            }
        },
        addRagSource(state, action: PayloadAction<RagSourceMessage>) {
            const lastMessage = state.messages[state.messages.length - 1];

            if (lastMessage.ragSources) {
                lastMessage.ragSources.push(action.payload);
            } else {
                lastMessage.ragSources = [action.payload];
            }
        },
        addFilteredContentError(
            state,
            action: PayloadAction<FilteredContentMessage>
        ) {
            state.filteredContent = action.payload;
        },
        clearThread(state) {
            return {
                ...initialState,
                system: state.system,
            };
        },
        setSystem(state, action: PayloadAction<System>) {
            state.system = action.payload;
        },
    },
    extraReducers: (builder) => {
        builder.addMatcher(
            chatApi.endpoints.sendMessage.matchPending,
            (state, action) => {
                const id = action.meta.arg.originalArgs.messageId;

                // If the message is a resend, slice the messages up to the message
                if (id) {
                    const i = state.messages.findIndex((m) => m.id === id);

                    if (i !== -1) {
                        state.messages = state.messages.slice(0, i);
                    }
                }

                state.isStreaming = true;
                state.filteredContent = null;
            }
        );
        builder.addMatcher(
            chatApi.endpoints.getThread.matchFulfilled,
            (state, action) => {
                // TODO: understand why this is necessary
                if (!state.isStreaming) {
                    state.messages = action.payload?.messages ?? [];
                    state.id = action.payload?.id;
                    state.name = action.payload?.name;
                    state.datestamp = action.payload?.datestamp;
                }
            }
        );
        /* builder.addMatcher(
            chatApi.endpoints.getThreads.matchFulfilled,
            (state, action) => {
                action.payload.forEach((thread) => {
                    state.threads[thread.id] = thread;
                });
            }
        ); */
    },
});

export const {
    finishAssistantMessage,
    addPartialMessage,
    replaceMessage,
    updateThread,
    addRagProcess,
    addRagSource,
    addMessage,
    addFilteredContentError,
    setSystem,
    clearThread,
} = threadSlice.actions;

export default threadSlice.reducer;
