#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright © 2021 Ye Chang yech1990@gmail.com
# Distributed under terms of the MIT license.
#
# Created: 2021-11-03 16:29

"""Realign bam file.

Demo version by python.
"""

import re

import click
import parasail
import pysam

re_split_cigar = re.compile(r"(?P<len>\d+)(?P<op>\D+)")


def sam2align(record):
    query_string = ""
    match_string = ""
    ref_string = ""
    for query_pos, ref_pos, ref_base in record.get_aligned_pairs(
        with_seq=True
    ):
        if ref_base is None and query_pos is None:
            query_string += "."
            match_string += " "
            ref_string += "N"
        elif query_pos is None and ref_pos is not None:
            query_string += "-"
            match_string += " "
            ref_string += ref_base
        elif ref_pos is None and query_pos is not None:
            query_base = record.query_sequence[query_pos]
            query_string += query_base
            match_string += " "
            ref_string += "-"
        elif ref_pos is not None and query_pos is not None:
            query_base = record.query_sequence[query_pos]
            query_string += query_base
            if query_base == ref_base:
                match_string += "|"
            else:
                match_string += " "
            ref_string += ref_base
        else:
            pass
    return "\n".join([query_string, match_string, ref_string])


def cigar_ops_from_start(cigar):
    """Yield cigar string operations from start of cigar (in order).

    :param cigar: cigar string.

    :yields: str, op. length, str op. type
    """
    for m in re.finditer(re_split_cigar, cigar):
        yield m.group("len"), m.group("op")


def parasail_to_sam(result, seq):
    """Extract reference start and sam compatible cigar string.

    :param result: parasail alignment result.
    :param seq: query sequence.

    :returns: reference start coordinate, cigar string.
    """
    cigstr = result.cigar.decode.decode()

    first = next(cigar_ops_from_start(cigstr))
    prefix = "".join(first)
    rstart = result.cigar.beg_ref
    cliplen = result.cigar.beg_query
    clip = "" if cliplen == 0 else "{}S".format(cliplen)
    if first[1] == "I":
        pre = "{}S".format(int(first[0]) + cliplen)
    elif first[1] == "D":
        pre = clip
        rstart = int(first[0])
    else:
        pre = "{}{}".format(clip, prefix)

    mid = cigstr[len(prefix) :]
    end_clip = len(seq) - result.end_query - 1
    suf = "{}S".format(end_clip) if end_clip > 0 else ""
    new_cigstr = "".join((pre, mid, suf))
    return rstart, new_cigstr


def parasail_alignment(query, ref):
    """Run a Smith-Waterman alignment between two sequences.

    :param query: the query sequence.
    :param ref: the reference sequence.

    :returns: reference start co-ordinate, cigar string
    """
    result = parasail.sw_trace_striped_32(query, ref, 3, 2, parasail.dnafull)
    rstart, cigar = parasail_to_sam(result, query)
    return rstart, cigar


def generate_md_tag(ref, cigar):
    p = 0
    md = ""
    t_pre = "="
    n_pre = 0
    for n, t in cigar_ops_from_start(cigar):
        n = int(n)
        if t == "I" or t == "S" or t == "N":
            continue
        if t == "=":
            n += n_pre
            n_pre = n
            p += n
        if t == "X":
            if t_pre == "=":
                md += f"{n_pre}"
                n_pre = 0
            md += f"{ref[p: p+n].upper()}"
            p += n
        if t == "D":
            if t_pre == "=":
                md += f"{n_pre}"
                n_pre = 0
            md += f"^{ref[p: p+n].upper()}"
            p += 0
        t_pre = t
    if t_pre == "=":
        md += f"{n_pre}"
    return md


def get_splicing_site_from_cigar(cigar):
    """Get splicing site from cigar string.

    :param cigar: cigar string.

    :returns: splicing site.
    """
    splicing_site = []
    n = 0
    for x, y in cigar_ops_from_start(cigar):
        x = int(x)
        if y == "N":
            # if n > 0:
            splicing_site.append((n, x))
            # n = 0
        elif y in "X=MDS":
            n += x

    return splicing_site


def degenerate_cigar(cigar):
    new_cigar = ""
    n = 0
    for x, y in cigar_ops_from_start(cigar):
        x = int(x)
        if y not in "X=":
            if n > 0:
                new_cigar += f"{n}M"
            new_cigar += f"{x}{y}"
            n = 0
        else:
            n += x
    if n > 0:
        new_cigar += f"{n}M"
    return new_cigar


def split_cigar(cigar, split_points):
    """Input the cigar string before degenration.

    52= for example.
    """
    new_cigar = ""
    split_points.append((float("inf"), 0))
    w = 0
    p = 0
    s, l = split_points.pop(0)
    s2, _ = split_points[0]

    for x, y in cigar_ops_from_start(cigar):
        x = int(x)
        if y in "X=MS":
            p += x
        if p <= s:
            new_cigar += f"{x}{y}"
            w += x
        else:
            while x > 0:
                s2, _ = split_points[0]
                if p - s <= s2 - s:
                    new_cigar += f"{x - p + s}{y}{l}N{p - s}{y}"
                    x -= s2
                else:
                    x -= s - w
                    new_cigar += f"{x}{y}{l}N"
                    w = 0
                s, l = split_points.pop(0)
    return new_cigar


@click.command(
    context_settings=dict(help_option_names=["-h", "--help"]),
    no_args_is_help=True,
)
@click.option(
    "-r",
    "--input-fa",
    "input_fa_name",
    help="Path of reference fasta file.",
    required=True,
)
@click.option(
    "-i",
    "--input-sam",
    "input_sam_name",
    help="Path of input SAM/BAM file.",
    required=True,
)
@click.option(
    "-o",
    "--output-sam",
    "output_sam_name",
    help="Path of output SAM/BAM file.",
    required=True,
)
@click.option(
    "--debug/--no-debug",
    default=False,
    help="Write error into file for debuging.",
)
def run_realign(input_fa_name, input_sam_name, output_sam_name, debug):
    fafile = pysam.FastaFile(input_fa_name)
    if input_sam_name.endswith(".bam"):
        samfile = pysam.AlignmentFile(input_sam_name, "rb")
    elif input_sam_name.endswith(".cram"):
        samfile = pysam.AlignmentFile(
            input_sam_name, "rc", reference_filename=input_fa_name
        )
    else:
        samfile = pysam.AlignmentFile(input_sam_name, "r")
    if output_sam_name.endswith(".bam"):
        outfile = pysam.AlignmentFile(output_sam_name, "wb", template=samfile)
    elif output_sam_name.endswith(".cram"):
        outfile = pysam.AlignmentFile(
            output_sam_name,
            "wc",
            template=samfile,
            reference_filename=input_fa_name,
        )
    else:
        outfile = pysam.AlignmentFile(output_sam_name, "w", template=samfile)
    if debug:
        debug_output = pysam.AlignmentFile(
            output_sam_name.rsplit(".", 1)[0] + "_debug.sam",
            "w",
            template=samfile,
        )

    pad = 20
    for read in samfile.fetch():
        if (
            "D" in read.cigarstring
            or "S" in read.cigarstring
            or not read.get_tag("MD").isnumeric()
        ) and read.reference_name in fafile.references:
            if "N" not in read.cigarstring:
                align_start = max(read.reference_start - pad, 0)
                align_end = read.reference_end + pad
                ref = fafile.fetch(read.reference_name, align_start, align_end)
                s, c = parasail_alignment(read.query_sequence, ref)
                read.reference_start = align_start + s
                # use a old cigar spec...
                c2 = degenerate_cigar(c)
                if read.cigarstring != c2:
                    # write original cigar
                    read.set_tag("OC", read.cigarstring)
                    # write original MD tag
                    read.set_tag("OM", read.get_tag("MD"))
                    # update cigar string
                    read.cigarstring = c2
                    # update MD tag
                    read.set_tag("MD", generate_md_tag(ref[s:], c))
            # realign splicing reads
            else:
                ref = ""
                align_start = max(read.reference_start - pad, 0)
                align_end = read.reference_end + pad
                ref += fafile.fetch(
                    read.reference_name, align_start, read.reference_start
                )
                ref += read.get_reference_sequence()
                ref += fafile.fetch(
                    read.reference_name, read.reference_end, align_end
                )
                s, c = parasail_alignment(read.query_sequence, ref)
                read.reference_start = align_start + s
                split_points = get_splicing_site_from_cigar(read.cigarstring)
                c2 = degenerate_cigar(split_cigar(c, split_points))
                if read.cigarstring != c2:
                    # write original cigar
                    read.set_tag("OC", read.cigarstring)
                    # write original MD tag
                    read.set_tag("OM", read.get_tag("MD"))
                    # update cigar string
                    read.cigarstring = c2
                    # update MD tag
                    read.set_tag("MD", generate_md_tag(ref[s:], c))
        if read.infer_query_length() != read.query_length:
            if debug:
                debug_output.write(read)
        else:
            outfile.write(read)


if __name__ == "__main__":
    run_realign()
