#!/usr/bin/env python

import argparse
import gzip

parser = argparse.ArgumentParser(description = 'Ad hoc script for filtering a table based on values in a specified column. For version info, run `bit-version`.')

required = parser.add_argument_group('required arguments')

required.add_argument("-i", "--input-table", metavar = "<FILE>", help = 'Input table', action = "store", dest = "in_tab", required = True)
required.add_argument("-w", "--wanted-values", metavar = "<FILE>", help = 'Wanted values', action = "store", dest = "wanted", required = True)

parser.add_argument("-o", "--output-file", metavar = "<FILE>", help='Output table filename (default: "Output.tsv")', action = "store", dest = "out_tab", default = "Output.tsv")
parser.add_argument("-d", "--delimiter", metavar = "<STR>", help = 'Delimiter (default: "\\t")', action = "store", default = "\t")
parser.add_argument("-c", "--column", metavar = "<INT>", help = 'Index of column to filter on (default: 1)', action = "store", default = 1, type = int)
parser.add_argument("--no-header", help='Add if there is no header', action = "store_true")
parser.add_argument("--gz", help = 'Add if the input is gzipped (output will not be)', action = "store_true")

args = parser.parse_args()

targets = set(line.strip() for line in open(args.wanted))

output = open(args.out_tab, "w")

if not args.gz:
    input = open(args.in_tab, "r")
else:
    input = gzip.open(args.in_tab, "rt")

target_column = args.column - 1


with open(args.out_tab, "w") as output:

    if not args.no_header:
        # only doing this firstline variable because i can't figure out a better way to just print the first line when the header is included (and still be just iterating over the file contents)
        firstline = True

        for line in input:

            if firstline:
                output.write(line)
                firstline = False
                continue

            split_line = line.strip().split(args.delimiter)

            if split_line[target_column] in targets:
                output.write(line)

    else:

        for line in input:

            split_line = line.strip().split(args.delimiter)

            if split_line[target_column] in targets:
                output.write(line)

input.close()
