#!/usr/bin/python3

import argparse
import atexit
import os
import signal
import subprocess
import sys
from argparse import Namespace


def initialize_parser():
    parser = argparse.ArgumentParser(description="Run ramalama run core")
    parser.add_argument(
        "-c",
        "--ctx-size",
        dest="context",
        default=2048,
        help="size of the prompt context (0 = loaded from model)",
    )
    parser.add_argument(
        "--debug",
        action="store_true",
    )
    parser.add_argument("--jinja", action="store_true", help="enable jinja")
    parser.add_argument("--temp", default=0.8, help="temperature of the response from the AI model")
    parser.add_argument(
        "--ngl",
        dest="ngl",
        type=int,
        help="number of layers to offload to the gpu, if available",
    )
    parser.add_argument(
        "-t",
        "--threads",
        type=int,
        help=f"number of cpu threads to use",
    )
    parser.add_argument(
        "-v",
        help=f"verbose",
        action="store_true",
    )
    parser.add_argument("MODEL", type=str, help="Path to the model")  # positional argument
    parser.add_argument(
        "ARGS", nargs="*", help="overrides the default prompt, and the output is returned without entering the chatbot"
    )

    return parser


def initialize_args():
    from ramalama.model import get_available_port_if_any

    parser = initialize_parser()
    parsed_args = parser.parse_args()
    port = get_available_port_if_any()

    return parsed_args, port


def main(args):
    # This logic ensures that for internal HTTP calls (to 127.0.0.1), the system's proxy settings are bypassed.
    no_proxy_essentials = {"localhost", "127.0.0.1"}

    current_no_proxy_upper = os.environ.get('NO_PROXY', '')
    existing_parts_upper = set(part.strip() for part in current_no_proxy_upper.split(',') if part.strip())
    final_no_proxy_parts_upper = existing_parts_upper.union(no_proxy_essentials)
    os.environ['NO_PROXY'] = ",".join(sorted(list(final_no_proxy_parts_upper)))

    current_no_proxy_lower = os.environ.get('no_proxy', '')
    existing_parts_lower = set(part.strip() for part in current_no_proxy_lower.split(',') if part.strip())
    final_no_proxy_parts_lower = existing_parts_lower.union(no_proxy_essentials)
    os.environ['no_proxy'] = ",".join(sorted(list(final_no_proxy_parts_lower)))

    for proxy_var in ['http_proxy', 'https_proxy', 'HTTP_PROXY', 'HTTPS_PROXY']:
        if proxy_var in os.environ:
            del os.environ[proxy_var]

    sys.path.append('./')
    from ramalama.cli import serve_cli
    from ramalama.common import exec_cmd

    parsed_args, port = initialize_args()

    pid = os.fork()
    if pid == 0:
        signal.pthread_sigmask(signal.SIG_BLOCK, {signal.SIGINT})
        # at this point we are already in a container so, --nocontainer always makes sense
        dont_print_output = (not parsed_args.debug) and (not parsed_args.v)
        exec_cmd(["ramalama", "--nocontainer", "serve", "--threads", str(parsed_args.threads), "--temp", parsed_args.temp, "--port", str(port), parsed_args.MODEL], stdout2null=dont_print_output, stderr2null=dont_print_output)

        return 0

    client_args = build_args(args, "ramalama-client-core")
    client_args += ['-c', parsed_args.context, '--temp', parsed_args.temp, "--kill-server", str(pid), "http://127.0.0.1:" + str(port)] + parsed_args.ARGS
    exec_cmd(client_args)

    return 0

def build_args(args, new_command):
    new_args = args.copy()
    new_args[0] = new_args[0].replace("ramalama-run-core", new_command)
    new_args = [new_args[0]]

    return new_args

if __name__ == "__main__":
    main(sys.argv)
