#!/usr/bin/python3

import argparse
import cmd
import itertools
import json
import os
import signal
import sys
import time
import urllib.error
import urllib.request


def should_colorize():
    t = os.getenv("TERM")
    return t and t != "dumb" and sys.stdout.isatty()


def res(response, color):
    color_default = ""
    color_yellow = ""
    if (color == "auto" and should_colorize()) or color == "always":
        color_default = "\033[0m"
        color_yellow = "\033[33m"

    print("\r", end="")
    assistant_response = ""
    for line in response:
        line = line.decode("utf-8").strip()
        if line.startswith("data: {"):
            line = line[len("data: ") :]
            choice = json.loads(line)["choices"][0]["delta"]
            if "content" in choice:
                choice = choice["content"]
            else:
                continue

            if choice:
                print(f"{color_yellow}{choice}{color_default}", end="", flush=True)
                assistant_response += choice

    print("")
    return assistant_response


class RamaLamaShell(cmd.Cmd):
    def __init__(self, parsed_args):
        super().__init__()
        self.conversation_history = []
        self.parsed_args = parsed_args
        self.request_in_process = False
        if "LLAMA_PROMPT_PREFIX" in os.environ:
            self.prompt = os.environ["LLAMA_PROMPT_PREFIX"]
        else:
            self.prompt = parsed_args.prefix

        self.url = f"{parsed_args.host}/v1/chat/completions"
        self.models_url = f"{parsed_args.host}/v1/models"
        self.models = []

    def model(self, index=0):
        try:
            if len(self.models) == 0:
                self.models = self.get_models()
            return self.models[index]
        except urllib.error.URLError:
            return ""
        
    def get_models(self):
        request = urllib.request.Request(self.models_url, method="GET")
        response = urllib.request.urlopen(request)
        for line in response:
            line = line.decode("utf-8").strip()
            return([d['id'] for d in json.loads(line)["data"]])

    def do_EOF(self, user_content):
        print("")
        return True

    def default(self, user_content):
        if user_content in ["/bye", "exit"]:
            return True

        self.conversation_history.append({"role": "user", "content": user_content})
        self.request_in_process = True
        response = self._req()
        if not response:
            return True

        self.conversation_history.append(
            {"role": "assistant", "content": response}
        )
        self.request_in_process = False

    def _req(self):
        data = {
            "stream": True,
            "messages": self.conversation_history,
            "model": self.model(),
        }

        json_data = json.dumps(data).encode("utf-8")
        headers = {
            "Content-Type": "application/json",
        }

        # Create a request
        request = urllib.request.Request(self.url, data=json_data, headers=headers, method="POST")

        # Send request
        i = 0.01
        response = None
        for c in itertools.cycle(['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏']):
            try:
                response = urllib.request.urlopen(request)
                break
            except Exception:
                if sys.stdout.isatty():
                    print(f"\r{c}", end="", flush=True)

                if i > 32:
                    break

                time.sleep(i)
                i *= 2

        if response:
            return res(response, self.parsed_args.color)

        print(f"\rError: could not connect to: {self.url}", file=sys.stderr)
        self.kills()

        return None

    def kills(self):
        if self.parsed_args.pid2kill:
            os.kill(self.parsed_args.pid2kill, signal.SIGINT)
            os.kill(self.parsed_args.pid2kill, signal.SIGTERM)
            os.kill(self.parsed_args.pid2kill, signal.SIGKILL)

    def loop(self):
        while True:
            self.request_in_process = False
            try:
                self.cmdloop()
            except KeyboardInterrupt:
                print("")
                if not self.request_in_process:
                    print("Use Ctrl + d or /bye or exit to quit.")

                continue

            break

    def handle_args(self):
        if self.parsed_args.ARGS:
            self.default(" ".join(self.parsed_args.ARGS))
            self.kills()
            return True

        return False


def parse_arguments(args):
    parser = argparse.ArgumentParser(description="Run ramalama client core")
    parser.add_argument(
        '--color',
        '--colour',
        default="auto",
        choices=['never', 'always', 'auto'],
        help='possible values are "never", "always" and "auto".',
    )
    parser.add_argument("--prefix", type=str, default="> ", help="prefix for the user prompt")
    parser.add_argument(
        "host", type=str, nargs="?", default="http://127.0.0.1:8080", help="the host to send requests to"
    )
    parser.add_argument(
        "-c",
        "--ctx-size",
        dest="context",
        default=2048,
        help="size of the prompt context (0 = loaded from model)",
    )
    parser.add_argument("--jinja", action="store_true", help="enable jinja")
    parser.add_argument("--kill-server", dest="pid2kill", type=int, help="server process to kill on termination of client")
    parser.add_argument("--temp", default=0.8, help="temperature of the response from the AI model")
    parser.add_argument(
        "ARGS", nargs="*", help="overrides the default prompt, and the output is returned without entering the chatbot"
    )

    return parser.parse_args(args)


def main(args):
    sys.path.append('./')

    parsed_args = parse_arguments(args)
    ramalama_shell = RamaLamaShell(parsed_args)
    if ramalama_shell.handle_args():
        return 0

    ramalama_shell.loop()
    ramalama_shell.kills()


if __name__ == '__main__':
    main(sys.argv[1:])
