#!/usr/bin/env python3

"""OCI Registry Helper

Usage:
    ocirh <subcommand> [<args>...]

Subcommands:
    repos           Lists repositories in the registry. Repos correspond to images
                    pushed to the registry.
    tags            Lists tags of the given repository.
    manifests       Lists manifests of the given repository for the given tag.
    rmi             Removes a tag from an image. If given tag is the only tag,
                    removes the image.
    gc              Runs garbage collection on the registry. Requires SSH public key
                    access to registry server.
    rmr             Removes given repository from the registry. Requires SSH public
                    key access to registry server.

Examples:
    Suppose we have an image called 'fedora-toolbox' tagged with 'latest'.

    ocirh repos
    ocirh tags fedora-toolbox
    ocirh manifests fedora-toolbox latest
    ocirh rmi fedora-toolbox latest
    ocirh gc
    ocirh rmr fedora-toolbox
"""
import http.client
import json
import logging
import math
import subprocess

from docopt import docopt
from rich import print
from rich.console import Group
from rich.logging import RichHandler
from rich.panel import Panel
from rich.table import Table
from rich.text import Text
from rich.traceback import install
from rich.tree import Tree

install(show_locals=True)

# Rich logging handler
FORMAT = "%(message)s"
logging.basicConfig(
    level="NOTSET",
    format=FORMAT,
    datefmt="[%X]",
    handlers=[RichHandler(rich_tracebacks=True)],
)
log = logging.getLogger("rich")


# Taken from https://stackoverflow.com/a/14822210
#
# How this function works:
#   If size_bytes == 0, returns 0 B.
#   size_name is a tuple containing binary prefixes for bytes.
#
#   math.log takes the logarithm of size_bytes to base 1024.
#   math.floor rounds down the result of math.log to the nearest integer.
#   int ensures the result of math.floor is of type int, and stores it in i.
#   The value of i is used to determine which binary prefix to use from
#   size_name.
#
#   math.pow returns the value of 1024 raised to the power of i, stores it in p.
#
#   round takes the value of size_bytes, divides it by p, and stores the result
#   in s at precision of 2 decimal places.
#
#   A formatted string with size s and binary prefix size_name[i] is returned.
def convert_size(size_bytes: int) -> str:
    """
    Converts a decimal integer of bytes to its respective binary-prefixed size.

        Parameters:
            size_bytes (int): A decimal integer.

        Returns:
            (str): Binary-prefixed size of size_bytes formatted as a string.
    """
    if size_bytes == 0:
        return "0 B"
    size_name = ("B", "KiB", "MiB", "GiB")
    i = int(math.floor(math.log(size_bytes, 1024)))
    p = math.pow(1024, i)
    s = round(size_bytes / p, 2)
    return "%s %s" % (s, size_name[i])


REGISTRY_URL = "registry.hyperreal.coffee"


def get_auth() -> str:
    """
    Get the base64 encoded password for registry autentication.

        Returns:
            auth (str): A string containing the base64 encoded password.
    """
    try:
        with open("/run/user/1000/containers/auth.json", "r") as authfile:
            json_data = json.loads(authfile.read())
    except Exception as ex:
        log.exception(ex)

    auth = json_data["auths"][REGISTRY_URL]["auth"]
    return auth


def get_headers() -> dict:
    """
    Returns headers for HTTP request authentication to the registry server.

        Returns:
            headers (dict): A dict of HTTP headers
    """
    return {
        "Accept": "application/vnd.oci.image.manifest.v1+json",
        "Authorization": "Basic " + get_auth(),
    }


def get_json_response(request: str, url: str) -> dict:
    """
    Connects to registry and returns response data as JSON.

        Parameters:
            request (str): A string like "GET" or "DELETE"
            url (str)    : A string containing the URL of the requested data

        Returns:
            json_data (dict): JSON data as a dict object
    """
    conn = http.client.HTTPSConnection(REGISTRY_URL)
    headers = get_headers()
    try:
        conn.request(request, url, "", headers)
        res = conn.getresponse()
        data = res.read()
        json_data = json.loads(data.decode("utf-8"))
    except Exception as ex:
        log.exception(ex)

    return json_data


def get_repositories():
    """
    Prints a Rich Tree that lists the repositories of the registry.
    """

    json_data = get_json_response("GET", "/v2/_catalog")
    repo_tree = Tree("[green]Repositories")
    for repo in json_data["repositories"]:
        repo_tree.add("[blue]%s" % repo)

    print(repo_tree)


def get_tags(repo: str):
    """
    Prints a Rich Tree that lists the tags for the given repository.

        Parameters:
            repo (str): A string containing the name of the repo
    """
    json_data = get_json_response("GET", "/v2/" + repo + "/tags/list")
    tags_tree = Tree("[green]%s tags" % repo)
    for tag in json_data["tags"]:
        tags_tree.add("[cyan]:%s" % tag)

    print(tags_tree)


def get_manifests(repo: str, tag: str):
    """
    Prints a Rich grid table that displays the manifests and metadata of the
    image repository.

        Parameters:
            repo (str): A string containing the name of the repo
            tag (str) : A string containing the tag of the desired image
    """
    json_data = get_json_response("GET", "/v2/" + repo + "/manifests/" + tag)

    grid_meta = Table.grid(expand=True)
    grid_meta.add_column()
    grid_meta.add_column()
    meta_schema_version_key = Text("Schema version")
    meta_schema_version_key.stylize("bold green", 0)
    meta_schema_version_value = Text(str(json_data["schemaVersion"]))
    meta_media_type_key = Text("Media type")
    meta_media_type_key.stylize("bold green", 0)
    meta_media_type_value = Text(json_data["mediaType"])
    grid_meta.add_row(meta_schema_version_key, meta_schema_version_value)
    grid_meta.add_row(meta_media_type_key, meta_media_type_value)

    grid_config = Table.grid(expand=True)
    grid_config.add_column()
    grid_config.add_column()
    config_media_type_key = Text("Media type")
    config_media_type_key.stylize("bold green", 0)
    config_media_type_value = Text(json_data["config"]["mediaType"])
    config_digest_key = Text("Digest")
    config_digest_key.stylize("bold green", 0)
    config_digest_value = Text(json_data["config"]["digest"])
    config_size_key = Text("Size")
    config_size_key.stylize("bold green", 0)
    config_size_value = Text(convert_size(json_data["config"]["size"]))
    grid_config.add_row(config_media_type_key, config_media_type_value)
    grid_config.add_row(config_digest_key, config_digest_value)
    grid_config.add_row(config_size_key, config_size_value)

    grid_annotations = Table.grid(expand=True)
    grid_annotations.add_column()
    grid_annotations.add_column()
    for item in json_data["annotations"].items():
        annotations_item_key = Text(item[0])
        annotations_item_key.stylize("bold green", 0)
        annotations_item_value = Text(item[1])
        grid_annotations.add_row(annotations_item_key, annotations_item_value)

    total_size = sum(layer.get("size") for layer in json_data["layers"])
    table_layers = Table(box=None, show_footer=True)
    table_layers.add_column(
        "Digest", justify="right", style="yellow", no_wrap=True, footer="Total size:"
    )
    table_layers.add_column(
        "Size",
        justify="left",
        style="cyan",
        no_wrap=True,
        footer=convert_size(total_size),
    )
    for layer in json_data["layers"]:
        table_layers.add_row(layer.get("digest"), convert_size(layer.get("size")))

    panel_group = Group(
        Panel(grid_meta, title="[bold blue]Metadata"),
        Panel(grid_config, title="[bold blue]Config"),
        Panel(grid_annotations, title="Annotations"),
        Panel(
            table_layers,
            title="[bold blue]Layers: %s" % json_data["layers"][0].get("mediaType"),
        ),
    )
    print(Panel(panel_group, title="[bold blue]%s:%s" % (repo, tag)))


def delete_image(repo: str, tag: str):
    """
    Removes the given tag from the image. If the given tag is the only tag,
    removes the image.

        Parameters:
            repo (str): A string containing the name of the repo
            tag (str) : A string containing the tag to be removed
    """
    try:
        conn = http.client.HTTPSConnection(REGISTRY_URL)
        headers = get_headers()
        conn.request("GET", "/v2/" + repo + "/manifests/" + tag, "", headers)
        res = conn.getresponse()
        docker_content_digest = res.getheader("Docker-Content-Digest")
    except Exception as ex:
        log.exception(ex)

    try:
        conn.request(
            "DELETE", "/v2/" + repo + "/manifests/" + docker_content_digest, "", headers
        )
    except Exception as ex:
        log.exception(ex)

    print("Untagged %s:%s successfully" % (repo, tag))


def garbage_collection():
    """
    Runs garbage collection command on the remote registry server. Requires
    SSH public key access.
    """
    command = "/usr/local/bin/registry-gc"

    try:
        ssh = subprocess.Popen(
            ["ssh", "%s" % REGISTRY_URL, command],
            shell=False,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
        )
        result = ssh.stdout.readlines()
        if result == []:
            log.error(ssh.stderr.readlines())
        else:
            print(result)
    except Exception as ex:
        log.exception(ex)


def remove_repo(repo: str):
    """
    Runs command on remote registry server to remove the given repo.

        Parameters:
            repo (str): A string containing the name of the repo.
    """
    command = "/usr/local/bin/registry-rm-repo " + repo

    try:
        ssh = subprocess.Popen(
            ["ssh", "%s" % REGISTRY_URL, command],
            shell=False,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
        )
        result = ssh.stdout.readlines()
        if result == []:
            log.error(ssh.stderr.readlines())
        else:
            print(result)
    except Exception as ex:
        log.exception(ex)


if __name__ == "__main__":
    args = docopt(__doc__, options_first=True)
    match args["<subcommand>"]:
        case "repos":
            get_repositories()
        case "tags":
            get_tags(args["<args>"][0])
        case "manifests":
            get_manifests(args["<args>"][0], args["<args>"][1])
        case "rmi":
            delete_image(args["<args>"][0], args["<args>"][1])
        case "gc":
            garbage_collection()
        case "rmr":
            remove_repo(args["<args>"])
        case _:
            if args["<subcommand>"] in ["help", None]:
                exit(subprocess.call(["python3", "ocirh", "--help"]))
            else:
                exit(
                    "%r is not a ocirh subcommand. See 'ocirh --help."
                    % args["<subcommand>"]
                )