#!/usr/bin/env python3

"""CSV filter that converts (selected columns) from Postgres HEX sequences to UTF-8
strings, respecting CSV escaping

"""

__copyright__ = "Copyright (C) 2020  Stefano Zacchiroli"
__license__ = "GPL-3.0-or-later"

import binascii
import click
import csv
import datetime
import logging
import sys

from contextlib import ExitStack
from dateutil.tz import tzoffset
from typing import Callable, Generator, IO, List, Optional


DEFAULT_ENCODING = "utf-8"
CSV_SEP = ","

PG_HEX_PREFIX = r"\x"
PG_HEX_PREFIX_len = len(PG_HEX_PREFIX)

# Format string to parse postgres timestamps. We expect the timezone to be always zero
# (as it is not kept by Postgres), so we force it to be so here.
PG_TS_FMT = "%Y-%m-%d %H:%M:%S+00"

# reference: https://www.postgresql.org/docs/12/datatype-datetime.html
#
# experimentally it seems to be *15* hours instead of 14, but we follow the doc to stay
# on the safe side
PG_MAX_TS_OFFSET = 14 * 60 + 59  # in minutes


def pghex_to_str(s: str, encoding: str = DEFAULT_ENCODING) -> str:
    """convert a Postgres hex encoded string to a (decoded) Python string

    Args:
        s: postgres hex encoded string starting with "\\x"
        encoding: encoding to use for decoding

    """
    return binascii.unhexlify(s[PG_HEX_PREFIX_len:]).decode(encoding)


def pgts_to_iso(ts: str, tz: int = 0) -> str:
    """convert a Postgres timestamp to ISO format

    Args:
        ts: textual serialization of a Postgres "timestamp with time zone" value
        tz: timezone offset in minutes (default: 0, i.e., UTC)

    Returns:
        a ISO8601-formatted datetime string

    """
    if abs(tz) > PG_MAX_TS_OFFSET:
        raise ValueError(f"time zone displacement out of (postgres) range: {tz}")

    offset = tzoffset("custom", tz * 60)  # "* 60" to obtain seconds
    # date = dateutil.parser.parse(ts)  # more robust, but slower and not needed
    date = datetime.datetime.strptime(ts, PG_TS_FMT)
    return date.astimezone(offset).isoformat()


def csv_filter(
    csv_reader, csv_writer, transform: Callable[[List[str]], List[str]]
) -> int:
    """generic CSV to CSV filter

    Args:
        transform: column to column transformer; can raise ValueError to signal an error
            that would result in skipping the current row

    Returns:
        number of lines not written due to encountered errors

    """
    errors = 0
    lineno = 0

    for row in csv_reader:
        lineno += 1
        try:
            row = transform(row)
        except ValueError as e:
            logging.error(f"error while converting row {lineno}: {e}")
            errors += 1
            continue

        csv_writer.writerow(row)

    return errors


def convert_to_utf8(
    csv_reader, csv_writer, columns: Optional[List[int]], encoding=DEFAULT_ENCODING
) -> int:
    """convert the CSV records read from csv_reader to a given encoding before writing
    using csv_writer

    Only the (0-based) indexes listed in columns are converted, other fields are left
    untouched. If the index list is None (the default) all columns will be converted.

    rows that contain even a single decode error are skipped

    Returns:
        number of decode errors encountered

    """

    def row_to_utf8(row):
        nonlocal columns

        if columns is None:  # generate row indexes based on 1st row length
            columns = list(range(len(row)))
        for col_idx in columns:
            row[col_idx] = pghex_to_str(row[col_idx])

        return row

    decode_errors = csv_filter(csv_reader, csv_writer, row_to_utf8)

    return decode_errors


def dumb_csv_reader(
    f: IO[str], delimiter: str = CSV_SEP
) -> Generator[List[str], None, None]:
    """fast but unsafe CSV reader, which doesn't care about escapes"""
    for line in f:
        cols = line.rstrip().split(delimiter)
        yield cols


@click.command(
    help="""
    CSV to CSV filter that converts (selected columns) from Postgres HEX escapes to
    native strings in a given encoding.

    Input and output files default to stdin/stdout; "-" can be given to specify
    stdin/stdout explicitly.

    """
)
@click.argument("in-file", default="-")
@click.argument("out-file", default="-")
@click.option(
    "-c",
    "--columns",
    help="comma separated list of columns to encode, 0-based (default: no columns)",
    default="",
)
@click.option(
    "-e",
    "--encoding",
    help=f'target encoding (default: "{DEFAULT_ENCODING}")',
    type=str,
    default=DEFAULT_ENCODING,
)
def main(in_file: str, out_file: str, columns: str, encoding: str) -> None:
    if columns:
        col_idxs = list(map(int, columns.split(",")))
    else:
        col_idxs = []

    with ExitStack() as stack:
        if in_file == "-":
            input_f = sys.stdin
        else:
            input_f = stack.enter_context(open(in_file, "rt", encoding="ascii"))
        if out_file == "-":
            output_f = sys.stdout
        else:
            output_f = stack.enter_context(open(out_file, "wt", encoding=encoding))
        # csv_in = dumb_csv_reader(input_f, delimiter=CSV_SEP)
        csv_in = csv.reader(input_f, delimiter=CSV_SEP)
        csv_out = csv.writer(output_f, delimiter=CSV_SEP)

        errors = convert_to_utf8(csv_in, csv_out, col_idxs, encoding)
        if errors:
            logging.error(f"skipped {errors} row(s) due to encoding errors")


if __name__ == "__main__":
    main()
