#!/usr/bin/env python3

# Assumptions made by this code:
# - list_object_versions sorts all versions (including delete markers) by key
#   (in ascending order) and last modified date (in descending order, sometimes
#   with consecutive equal timestamps) and returns them in chunks of 1000 (by
#   default), with each chunk divided into proper versions and delete markers.
# - Objects that have been deleted will have a delete marker as their latest
#   version and at least one proper version dated before that.

__requires__ = ["boto3", "click >= 7.0", "humanize"]

from bisect import bisect
from datetime import datetime
import re
import sys
from typing import List, NamedTuple, Tuple
from urllib.parse import urlparse

import boto3
from botocore import UNSIGNED
from botocore.client import Config
import click
from humanize import naturalsize


class Version(NamedTuple):
    bucket: str
    key: str
    version_id: str
    size: int
    last_modified: datetime

    @classmethod
    def from_data(cls, bucket, data):
        return cls(
            bucket=bucket,
            key=data["Key"],
            version_id=data["VersionId"],
            size=data["Size"],
            last_modified=data["LastModified"],
        )

    @property
    def key_url(self):
        return f"s3://{self.bucket}/{self.key}"

    @property
    def url(self):
        return f"s3://{self.bucket}/{self.key}?versionId={self.version_id}"

    def __str__(self):
        return f"{self.url} {self.size}"


class BucketStats:
    def __init__(self, bucket, prefix, list_files=False, stat=("all",), exclude=()):
        self.bucket: str = bucket
        self.prefix: str = prefix
        self.list_files: bool = list_files
        self.stat: Tuple[str, ...] = stat
        self.exclude: Tuple[str, ...] = exclude
        #: Versions of the current key
        self.versions: List[Version] = []
        # Deleted keys, in ascending order:
        self.deleted: List[str] = []
        self.qtys = {
            "all": 0,
            "visible": 0,
            "invisible": 0,
            "old": 0,
        }
        self.sizes = {
            "all": 0,
            "visible": 0,
            "invisible": 0,
            "old": 0,
        }
        self.found_any = False

    def run(self):
        # Use s3 anonymously/without credentials:
        client = boto3.client("s3", config=Config(signature_version=UNSIGNED))
        for page in client.get_paginator("list_object_versions").paginate(
            Bucket=self.bucket, Prefix=self.prefix
        ):
            # TODO: Filter out keys that ends with slashes?
            for dm in page.get("DeleteMarkers", []):
                if dm["IsLatest"]:
                    self.mark_deleted(dm["Key"])
            for v in page["Versions"]:
                self.add_version(v)
        self.end_key()
        for rtype in self.stat:
            print(
                f"{rtype.title()} files: {self.qtys[rtype]}",
                "/",
                f"Size: {self.sizes[rtype]} ({naturalsize(self.sizes[rtype])})",
            )

    def add_version(self, data):
        v = Version.from_data(self.bucket, data)
        if self.versions:
            if self.versions[-1].key == v.key:
                assert (
                    self.versions[-1].last_modified >= v.last_modified
                ), f"Versions for key {v.key!r} not in reverse chronological order"
            else:
                assert self.versions[-1].key < v.key, (
                    "Keys not in lexicographic order;"
                    f" {self.versions[-1].key!r} listed before {v.key!r}"
                )
                self.end_key()
        self.versions.append(v)

    def mark_deleted(self, key):
        if self.deleted:
            assert (
                self.deleted[-1] < key
            ), f"DeleteMarkers not in lexicographic order; {self.deleted[-1]} listed before {key}"
        self.deleted.append(key)

    def end_key(self):
        if self.versions:
            key = self.versions[0].key
            if self.deleted:
                i = bisect(self.deleted, key)
                # Error/warn if i>1?
                deleted = i > 0 and self.deleted[i - 1] == key
                del self.deleted[:i]
            else:
                deleted = False
            if deleted:
                self.report({"all", "invisible"}, self.versions)
            else:
                self.report({"all", "visible"}, self.versions[:1])
                self.report({"all", "old", "invisible"}, self.versions[1:])
            self.versions = []

    def report(self, rtypes, versions):
        if not versions:
            return
        if any(re.search(rgx, versions[0].key_url) for rgx in self.exclude):
            return
        if rtypes.intersection(self.stat):
            for v in versions:
                self.found_any = True
                if self.list_files:
                    print(v)
                for t in rtypes:
                    self.qtys[t] += 1
                    self.sizes[t] += v.size


@click.command()
@click.option("--exclude", metavar="URLREGEX", multiple=True)
@click.option("--fail-if-any", is_flag=True)
@click.option("--list", "list_files", is_flag=True)
@click.option(
    "--stat",
    type=click.Choice(["all", "visible", "invisible", "old"]),
    multiple=True,
    default=["all"],
)
@click.argument("url")
def main(stat, list_files, fail_if_any, url, exclude):
    bucket, prefix = parse_s3_url(url)
    stats = BucketStats(
        bucket, prefix, list_files=list_files, stat=stat, exclude=exclude
    )
    stats.run()
    if fail_if_any and stats.found_any:
        sys.exit(1)


def parse_s3_url(url):
    parts = urlparse(url, allow_fragments=False)
    if parts.scheme != "s3":
        raise ValueError(f"not an S3 URL: {url}")
    return (parts.netloc, parts.path.lstrip("/"))


if __name__ == "__main__":
    main()
