#!/usr/bin/env python3

from __future__ import annotations

import base64
import dataclasses
import fcntl
import hmac
import http
import http.server
import logging
import math
import os
import os.path
import shutil
import shlex
import signal
import subprocess
import sys
import traceback
import urllib.parse

import trio
import yaml


_log = logging.getLogger('git-build-triggers')
logging.basicConfig(
    format='%(asctime)s: %(levelname)s: %(message)s',
    level=logging.INFO,
)


class UnixFileLock:
    __slots__ = ('_path', '_fd')

    def __init__(self, path: str) -> None:
        self._path = path
        self._fd: int | None = None

    def acquire(self) -> bool:
        if not self._fd is None:
            raise RuntimeError(f"UnixFileLock({self._path!r}) already locked; re-entry not allowed")

        fd = os.open(self._path, os.O_RDWR | os.O_CREAT | os.O_TRUNC)
        try:
            fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
        except OSError:
            os.close(fd)
            return False
        else:
            self._fd = fd
            return True

    def release(self) -> None:
        fd, self._fd = self._fd, None
        if not fd is None:
            fcntl.flock(fd, fcntl.LOCK_UN)
            os.close(fd)


@dataclasses.dataclass(slots=True, kw_only=True)
class Job:
    repository: Repository
    lock: UnixFileLock


@dataclasses.dataclass(slots=True)
class JobQueue:
    parallel: int
    _queue: trio.MemorySendChannel[Job]
    _rx: trio.MemoryReceiveChannel[Job]

    def __init__(self, *, parallel: int = 1) -> None:
        self.parallel = parallel
        self._queue, self._rx = trio.open_memory_channel(math.inf)

    async def run(self) -> None:
        limit = trio.CapacityLimiter(self.parallel)

        async def work(job: Job) -> None:
            try:
                async with limit:
                    await job.repository.update()
                    try:
                        os.remove(job.repository._path_rebuild)
                    except FileNotFoundError:
                        return
                    # build again
                    await self._queue.send(job)
            finally:
                job.lock.release()

        async with trio.open_nursery() as nursery:
            job: Job
            async for job in self._rx:
                nursery.start_soon(work, job)

    def queue(self, job: Job) -> None:
        trio.from_thread.run(lambda: self._queue.send(job))

    def stop(self) -> None:
        self._queue.close()


JOB_QUEUE: JobQueue


@dataclasses.dataclass(slots=True, kw_only=True)
class Repository:
    name: str
    workdir: str
    token: str = dataclasses.field(repr=False)
    command: str
    config: Config = dataclasses.field(repr=False)
    _path_gitdir: str = dataclasses.field(init=False)
    _path_lockfile: str = dataclasses.field(init=False)
    _path_lastbuild: str = dataclasses.field(init=False)
    _path_rebuild: str = dataclasses.field(init=False)

    def __post_init__(self) -> None:
        self._path_gitdir = gitdir = os.path.join(self.workdir, ".git")
        self._path_lockfile = os.path.join(gitdir, "build.lock")
        self._path_lastbuild = os.path.join(gitdir, "build.status")
        self._path_rebuild = os.path.join(gitdir, "rebuild_flag")

    def _writestatus(self, commit: str, message: str|bytes) -> None:
        if isinstance(message, str):
            message = message.encode()
        tmpname = self._path_lastbuild + ".tmp"
        try:
            with open(tmpname, "wb") as lastbuild:
                lastbuild.write(commit.encode() + b"\n" + message)
            os.rename(tmpname, self._path_lastbuild)
        except OSError as e:
            _log.error(f"{self.name}: Failed to update {self._path_lastbuild}: {e}")

    async def _run_git(self, cmd: list[str]) -> bytes:
        result: subprocess.CompletedProcess[bytes] = await trio.run_process(
            [self.config.git_path] + cmd,
            cwd=self.workdir,
            capture_stdout=True,
        )
        return result.stdout

    async def _update(self, last_commit: str) -> None:
        _log.info(f"{self.name}: Updating git")
        # remove all ignored/untracked files and directories:
        await self._run_git(["clean", "--force", "-d", "-x"])
        await self._run_git(["remote", "update", "origin"])
        branch_name = (await self._run_git(["rev-parse", "--abbrev-ref", "HEAD"])).decode().strip()
        await self._run_git(["reset", "--hard", f"origin/{branch_name}"])
        commit_id = (await self._run_git(["rev-parse", "HEAD"])).decode().strip()
        if last_commit == commit_id:
            _log.info(f"{self.name}: No changes (still {last_commit})")
            return  # no changes

        build: subprocess.CompletedProcess[bytes] = await trio.run_process(
            shlex.split(self.command),
            cwd=self.workdir,
            capture_stdout=True,
            stderr=subprocess.STDOUT,
            check=False,
        )
        self._writestatus(commit_id, f"Exit code: {build.returncode}\n".encode() + build.stdout)
        # again: remove all ignored/untracked files and directories:
        await self._run_git(["clean", "--force", "-d", "-x"])
        _log.info(f"{self.name}: Built {commit_id} with exit status {build.returncode}")

    async def update(self) -> None:
        """should only be called while holding lock"""
        try:
            with open(self._path_lastbuild, "rb") as lastbuild:
                last_commit = lastbuild.readline().strip().decode()
        except FileNotFoundError:
            last_commit = "none"
        try:
            await self._update(last_commit)
        except Exception as e:
            _log.error(f"{self.name}: Failed {last_commit}: {e}")
            self._writestatus(last_commit, str(e))

    def check(self) -> tuple[int, bytes | str]:
        if not os.path.isdir(self._path_gitdir):
            return (500, "Missing .git directory")

        lock = UnixFileLock(self._path_lockfile)
        if lock.acquire():
            JOB_QUEUE.queue(Job(repository=self, lock=lock))
        else:
            # tell current job to restart when finished:
            with open(self._path_rebuild, "w"):
                pass

            # if current job finished before seeing our trigger,
            # remove the trigger and just run it ourself
            if lock.acquire():
                try:
                    os.remove(self._path_rebuild)
                except FileNotFoundError:
                    # something saw the trigger and handled it
                    # - no need to build again
                    pass
                else:
                    JOB_QUEUE.queue(Job(repository=self, lock=lock))

        try:
            with open(self._path_lastbuild) as f:
                return (200, f.read())
        except FileNotFoundError:
            return (200, "never built yet")


@dataclasses.dataclass(slots=True, kw_only=True)
class Config:
    repositories: dict[str, Repository]
    address: str
    port: int
    basepath: str
    admin_token: str
    git_path: str
    parallel_jobs: int

    def handle(self, req_path: str, auth: str) -> tuple[int, bytes | str]:
        url = urllib.parse.urlparse(req_path)
        if auth.startswith("Basic "):
            creds = base64.decodebytes(auth.removeprefix("Basic ").strip().encode())
            token = creds.split(b":", maxsplit=1)[1].decode()
        elif auth.startswith("Bearer "):
            token = auth.removeprefix("Bearer ").strip()
        else:
            return (401, "Missing authentication")
        name = url.path.removeprefix(self.basepath).strip("/")
        if not name in self.repositories:
            return (404, "Not found")
        repo = self.repositories[name]
        if (
            not hmac.compare_digest(token, repo.token)
            and not (self.admin_token and hmac.compare_digest(token, self.admin_token))
        ):
            return (401, "Invalid token")
        return repo.check()


def load_config(path: str) -> Config:
    with open(path) as f:
        data = yaml.safe_load(f)

    assert isinstance(data, dict)
    data_repositories = data.pop('repositories')
    assert isinstance(data_repositories, dict)

    address = data.pop("address", "127.0.0.1")
    assert isinstance(address, str)

    port = data.pop("port")
    assert isinstance(port, int)

    parallel_jobs = data.pop("parallel-jobs", 1)
    assert isinstance(parallel_jobs, int)

    basepath = data.pop("base-path", "/")
    assert isinstance(basepath, str)
    assert not basepath or basepath.startswith("/")

    admin_token = data.pop("admin-token", "")
    assert not admin_token or len(admin_token) >= 16

    git_path = GIT = shutil.which("git")
    if not git_path:
        raise RuntimeError("Missing git binary")

    config = Config(
        repositories={},
        address=address,
        port=port,
        basepath=basepath,
        git_path=git_path,
        admin_token=admin_token,
        parallel_jobs=parallel_jobs,
    )
    for repo_name, repo_data in data_repositories.items():
        workdir = repo_data.pop("workdir")
        assert isinstance(workdir, str)
        token = repo_data.pop("token")
        assert isinstance(token, str) and len(token) >= 16
        command = repo_data.pop("command")
        assert isinstance(command, str) and command
        config.repositories[repo_name] = Repository(name=repo_name, config=config, workdir=workdir, token=token, command=command)

    return config


CONFIG: Config


class RequestHandler(http.server.BaseHTTPRequestHandler):
    server_version = "BuildTrigger"

    def do_POST(self) -> None:
        status: int
        body: bytes | str

        auth = self.headers.get("Authorization", "")
        try:
            status, body = CONFIG.handle(self.path, auth)
        except Exception as e:
            status = 500
            body = str(e)
            traceback.print_exception(e)
        if isinstance(body, str):
            raw_body = body.encode()
        else:
            assert isinstance(body, bytes)
            raw_body = body
        self.send_response(status)
        self.send_header("Cache-Control", "no-store")
        self.send_header("Content-Type", "text/plain; charset=utf-8")
        self.send_header("Content-Length", str(len(raw_body)))
        if status == 401:
            self.send_header("WWW-Authenticate", "Basic realm=\"trigger\"")
        self.end_headers()
        self.wfile.write(raw_body)

    do_GET = do_POST


def run():
    server = http.server.HTTPServer((CONFIG.address, CONFIG.port), RequestHandler)

    def shutdown(signum, frame) -> None:
        _log.info("Shutdown")
        server.shutdown()
        JOB_QUEUE.stop()

    signal.signal(signal.SIGINT, shutdown)

    async def go() -> None:
        async with trio.open_nursery() as nursery:
            nursery.start_soon(lambda: trio.to_thread.run_sync(server.serve_forever))
            nursery.start_soon(JOB_QUEUE.run)

    trio.run(go)


def main():
    global CONFIG, JOB_QUEUE
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', required=True, help="Path to YAML config file")
    args = parser.parse_args()
    CONFIG = load_config(args.config)
    JOB_QUEUE = JobQueue(parallel=CONFIG.parallel_jobs)

    run()


main()