#!/usr/bin/python3

import argparse
import atexit
import os
import signal
import subprocess
import sys
from argparse import Namespace


def initialize_parser():
    parser = argparse.ArgumentParser(description="Run ramalama run core")
    parser.add_argument(
        "-c",
        "--ctx-size",
        dest="context",
        default=2048,
        help="size of the prompt context (0 = loaded from model)",
    )
    parser.add_argument("--jinja", action="store_true", help="enable jinja")
    parser.add_argument("--temp", default=0.8, help="temperature of the response from the AI model")
    parser.add_argument(
        "--ngl",
        dest="ngl",
        type=int,
        help="number of layers to offload to the gpu, if available",
    )
    parser.add_argument(
        "-t",
        "--threads",
        type=int,
        help=f"number of cpu threads to use",
    )
    parser.add_argument(
        "-v",
        help=f"verbose",
        action="store_true",
    )
    parser.add_argument("MODEL", type=str, help="Path to the model")  # positional argument
    parser.add_argument(
        "ARGS", nargs="*", help="overrides the default prompt, and the output is returned without entering the chatbot"
    )

    return parser


def initialize_args():
    from ramalama.model import get_available_port_if_any

    parser = initialize_parser()
    parsed_args = parser.parse_args()
    port = get_available_port_if_any(False)

    return parsed_args, port


def main(args):
    sys.path.append('./')
    from ramalama.cli import serve_cli
    from ramalama.common import exec_cmd

    parsed_args, port = initialize_args()

    pid = os.fork()
    if pid == 0:
        signal.pthread_sigmask(signal.SIG_BLOCK, {signal.SIGINT})
        with open(os.devnull, 'w') as devnull:
            os.dup2(devnull.fileno(), sys.stderr.fileno())

        args = Namespace(
            container=False, dryrun=False, engine=None, podman_keep_groups=False,
            image='quay.io/ramalama/ramalama', runtime='llama.cpp',
            store=os.path.expanduser("~/.local/share/ramalama"), use_model_store=False,
            quiet=False, debug=False, subcommand='serve', ngl=-1, threads=6,
            temp=parsed_args.temp, authfile=None, env=[], device=None, name=None,
            oci_runtime=None, privileged=False, pull='newer', seed=None,
            tlsverify=True, context=parsed_args.context, runtime_args=[], network=None,
            detach=False, host='127.0.0.1', webui='on', generate=None, port=str(port),
            rag=None, MODEL=parsed_args.MODEL, func=lambda: None,
            UNRESOLVED_MODEL=parsed_args.MODEL
        )
        serve_cli(args)

        return 0

    client_args = build_args(args, "ramalama-client-core")
    client_args += ['-c', parsed_args.context, '--temp', parsed_args.temp, "--kill-server", str(pid), "http://127.0.0.1:" + str(port)] + parsed_args.ARGS
    exec_cmd(client_args)

    return 0

def build_args(args, new_command):
    new_args = args.copy()
    new_args[0] = new_args[0].replace("ramalama-run-core", new_command)
    new_args = [new_args[0]]

    return new_args

if __name__ == "__main__":
    main(sys.argv)
