#!/usr/bin/env python

from Bio.SeqIO.FastaIO import SimpleFastaParser
import sys
import argparse

parser = argparse.ArgumentParser(description='This script is for parsing a fasta file by pulling out sequences with the desired headers. If you want all sequences EXCEPT the ones with the headers you are providing, add the flag "--inverse".')

required = parser.add_argument_group('required arguments')

required.add_argument("-i", "--input-fasta", help="Original fasta file", action="store", required=True)
required.add_argument("-w", "--wanted-headers", help="Single-column file with sequence headers", action="store", required=True)
parser.add_argument("-o", "--output-fasta", help='Output fasta file default: "Wanted.fa"', action="store", default="Wanted.fa")
parser.add_argument("--inverse", help="Add this flag to pull out all sequences with headers NOT in the provided header file.", action="store_true")

if len(sys.argv)==1:
    parser.print_help(sys.stderr)
    sys.exit(0)

args = parser.parse_args()

headers_of_int_list = [line.strip() for line in open(args.wanted_headers)]

if not args.inverse:

    with open(args.output_fasta, "w") as output_file:
        
        with open(args.input_fasta, "r") as input_file:

            for header, seq in SimpleFastaParser(input_file):

                if header in headers_of_int_list:

                    output_file.write(">%s\n%s\n" % (header, seq))

else:

    with open(args.output_fasta, "w") as output_file:
        
        with open(args.input_fasta, "r") as input_file:

            for header, seq in SimpleFastaParser(input_file):

                if header not in headers_of_int_list:

                    output_file.write(">%s\n%s\n" % (header, seq))
