#!/usr/bin/env python3

# All imports first
import argparse
import cmd
import http.client
import os
import socket
import sys
import time
import warnings
from pathlib import Path

import openai
import qdrant_client
from fastembed import SparseTextEmbedding, TextEmbedding
from fastembed.rerank.cross_encoder import TextCrossEncoder
from pymilvus import AnnSearchRequest, MilvusClient, RRFRanker

warnings.filterwarnings("ignore", category=UserWarning)

# Global Vars
EMBED_MODEL = os.getenv("EMBED_MODEL", "jinaai/jina-embeddings-v2-small-en")
SPARSE_MODEL = os.getenv("SPARSE_MODEL", "prithivida/Splade_PP_en_v1")
RANK_MODEL = os.getenv("RANK_MODEL", "Xenova/ms-marco-MiniLM-L-6-v2")
COLLECTION_NAME = "rag"
# Needed for mac to not give errors
os.environ["TOKENIZERS_PARALLELISM"] = "true"


def eprint(e, exit_code):
    print("Error:", str(e).strip("'\""), file=sys.stderr)
    sys.exit(exit_code)


# Helper Classes and Functions


def wait_for_llama_server(host: str, port: int, timeout: int = 10):
    end_time = time.monotonic() + timeout
    first_attempt = True

    while time.monotonic() < end_time:
        conn = None
        try:
            conn = http.client.HTTPConnection(host, port, timeout=5)
            conn.request("GET", "/health")
            resp = conn.getresponse()
            if resp.status == 200:
                return True
            elif first_attempt:
                print(f"Server at {host}:{port} is running but not ready (status={resp.status}), retrying...")
                first_attempt = False
        except OSError:
            if first_attempt:
                print(f"No server responding at {host}:{port}, retrying for up to {timeout} seconds...")
                first_attempt = False
        except socket.gaierror:
            print(f"Cannot resolve host '{host}'", file=sys.stderr)
            sys.exit(1)
        finally:
            if conn:
                conn.close()
    print(f"Error: llama-server at {host}:{port} did not become ready after {timeout} seconds.")
    sys.exit(1)


class qdrant:
    def __init__(self, vector_path):
        self.client = qdrant_client.QdrantClient(path=vector_path)
        self.client.set_model(EMBED_MODEL)
        self.client.set_sparse_model(SPARSE_MODEL)

    def query(self, prompt):
        results = self.client.query(
            collection_name="rag",
            query_text=prompt,
            limit=20,
        )
        return [r.document for r in results]


class milvus:
    def __init__(self, vector_path):
        self.milvus_client = MilvusClient(uri=os.path.join(vector_path, "milvus.db"))
        self.dmodel = TextEmbedding(model_name=EMBED_MODEL)
        self.smodel = SparseTextEmbedding(model_name=SPARSE_MODEL)

    def query(self, prompt):
        dense_embedding = next(self.dmodel.embed([prompt]))
        sparse_embedding = next(self.smodel.embed([prompt])).as_dict()

        search_param_dense = {
            "data": [dense_embedding],
            "anns_field": "dense",
            "param": {"metric_type": "IP", "params": {"nprobe": 10}},
            "limit": 10,
        }

        request_dense = AnnSearchRequest(**search_param_dense)

        search_param_sparse = {
            "data": [sparse_embedding],
            "anns_field": "sparse",
            "param": {"metric_type": "IP", "params": {"drop_ratio_build": 0.2}},
            "limit": 10,
        }

        request_sparse = AnnSearchRequest(**search_param_sparse)

        reqs = [request_dense, request_sparse]

        ranker = RRFRanker(100)

        results = self.milvus_client.hybrid_search(
            collection_name=COLLECTION_NAME,
            reqs=reqs,
            ranker=ranker,
            limit=20,
            output_fields=["text"],
        )
        return [hit["entity"]["text"] for hit in results[0]]


class Rag(cmd.Cmd):
    prompt = "🦭  > "

    def __init__(self, vector_path):
        # Initlialze the cmd class
        super().__init__()

        self.reranker = TextCrossEncoder(model_name=RANK_MODEL)

        if self.is_milvus(vector_path):
            # setup mivlus
            self.vectordb = milvus(vector_path)
        else:
            # setup qdrant
            self.vectordb = qdrant(vector_path)

        # Setup openai api
        self.port = int(os.getenv("RAMALAMA_PORT", "8080"))
        wait_for_llama_server("localhost", self.port)

        self.llm = openai.OpenAI(api_key="your-api-key", base_url=f"http://localhost:{self.port}")
        self.chat_history = []

    def is_milvus(self, vector_path):
        return any(f.suffix == ".db" and f.is_file() for f in Path(vector_path).iterdir())

    def do_EOF(self, user_content):
        print("")
        return True

    def query(self, prompt):
        # Add user query to chat history
        self.chat_history.append({"role": "user", "content": prompt})

        # Ensure chat history does not exceed 10 messages (5 user + 5 AI)
        if len(self.chat_history) > 10:
            self.chat_history.pop(0)  # Remove the oldest message

        # Query the Vectordb
        result = self.vectordb.query(prompt)
        # reranker code to have the first 5 queries
        reranked_context = " ".join(
            str(result[i])
            for i, _ in sorted(enumerate(self.reranker.rerank(prompt, result)), key=lambda x: x[1], reverse=True)[:5]
        )

        # Prepare the metaprompt with chat history and context
        metaprompt = f"""
            You are an expert software architect.
            Use the provided context and chat history to answer the question accurately and concisely.
            If the answer is not explicitly stated, infer the most reasonable answer based on the available information.
            If there is no relevant information, respond with "I don't know"—do not fabricate details.

            ### Chat History:
            {self.format_chat_history()}

            ### Context:
            {reranked_context.strip()}

            ### Question:
            {prompt.strip()}

            ### Answer:
            """

        # Query the LLM with the metaprompt
        response = self.llm.chat.completions.create(
            model="your-model-name", messages=[{"role": "user", "content": metaprompt}], stream=True
        )

        # Collect the AI response and add it to chat history
        full_response = ""
        for chunk in response:
            if chunk.choices[0].delta.content:
                full_response += chunk.choices[0].delta.content
                print(chunk.choices[0].delta.content, end="", flush=True)

        # Add AI response to chat history
        self.chat_history.append({"role": "assistant", "content": full_response})

        # Ensure chat history does not exceed 10 messages after adding the AI response
        if len(self.chat_history) > 10:
            self.chat_history.pop(0)  # Remove the oldest message
        print(" ")

    def format_chat_history(self):
        """Format the chat history into a string for inclusion in the metaprompt."""
        formatted_history = []
        for i in range(0, len(self.chat_history), 2):
            user_message = self.chat_history[i]["content"]
            if i + 1 < len(self.chat_history):
                ai_message = self.chat_history[i + 1]["content"]
                formatted_history.append(f"User: {user_message}\nAI: {ai_message}")
            else:
                formatted_history.append(f"User: {user_message}\nAI: ")
        return "\n".join(formatted_history)

    def default(self, user_content):
        if user_content == "/bye":
            return True

        self.query(user_content)


def run_rag(vector_path):
    rag = Rag(vector_path)
    try:
        rag.cmdloop()
    except KeyboardInterrupt:
        print("")


def load():
    client = qdrant_client.QdrantClient(":memory:")
    client.set_model(EMBED_MODEL)
    client.set_sparse_model(SPARSE_MODEL)
    TextCrossEncoder(model_name=RANK_MODEL)


parser = argparse.ArgumentParser(description="A script to enable Rag")
subparsers = parser.add_subparsers(dest='command')

run_parser = subparsers.add_parser('run', help='Run RAG interactively')
run_parser.add_argument("vector_path", type=str, help="Path to the vector database")
run_parser.set_defaults(func=run_rag)

load_parser = subparsers.add_parser('load', help='Preload RAG Embedding Models')
load_parser.set_defaults(func=load)

try:
    args = parser.parse_args()

    if args.command:
        if args.command in ['run']:
            args.func(args.vector_path)
        else:
            # no argument for 'load'
            args.func()
except ValueError as e:
    eprint(e, 1)
