#!/usr/bin/env python
"""
meta-sweeper - for performing parametric sweeps of simulated
metagenomic sequencing experiments.
Copyright (C) 2016 "Matthew Z DeMaere"

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published
by the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program.  If not, see <http://www.gnu.org/licenses/>.
"""
import logging
import numba as nb
import numpy as np
import os
import scipy.stats as st

from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from collections import OrderedDict
from typing import Dict, Union, List, Optional, TextIO, Iterator, Any

from .exceptions import Sim3CException
from .random import uniform, randint
from .sequence import Sequence

# suppress type safety warning of cast, which is ok.
# TODO adjust types to eliminate cast.
from numba.core.errors import NumbaTypeSafetyWarning
import warnings
warnings.simplefilter('ignore', category=NumbaTypeSafetyWarning)

logger = logging.getLogger(__name__)
logging.getLogger('numba').setLevel(logging.ERROR)

"""
The following module was transcribed and adapted from the original project's C++ code:

ART -- Artificial Read Transcription, Illumina Q version
Authors: Weichun Huang 2008-2016
License: GPL v3
"""


class ArtException(Sim3CException):
    """Module base exception class"""
    def __init__(self, message: str):
        super(ArtException, self).__init__(message)


class IllegalSymbolException(ArtException):
    """Throw this when unsupported symbols/characters are used"""
    def __init__(self, symb: Union[bytes, str]):
        if isinstance(symb, bytes):
            symb = symb.decode()
        super(IllegalSymbolException, self).__init__(
            'Encountered symbol [{}] in sequence. '
            'Ambiguous IUPAC symbols are not supported.'.format(symb))


def _clear_list(l: List):
    del l[:]


# path of this (Art.py) source file. It is expected that profiles are co-located with Art.
MODULE_PATH = os.path.dirname(os.path.abspath(__file__))


# A catalog of empirical profiles for Illumina machine types.
ILLUMINA_PROFILES = OrderedDict({
    'DRR': ('Illumina_profiles/DRRR1.txt',
            'Illumina_profiles/DRRR2.txt'),
    'Emp100': ('Illumina_profiles/Emp100R1.txt',
               'Illumina_profiles/Emp100R2.txt'),
    'Emp36': ('Illumina_profiles/Emp36R1.txt',
              'Illumina_profiles/Emp36R2.txt'),
    'Emp44': ('Illumina_profiles/Emp44R1.txt',
              'Illumina_profiles/Emp44R2.txt'),
    'Emp50': ('Illumina_profiles/Emp50R1.txt',
              'Illumina_profiles/Emp50R2.txt'),
    'Emp75': ('Illumina_profiles/Emp75R1.txt',
              'Illumina_profiles/Emp75R2.txt'),
    'EmpMiSeq250': ('Illumina_profiles/EmpMiSeq250R1.txt',
                    'Illumina_profiles/EmpMiSeq250R2.txt'),
    'EmpR36': ('Illumina_profiles/EmpR36R1.txt',
               'Illumina_profiles/EmpR36R2.txt'),
    'EmpR44': ('Illumina_profiles/EmpR44R1.txt',
               'Illumina_profiles/EmpR44R2.txt'),
    'EmpR50': ('Illumina_profiles/EmpR50R1.txt',
               'Illumina_profiles/EmpR50R2.txt'),
    'EmpR75': ('Illumina_profiles/EmpR75R1.txt',
               'Illumina_profiles/EmpR75R2.txt'),
    'HiSeq2500L125': ('Illumina_profiles/HiSeq2500L125R1.txt',
                      'Illumina_profiles/HiSeq2500L125R2.txt'),
    'HiSeq2500L150': ('Illumina_profiles/HiSeq2500L150R1.txt',
                      'Illumina_profiles/HiSeq2500L150R2.txt'),
    'HiSeq2500L150filt': ('Illumina_profiles/HiSeq2500L150R1filter.txt',
                          'Illumina_profiles/HiSeq2500L150R2filter.txt'),
    'HiSeq2kL100': ('Illumina_profiles/HiSeq2kL100R1.txt',
                    'Illumina_profiles/HiSeq2kL100R2.txt'),
    'HiSeqXPCRfreeL150': ('Illumina_profiles/HiSeqXPCRfreeL150R1.txt',
                          'Illumina_profiles/HiSeqXPCRfreeL150R2.txt'),
    'HiSeqXtruSeqL150': ('Illumina_profiles/HiSeqXtruSeqL150R1.txt',
                         'Illumina_profiles/HiSeqXtruSeqL150R2.txt'),
    'MiSeqv3L250': ('Illumina_profiles/MiSeqv3L250R1.txt',
                    'Illumina_profiles/MiSeqv3L250R2.txt'),
    'NextSeq500v2L75': ('Illumina_profiles/NextSeq500v2L75R1.txt',
                        'Illumina_profiles/NextSeq500v2L75R2.txt')})


def get_profile(name: str) -> Iterator[Union[bytes, str]]:
    """
    Return the absolute path to a requested Illumina profile.
    :param name: the name of the profile.
    :return: absolute (full) path
    """
    assert name in ILLUMINA_PROFILES, 'Unknown profile name. Try one of: {}'.format(
        ', '.join([*ILLUMINA_PROFILES]))

    return map(lambda pi: os.path.join(MODULE_PATH, pi), ILLUMINA_PROFILES[name])


# IUPAC ambiguous symbols are converted to N, preserving case.
AMBIGUOUS_CONVERSION_TABLE = str.maketrans('mrwsykvhdbMRWSYKVHDB',
                                           'nnnnnnnnnnNNNNNNNNNN')


@nb.jit(nb.uint8[:](nb.int64[:, :, :], nb.int64[:]), nopython=True)
def _random_to_quality_numba(q3d: np.ndarray, rv_list: np.ndarray) -> np.ndarray:
    """
    Translate an array of random values to quality scores using a symbols set of empirical CDFs.
    :param q3d: the contiguous qCDF 3D array
    :param rv_list: an array of random values
    :return: quality score array of equal length to the array of random values
    """
    quals = np.zeros(shape=len(rv_list), dtype=np.uint8)
    for i in range(len(rv_list)):
        qcdf = q3d[i]
        quals[i] = qcdf[np.searchsorted(qcdf[:, 0], rv_list[i]), 1]
    return quals


def to_numba_dict(d: Dict, key_type: Any, value_type: Any) -> Dict:
    nbdict = nb.typed.Dict.empty(key_type=key_type, value_type=value_type)
    for k, v in d.items():
        nbdict[k] = v
    return nbdict


class EmpDist(object):

    FIRST = True
    SECOND = False
    HIGHEST_QUAL = 80
    MAX_DIST_NUMBER = 1.0e6
    CMB_SYMB = '.'
    A_SYMB = ord(b'A')
    C_SYMB = ord(b'C')
    G_SYMB = ord(b'G')
    T_SYMB = ord(b'T')
    N_SYMB = ord(b'N')
    PRIMARY_SYMB = {A_SYMB, C_SYMB, G_SYMB, T_SYMB}
    ALL_SYMB = PRIMARY_SYMB | {CMB_SYMB, N_SYMB}

    # lookup table of probability indexed by quality score
    PROB_ERR = np.apply_along_axis(
        lambda xi: 10.0**(-xi*0.1), 0, np.arange(HIGHEST_QUAL))

    # a dictionary of all combinations of single-symbol
    # knockouts used in random substitutions
    KO_SUBS_LOOKUP = to_numba_dict(OrderedDict({ord(b'A'): np.array([ord(b'C'), ord(b'G'), ord(b'T')], dtype=np.uint8),
                                                ord(b'C'): np.array([ord(b'A'), ord(b'G'), ord(b'T')], dtype=np.uint8),
                                                ord(b'G'): np.array([ord(b'A'), ord(b'C'), ord(b'T')], dtype=np.uint8),
                                                ord(b'T'): np.array([ord(b'A'), ord(b'C'), ord(b'G')], dtype=np.uint8)}),
                                   nb.types.uint8,
                                   nb.types.uint8[:])

    # alternatively, the list of all symbols for uniform random selection
    ALL_SUBS = np.array(sorted(PRIMARY_SYMB))

    @staticmethod
    def create(name: str, sep_quals: bool = False) -> 'EmpDist':
        """
        Instantiate a EmpDist with the specified profile name.
        :param name: empirically derived machine profile
        :param sep_quals: independent quality model per base (A,C,G,T)
        :return: instance of EmpDist
        """
        profile_r1, profile_r2 = get_profile(name)
        return EmpDist(profile_r1, profile_r2, sep_quals)

    def __init__(self, fname_first: str, fname_second: str, sep_quals: bool = False):
        """
        :param fname_first: file name of first read profile
        :param fname_second: file name of second read profile
        :param sep_quals: independent quality model per base (A,C,G,T)
        """
        self.sep_quals = sep_quals
        assert not sep_quals, 'Separate base qualities are not currently supported by this python implementation'
        self.qual_dist_first = dict(zip(EmpDist.ALL_SYMB, [[], [], [], [], [], []]))
        self.qual_dist_second = dict(zip(EmpDist.ALL_SYMB, [[], [], [], [], [], []]))
        self.dist_max = {EmpDist.FIRST: 0, EmpDist.SECOND: 0}
        self.init_dist(fname_first, fname_second)
        # create a contiguous datatype for numba calls
        self.q3d_first = EmpDist.embed_ragged(self.qual_dist_first[self.CMB_SYMB])
        self.q3d_second = EmpDist.embed_ragged(self.qual_dist_second[self.CMB_SYMB])

    @staticmethod
    def embed_ragged(qual_dist: List[np.ndarray]) -> np.ndarray:
        """
        Embed the ragged array of empirically derived Q CDFs into a 3D numpy array, where left over elements are
        initialized with large values. This datatype is much more efficient to pass to Numba JIT methods.
        :param qual_dist: the symbol whose set of positional distributions are to be embedded
        :return: a 3D numpy matrix of qcdfs for a given symbol
        """
        # largest counting value used in the set of CDFs
        qv_max = max(qi.shape[0] for qi in qual_dist)

        # a 3D matrix which can contain all the CDFs, though some can be shorter
        q3d = np.empty(shape=(len(qual_dist), qv_max, 2), dtype=np.int64)

        # unassigned elements will be one larger than largest defined
        q3d.fill(int(EmpDist.MAX_DIST_NUMBER)+1)

        for i in range(len(qual_dist)):
            j, k = qual_dist[i].shape
            # initialise from the left
            q3d[i, :j, :k] = qual_dist[i]

        return q3d

    def init_dist(self, fname_first: str, fname_second: str):
        """
        Initialise the per-base_position distribution of quality scores
        :param fname_first: profile for first read
        :param fname_second: profile for second read
        """
        map(_clear_list, self.qual_dist_first.values())
        map(_clear_list, self.qual_dist_second.values())

        with open(fname_first, 'rt') as hndl:
            self.read_emp_dist(hndl, True)

        with open(fname_second, 'rt') as hndl:
            self.read_emp_dist(hndl, False)

        if not self.sep_quals:
            self.dist_max[EmpDist.FIRST] = len(self.qual_dist_first[self.CMB_SYMB])
            self.dist_max[EmpDist.SECOND] = len(self.qual_dist_second[self.CMB_SYMB])
        else:
            dist_len = np.array([len(self.qual_dist_first[k]) for k in self.PRIMARY_SYMB])
            assert np.all(dist_len == dist_len.max()), \
                'Invalid first profile, not all symbols represented over full range'
            self.dist_max[EmpDist.FIRST] = dist_len.max()

            dist_len = np.array([len(self.qual_dist_second[k]) for k in self.PRIMARY_SYMB])
            assert np.all(dist_len == dist_len.max()), \
                'Invalid second profile, not all symbols represented over full range'
            self.dist_max[EmpDist.SECOND] = dist_len.max()

    def verify_length(self, length: int, is_first: bool):
        """
        Verify that profile and requested length are agreeable
        :param length: read length
        :param is_first: first or second read
        :return: True -- supported by profile
        """
        assert length <= self.dist_max[is_first], 'Requested length exceeds that of profile'

    def get_read_qual(self, read_len: int, is_first: bool) -> np.ndarray:
        """
        Read qualities for a given read-length
        :param read_len: length of read to simulate
        :param is_first: first or second read
        :return: simulated qualities
        """
        self.verify_length(read_len, is_first)
        if is_first:
            return self._get_from_dist(self.q3d_first, read_len)
        else:
            return self._get_from_dist(self.q3d_second, read_len)

    @staticmethod
    def _get_from_dist(qual_dist_for_symb: np.ndarray, read_len: int) -> np.ndarray:
        """
        Generate simulated quality scores for a given length using an initialised
        distribution. Scores are related to the empirically determined CDFs specified
        at initialisation.
        :param qual_dist_for_symb: combined or separate symbols
        :param read_len: read length to simulate
        :return: simulated quality scores
        """
        # draw a set of random values equal to the length of a read
        rv_list = randint(1, int(EmpDist.MAX_DIST_NUMBER)+1, size=read_len, dtype=np.int64)

        # convert this random rolls to quality scores
        quals = _random_to_quality_numba(qual_dist_for_symb, rv_list)

        assert len(quals) > 0
        assert len(quals) == read_len

        return quals

    def read_emp_dist(self, hndl: TextIO, is_first: bool) -> bool:
        """
        Read an empirical distribution from a file.
        :param hndl: open file handle
        :param is_first: first or second read profile
        :return: True -- profile was not empty
        """
        n = 0
        while True:
            line = hndl.readline().strip()

            if not line:
                # end of file
                break
            if len(line) <= 0 or line.startswith('#'):
                # skip empty and comment lines
                continue

            tok = line.split('\t')
            symb, read_pos, values = tok[0], int(tok[1]), np.array(tok[2:], dtype=int)

            # skip lines pertaining to unrequested mode
            if self.sep_quals:
                if symb == self.CMB_SYMB or symb == self.N_SYMB:
                    # requested separate quals but this pertains to combined or N
                    continue
            else:  # if combined
                if symb != self.CMB_SYMB:
                    # requested combined but this pertains to separate
                    continue

            if read_pos != n:
                if read_pos != 0:
                    raise IOError('Error: invalid format in profile at [{}]'.format(line))
                n = 0

            line = hndl.readline().strip()
            tok = line.split('\t')
            symb, read_pos, counts = tok[0], int(tok[1]), np.array(tok[2:], dtype=int)

            if read_pos != n:
                raise IOError('Error: invalid format in profile at [{}]'.format(line))

            if len(values) != len(counts):
                raise IOError('Error: invalid format in profile at [{}]'.format(line))

            dist = np.array([(cc, values[i]) for i, cc in
                             enumerate(np.ceil(counts * EmpDist.MAX_DIST_NUMBER // counts[-1]).astype(np.int64))])

            if dist.size > 0:
                n += 1
                try:
                    if is_first:
                        self.qual_dist_first[symb].append(dist)
                    else:
                        self.qual_dist_second[symb].append(dist)
                except Exception:
                    raise IOError('Error: unexpected base symbol [{}] linked to distribution'.format(symb))

        return n != 0


class SeqRead(object):

    PHRED33_OFFSET = 33

    def __init__(self, read_len: int, ins_rate: List[float], del_rate: List[float],
                 max_num: int = 2, plus_strand: bool = None):
        self.max_num = max_num
        self.read_len = read_len
        self.is_plus_strand = plus_strand
        self.seq_ref = np.zeros(read_len, dtype=np.uint8)
        self.seq_read = np.zeros(read_len, dtype=np.uint8)
        self.quals = None
        self.bpos = None
        self.indel = {}
        self.del_rate = del_rate
        self.ins_rate = ins_rate

    def __str__(self):
        return 'from {}...{}bp created {}'.format(self.seq_ref[0:10], self.seq_ref.shape[0], self.seq_read)

    # def write_read(self, fh: BinaryIO, seq_id: str, desc: str = ''):
    #     """
    #     Simple method to write FastQ Phred33 format.
    #
    #     :param fh: output file handle
    #     :param seq_id: sequence id
    #     :param desc: sequence description
    #     """
    #     s = self.seq_read
    #     s.name = seq_id.encode()
    #     s.description = f' {desc}'.encode()
    #     s.qual = self.quals
    #     s.write_fastq(fh)

    def read_record(self, seq_id: str, desc: str = '') -> SeqRecord:
        """
        Create a Biopython SeqRecord appropriate for writing to disk and matching the format
        generated by ART_illumina
        :param seq_id: sequence id for read
        :param desc: sequence description
        :return: Bio.SeqRecord
        """
        rec = SeqRecord(
                Seq(self._read_str()),
                id=seq_id,
                description=desc)
        # seems the only means of adding quality scores to a SeqRecord
        rec.letter_annotations['phred_quality'] = self.quals
        return rec

    def _read_desc(self) -> str:
        """
        Create a string description for this read, suitable for inclusion if output
        :return: a string description
        """
        return '{}{}'.format(self.bpos, 'F' if self.is_plus_strand else 'R')

    @staticmethod
    def read_id(ref_id: str, n: int) -> str:
        """
        Create an id for this read, based on the mother sequence and an index. This follows ART_illumina
        practice.
        :param ref_id: mother sequence id
        :param n: an index for the read
        :return: a string id for this read
        """
        return '{}-{}'.format(ref_id, n)

    def _read_str(self) -> str:
        """
        Create a string representation of this read's sequence. This is necessary as internally
        the sequence is handled as a list -- since strings are immutable in Python.
        :return:
        """
        return self.seq_read.tobytes().decode()

    # def get_indel(self) -> int:
    #     """
    #     Generate insertion and deletions
    #     :return: net change in length, i.e. insertion_length - deletion_length
    #     """
    #     self.indel.clear()
    #     ins_len = 0
    #     del_len = 0
    #
    #     # deletion
    #     for i in range(len(self.del_rate)-1, -1, -1):
    #         if self.del_rate[i] >= uniform():
    #             del_len = i+1
    #             j = i
    #             while j >= 0:
    #                 # invalid deletion positions: 0 or read_len-1
    #                 pos = randint(0, self.read_len)
    #                 if pos == 0:
    #                     continue
    #                 if pos not in self.indel:
    #                     self.indel[pos] = '-'
    #                     j -= 1
    #             break
    #
    #     # insertion
    #     for i in range(len(self.ins_rate)-1, -1, -1):
    #         # ensure that enough unchanged position for mutation
    #         if self.read_len - del_len - ins_len < i+1:
    #             continue
    #         if self.ins_rate[i] >= uniform():
    #             ins_len = i+1
    #             j = i
    #             while j >= 0:
    #                 pos = randint(0, self.read_len)
    #                 if pos not in self.indel:
    #                     self.indel[pos] = random_base()
    #                     j -= 1
    #             break
    #
    #     return ins_len - del_len

    # # number of deletions <= number of insertions
    # def get_indel_2(self) -> int:
    #     """
    #     Second method for creating indels. Called in some situations when the first method
    #     as returned an unusable result.
    #     :return: net change in length, i.e. insertion_length - deletion_length
    #     """
    #
    #     # start over
    #     self.indel.clear()
    #     ins_len = 0
    #     del_len = 0
    #
    #     for i in range(len(self.ins_rate)-1, -1, -1):
    #         if self.ins_rate[i] >= uniform():
    #             ins_len = i+1
    #             j = i
    #             while j >= 0:
    #                 pos = randint(0, self.read_len)
    #                 if pos not in self.indel:
    #                     self.indel[pos] = random_base()
    #                     j -= 1
    #             break
    #
    #     # deletion
    #     for i in range(len(self.del_rate)-1, -1, -1):
    #         if del_len == ins_len:
    #             break
    #
    #         # ensure that enough unchanged position for mutation
    #         if self.read_len - del_len - ins_len < i+1:
    #             continue
    #
    #         if self.del_rate[i] >= uniform():
    #             del_len = i+1
    #             j = i
    #             while j >= 0:
    #                 pos = randint(0, self.read_len)
    #                 if pos == 0:
    #                     continue
    #                 if pos not in self.indel:
    #                     self.indel[pos] = '-'
    #                     j -= 1
    #             break
    #
    #     return ins_len - del_len

    # def ref2read(self):
    #     """
    #     From the reference (mother) sequence, generating the read's sequence along
    #     with the indels.
    #     """
    #     if len(self.indel) == 0:
    #         # straight to an result if no indels, where here seq_ref
    #         # has already been chopped to the read length.
    #         self.seq_read = self.seq_ref
    #
    #     else:
    #         # otherwise, we gotta a little more work to do.
    #         self.seq_read.seq[:] = 0
    #
    #         n = 0
    #         k = 0
    #         i = 0
    #         while i < len(self.seq_ref):
    #             if k not in self.indel:
    #                 self.seq_read.seq[n] = self.seq_ref.seq[i]
    #                 n += 1
    #                 i += 1
    #                 k += 1
    #             elif self.indel[k] == '-':
    #                 # deletion
    #                 i += 1
    #                 k += 1
    #             else:
    #                 # insertion
    #                 self.seq_read.seq[n] = self.indel[k]
    #                 n += 1
    #                 k += 1
    #
    #         while k in self.indel:
    #             self.seq_read.seq[n] = self.indel[k]
    #             n += 1
    #             k += 1

    def length(self) -> int:
        """
        Return the actual length of the simulation result. This can be shorter than the requested
        length "read_len" due to short templates.
        :return: length of actual simulated sequence
        """
        return len(self.seq_read)


DELETION = ord(b'-')

# t_indel_dict = nb.typeof(nb.typed.Dict.empty(key_type=nb.types.uint, value_type=nb.types.uint))
# @nb.jit(nb.int64(t_indel_dict, nb.float64[:], nb.float64[:], nb.int64), nopython=True)
def get_indel(indels: Dict[int, int], del_rate: np.ndarray, ins_rate: np.ndarray, read_len: int) -> int:
    """
    Generate insertion and deletions
    :return: net change in length, i.e. insertion_length - deletion_length
    """
    ins_len = 0
    del_len = 0

    rv = uniform(size=len(del_rate))

    # deletion
    for i in range(len(del_rate)-1, -1, -1):
        if del_rate[i] >= rv[i]:
            del_len = i+1
            j = i
            while j >= 0:
                # invalid deletion positions: 0 or read_len-1
                pos = randint(0, read_len)
                if pos == 0:
                    continue
                if pos not in indels:
                    indels[pos] = DELETION
                    j -= 1
            break

    rv = uniform(size=len(del_rate))

    # insertion
    for i in range(len(ins_rate)-1, -1, -1):
        # ensure that enough unchanged position for mutation
        if read_len - del_len - ins_len < i+1:
            continue
        if ins_rate[i] >= rv[i]:
            ins_len = i+1
            j = i
            while j >= 0:
                pos = randint(0, read_len)
                if pos not in indels:
                    indels[pos] = random_base()
                    j -= 1
            break

    return ins_len - del_len


def get_indel_2(indels: Dict[int, int], del_rate: np.ndarray, ins_rate: np.ndarray, read_len: int) -> int:
    """
    Second method for creating indels. Called in some situations when the first method
    as returned an unusable result.
    :return: net change in length, i.e. insertion_length - deletion_length
    """

    # start over
    ins_len = 0
    del_len = 0

    rv = uniform(size=len(ins_rate))

    for i in range(len(ins_rate)-1, -1, -1):
        if ins_rate[i] >= rv[i]:
            ins_len = i+1
            j = i
            while j >= 0:
                pos = randint(0, read_len)
                if pos not in indels:
                    indels[pos] = random_base()
                    j -= 1
            break

    rv = uniform(size=len(ins_rate))

    # deletion
    for i in range(len(del_rate)-1, -1, -1):
        if del_len == ins_len:
            break

        # ensure that enough unchanged position for mutation
        if read_len - del_len - ins_len < i+1:
            continue

        if del_rate[i] >= rv[i]:
            del_len = i+1
            j = i
            while j >= 0:
                pos = randint(0, read_len)
                if pos == 0:
                    continue
                if pos not in indels:
                    indels[pos] = DELETION
                    j -= 1
            break

    return ins_len - del_len


@nb.jit(nopython=True)
def mutate_read(original: np.ndarray, indels: Dict[int, int], read_len: int) -> np.ndarray:
    """
    From the reference (mother) sequence, generating the read's sequence along
    with the indels.
    """
    # otherwise, we gotta a little more work to do.
    # org_seq = original
    org_seq = original
    mut_seq = np.empty(read_len, org_seq.dtype)

    n = 0
    k = 0
    i = 0
    while i < len(org_seq):
        if k not in indels:
            mut_seq[n] = org_seq[i]
            n += 1
            i += 1
            k += 1
        elif indels[k] == DELETION:
            # deletion
            i += 1
            k += 1
        else:
            # insertion
            mut_seq[n] = indels[k]
            n += 1
            k += 1

    while k in indels:
        mut_seq[n] = indels[k]
        n += 1
        k += 1

    # original.seq = mut_seq
    # return original
    return mut_seq


@nb.vectorize('boolean(float64)', nopython=True)
def uniform_chance(prob_q):
    return np.random.uniform(0, 1) < prob_q


t_dict = nb.typeof(nb.typed.Dict.empty(key_type=nb.types.uint8, value_type=nb.types.uint8[:]))


@nb.jit(nb.void(nb.uint8[:], nb.uint8[:], t_dict, nb.float64[:]), nopython=True)
def parse_error(qual: np.ndarray, seq: np.ndarray, subs_table: Dict[int, np.ndarray], err_prob: np.ndarray):
    """
    When analyzed, sequences are potentially modified by the simulated quality scores.
    Beginning with the basic transcription from Art C/C++ code, this method has been reimplemented to use
    Numpy for speed improvements, but does not employ Numba as we must respect the existing random state.

    :param qual: quality scores, modified in place
    :param seq: DNA sequence, modified in place
    """

    # indices of all Ns in sequence
    n_ix = seq == 78

    # set the quality of these locations to 1
    qual[n_ix] = 1

    # select sites of mutation randomly depending on quality
    to_mutate = uniform_chance(err_prob[qual])

    # exclude those having N
    ix_mut = np.where((~n_ix) & to_mutate)[0]

    # random choice of substitution
    for ix in np.nditer(ix_mut):
        seq[ix] = subs_table[seq[ix]][np.random.randint(0, 3)]

def random_base(excl: int = None) -> int:
    """
    Return a random selection of A,C,G or T. If specified, exclude one of the four.
    :param excl: a base to exclude from the draw
    :return: a random base.
    """
    if not excl:
        return EmpDist.ALL_SUBS[randint(0, 4)]
    else:
        return EmpDist.KO_SUBS_LOOKUP[excl][randint(0, 3)]


class Art(object):

    # translation table, non-standard bases become N
    COMPLEMENT_TABLE_STR = str.maketrans('acgtumrwsykvhdbnACGTUMRWSYKVHDBN',
                                         'TGCAAnnnnnnnnnnnTGCAANNNNNNNNNNN')

    COMPLEMENT_TABLE_BYTES = bytes.maketrans(b'acgtumrwsykvhdbnACGTUMRWSYKVHDBN',
                                             b'TGCAAnnnnnnnnnnnTGCAANNNNNNNNNNN')

    def __init__(self, read_len: int, emp_dist: EmpDist, ins_prob: float, del_prob: float, max_num: int = 2,
                 ref_seq: Optional[Sequence] = None, default_qual: Optional[int] = 40):

        # check immediately that read lengths are possible for profile
        emp_dist.verify_length(read_len, True)
        emp_dist.verify_length(read_len, False)
        self.emp_dist = emp_dist

        # convert immutable string to list
        if ref_seq is not None:
            self.ref_seq = ref_seq
            self.ref_seq.qual.fill(default_qual)
            self.ref_seq_cmp = ~ref_seq
            self.valid_region = len(ref_seq) - read_len
        else:
            logger.warning('No reference supplied, calls will have to supply a template')
            self.ref_seq = None
            self.ref_seq_cmp = None

        self.read_len = read_len
        self.max_num = max_num
        self.ins_rate = self._make_rate(ins_prob)
        self.del_rate = self._make_rate(del_prob)
        self.default_qual = default_qual

        self.indels_buffer = nb.typed.Dict.empty(key_type=nb.types.uint, value_type=nb.types.uint)


    def _make_rate(self, prob: float) -> np.ndarray:
        """
        Create the rates for an error type, returning a list of max_num length
        :param prob: probability of an error
        :return: list
        """
        rates = []
        if self.max_num > self.read_len:
            self.max_num = self.read_len
        for i in range(1, self.max_num+1):
            rates.append(1 - st.binom.cdf(i, self.read_len, prob))
        return np.array(rates)

    # def _new_read(self, rlen: int = None, plus_strand: bool = True) -> SeqRead:
    #     """
    #     Create a new read object ready for simulation.
    #     :param rlen: a read length other than what was defined when instantiating Art.
    #     :param plus_strand: True - forward strand, False - reverse strand
    #     :return: a new read object
    #     """
    #     if not rlen:
    #         return SeqRead(self.read_len, self.ins_rate, self.del_rate, self.max_num, plus_strand=plus_strand)
    #     else:
    #         return SeqRead(rlen, self.ins_rate, self.del_rate, self.max_num, plus_strand=plus_strand)

    def next_pair_simple_seq(self, template: Sequence) -> Dict[str, Sequence]:
        """
        Get a fwd/rev pair of simple error-free reads for a template, where each read is sequenced off the ends.
        :param template: the target tempalte to sequencing fwd/rev
        :return: a dict {'fwd': SeqRead, 'rev': SeqRead}
        """
        return {'fwd': self.next_read_simple_seq(template, True),
                'rev': self.next_read_simple_seq(template, False)}

    def next_read_simple_seq(self, template: Sequence, plus_strand: bool) -> Sequence:
        """
        Generate a simple error-free read and constant quality values.
        :param template: the target template to sequence
        :param plus_strand: forward: True, reverse: False
        :return: SeqRead
        """
        _len = self.read_len
        if len(template) < self.read_len:
            # for templates shorter than the requested length, we sequence its total extent
            _len = len(template)

        if plus_strand:
            read = template[:_len]
        else:
            read = template.revcomp()[:_len]

        read.tags['plus_strand'] = plus_strand
        read.qual = np.empty_like(read.seq)
        read.qual.fill(self.default_qual)
        return read

    def next_pair_indel_seq(self, template: Sequence) -> Dict[str, Sequence]:
        """
        Get a fwd/rev pair of reads for a template, where each read is sequenced off the ends.
        :param template: the target tempalte to sequencing fwd/rev
        :return: a dict {'fwd': SeqRead, 'rev': SeqRead}
        """
        return {'fwd': self.next_read_indel_seq(template, True),
                'rev': self.next_read_indel_seq(template, False)}

    def next_read_indel_seq(self, template: Sequence, plus_strand: bool) -> Sequence:
        """
        Generate a read off a supplied target template sequence.
        :param template: the target template to sequence
        :param plus_strand: forward: True, reverse: False
        :return: SeqRead
        """
        _len = self.read_len
        if len(template) < _len:
            # for templates shorter than the requested length, we sequence its total extent
            _len = len(template)

        indels = self.indels_buffer
        indels.clear()

        slen = get_indel(indels, self.del_rate, self.ins_rate, _len)

        # ensure that this read will fit within the extent of the template
        if _len - slen > len(template):
            indels.clear()
            slen = get_indel_2(indels, self.del_rate, self.ins_rate, _len)

        if plus_strand:
            read = (template[:_len - slen]).copy()
        else:
            read = template.revcomp()[:_len - slen]

        if len(indels) > 0:
            read.seq = mutate_read(read.seq, indels, _len)

        # simulated quality scores from profiles
        # read.qual = self.emp_dist.get_read_qual(_len, plus_strand)
        read.qual = self.emp_dist.get_read_qual(_len, plus_strand)

        # the returned quality scores can spawn sequencing errors
        # parse_error(read.qual, read.seq)
        parse_error(read.qual, read.seq, EmpDist.KO_SUBS_LOOKUP, EmpDist.PROB_ERR)

        return read

    def next_read_indel_at(self, pos: int, plus_strand: bool) -> Sequence:
        """
        Create a read with an already determined position and direction.
        :param pos: position for read
        :param plus_strand: True = forward, False = reverse
        :return: SeqRead
        """

        indels = self.indels_buffer
        indels.clear()

        slen = get_indel(indels, self.del_rate, self.ins_rate, self.read_len)

        # ensure that this read will fit within the extent of the reference
        if pos + self.read_len - slen > len(self.ref_seq):
            indels.clear()
            slen = get_indel_2(indels, self.del_rate, self.ins_rate, self.read_len)

        if plus_strand:
            read = self.ref_seq[pos: pos + self.read_len - slen]
        else:
            read = self.ref_seq_cmp[pos: pos + self.read_len - slen]

        read.tags['coord'] = str(pos)
        read.seq = mutate_read(read.seq, indels)

        # simulated quality scores from profiles
        read.qual = self.emp_dist.get_read_qual(len(read), True)
        # the returned quality scores can spawn sequencing errors
        parse_error(read.qual, read.seq, EmpDist.KO_SUBS_LOOKUP, EmpDist.PROB_ERR)

        return read

    def next_read_indel(self) -> Sequence:
        """
        Create the next SeqRead and its accompanying quality scores. Position and direction are
        determined by uniform random seletion.

        :return: SeqRead
        """
        # random position anywhere in valid range
        pos = randint(0, self.valid_region)

        # is it the forward strand?
        plus_strand = uniform() < 0.5

        return self.next_read_indel_at(pos, plus_strand)


# if __name__ == '__main__':
#
#     from Bio import SeqIO
#     import argparse
#     import math
#
#     parser = argparse.ArgumentParser(description='Generate Illumina reads')
#     parser.add_argument('-S', '--seed', type=int, default=None, help='Random seed')
#     parser.add_argument('--profile1', help='ART sequencer profile for R1', required=True)
#     parser.add_argument('--profile2', help='ART sequencer profile for R2', required=True)
#     parser.add_argument('-l', '--read-length', type=int, help='Read length', required=True)
#     parser.add_argument('-X', '--xfold', type=float, help='Depth of coverage')
#     parser.add_argument('-N', '--num-reads', type=int, help='Number of reads')
#     parser.add_argument('--ins-rate', type=float, default=0.00009, help='Insert rate')
#     parser.add_argument('--del-rate', type=float, default=0.00011, help='Deletion rate')
#     parser.add_argument('fasta', help='Reference fasta')
#     parser.add_argument('outbase', help='Output base name')
#     args = parser.parse_args()
#
#     if args.xfold and args.num_reads:
#         raise RuntimeError('xfold and num-reads are mutually exclusive options')
#
#     with open('{}.r1.fq'.format(args.outbase), 'wt', buffering=262144) as r1_h, \
#             open('{}.r2.fq'.format(args.outbase), 'wt', buffering=262144) as r2_h:
#
#         for input_record in SeqIO.parse(args.fasta, 'fasta'):
#
#             # ref to string
#             ref_seq = str(input_record.seq).encode()
#
#             # empirical distribution from files
#             emp_dist = EmpDist(args.profile1, args.profile2)
#
#             # init Art
#             art = Art(args.read_length, emp_dist, args.ins_rate, args.del_rate)
#
#             if args.xfold:
#                 num_seq = int(math.ceil(len(art.ref_seq) / args.read_length * args.xfold))
#             else:
#                 num_seq = args.num_reads
#
#             logger.info('Generating {} reads for {}'.format(num_seq, input_record.id))
#
#             print_rate = num_seq // 10
#
#             for n in range(num_seq):
#
#                 ins_len = None
#                 while True:
#                     ins_len = int(np.ceil(normal(500, 50)))
#                     if ins_len > 200:
#                         break
#
#                 # pick a random position on the chromosome, but we're lazy and don't
#                 # handle the edge case of crossing the origin
#                 pos = randint(0, len(ref_seq)-ins_len)
#
#                 # get next read and quals
#                 pair = art.next_pair_indel_seq(ref_seq[pos: pos + ins_len])
#
#                 # create file records
#                 SeqIO.write(pair['fwd'].read_record(SeqRead.read_id(input_record.id, n)), r1_h, 'fastq')
#                 SeqIO.write(pair['rev'].read_record(SeqRead.read_id(input_record.id, n)), r2_h, 'fastq')
#
#                 if ((n+1)*100) % print_rate == 0:
#                     logger.info('Wrote {} pairs'.format(n+1))
#             break
