#!/usr/bin/env python

import sys
import os
import argparse
import textwrap
from Bio.SeqIO.QualityIO import FastqGeneralIterator
import gzip
import subprocess


parser = argparse.ArgumentParser(description = 'This script is for parsing a single fastq file by pulling out sequences with the desired headers (paired-end not supported yet).\
                                              For version info, run `bit-version`.')

required = parser.add_argument_group('required arguments')

required.add_argument("-i", "--input-fastq", metavar = "<FILE>", help = "Starting fastq file", action = "store", required = True)
required.add_argument("-w", "--sequence-headers", metavar = "<FILE>", help = "Single-column file with target sequence headers (if the headers in the fastq file have whitespace in them, it is okay to provide just the part up to the whitespace in this input file)", action = "store", required = True)
parser.add_argument("-o", "--output-fastq", metavar = "<FILE>", help='Output fastq file name (default: "wanted.fq", ".gz" will be added if compressed)', action = "store", default = "wanted.fq")
parser.add_argument("--inverse", help = "Add this flag to pull out all sequences with headers NOT in the provided header file.", action = "store_true")
parser.add_argument("--gz", help = "Add this flag if the input is gzipped (output will be too)", action = "store_true")

if len(sys.argv)==1:
    parser.print_help(sys.stderr)
    sys.exit(0)

args = parser.parse_args()

################################################################################

def main():

    check_all_inputs_exist([args.input_fastq, args.sequence_headers])

    check_if_output_already_exists(args.output_fastq)

    # reading headers into a set
    headers_of_int_set = set(line.strip() for line in open(args.sequence_headers, "r"))

    parse_fastq(headers_of_int_set, args.input_fastq, args.output_fastq, args.inverse, args.gz)

    if args.gz:

        subprocess.run(["pigz", args.output_fastq])

################################################################################


# setting some colors
tty_colors = {
    'green' : '\033[0;32m%s\033[0m',
    'yellow' : '\033[0;33m%s\033[0m',
    'red' : '\033[0;31m%s\033[0m'
}


### functions ###
def color_text(text, color='green'):
    if sys.stdout.isatty():
        return tty_colors[color] % text
    else:
        return text


def wprint(text):
    """ print wrapper """

    print(textwrap.fill(text, width=80, initial_indent="  ",
          subsequent_indent="  ", break_on_hyphens=False))


def check_all_inputs_exist(input_list):

    for file in input_list:
        if not os.path.exists(file):
            print("")
            wprint(color_text("It seems the specified input file '" + str(file) + "' can't be found.", "yellow"))
            print("\nExiting for now.\n")
            sys.exit(1)


def check_if_output_already_exists(planned_output):

    # making sure outputs don't already exist, exiting if they do

        if os.path.exists(planned_output):

            print("")
            wprint(color_text("It seems the expected output (or intermediate) file '" + str(planned_output) + "' already exists.", "yellow"))
            print("")
            wprint("We don't want to overwrite something accidentally, so rename or remove that first if wanting to proceed.")
            print("\nExiting for now.\n")
            sys.exit(1)


def parse_fastq(headers_of_int_set, input_fastq, output_fastq, inverse, gz):


    if gz:

        fastq_in = gzip.open(input_fastq, "rt")

    else:

        fastq_in = open(input_fastq, "rt")


    if not args.inverse:

        with open(output_fastq, "w") as output_file:

            for header, seq, qual in FastqGeneralIterator(fastq_in):

                if header in headers_of_int_set or header.split(" ")[0] in headers_of_int_set:

                    output_file.write("@%s\n%s\n+\n%s\n" % (header, seq, qual))

    else:

        with open(output_fastq, "w") as output_file:

            for header, seq, qual in FastqGeneralIterator(fastq_in):

                if header not in headers_of_int_set and header.split(" ")[0] not in headers_of_int_set:

                    output_file.write("@%s\n%s\n+\n%s\n" % (header, seq, qual))

    fastq_in.close()


if __name__ == "__main__":
    main()
