#!/usr/bin/env python
from signal import signal, SIGPIPE, SIG_DFL
from sys import stdin, stderr, argv
from collections import defaultdict
import argparse

# emulates join, but support multiple files. 
# myJoin *.region.bedgraph -e 1,2,3 -c 4 -f "NA"

signal(SIGPIPE, SIG_DFL)

# parse args
parser = argparse.ArgumentParser()
parser.add_argument('inFiles', nargs='*', help="input files, support wild card, make sure each file has unique entry, or values be overwritten.")
parser.add_argument("-e", "--entrycolumn",
					help="Specify columns from the input file to use as entry, default is 1,2,3 as in bedfile")
parser.add_argument("-c", "--column",
					help="Specify columns from the input file to operate upon.")
parser.add_argument("-f", "--fill",
					help="Specify the fill for empty filed.")
parser.add_argument("-E", "--eval", action='store_true',
					help="flag eval mode for input args of entrycolumn, column and operation, eg. collapseBed -E -e '[1,2,3]' -c 'list(range(11,20)+[50])' -o '['mean']*10' ")
args = parser.parse_args()

# parse range args
# eg. "1,2,3" -> [1,2,3]
# eg. "1-15,18,20-31" -> [1,2,3 .. 15, 18, 20...31]
# eg. 'list(range(11,20)+[50])' if eval_flag is on
def parseRange(s, eval_flag, is_int=True):
	if eval_flag:
		ls = eval(s)
		return ls
	if is_int == True:
		ls = []
		parts = s.split(",")
		for part in parts:
			if "-" in part:
				a, b = part.split("-")
				l = range(int(a), int(b)+1)
				ls += l
			else:
				ls.append(int(part))
		return ls
	else:
		return s.split(",")

# check args:
if args.column:
	cols = parseRange(args.column, eval_flag=args.eval)
else:
	cols = [2]
if args.entrycolumn:
	ecols = parseRange(args.entrycolumn, eval_flag=args.eval)
else:
	ecols = [1]
if args.fill:
	fill=args.fill
else:
	fill=""
inFiles = args.inFiles

# data
data = defaultdict(dict)
# specified input

# read input -> data
for inFile in inFiles:
	with open(inFile) as f:
		for line in f.readlines():
			entry = tuple([line.split()[ecol-1] for ecol in ecols])
			items = [line.split()[col-1] for col in cols]
			data[entry][inFile] = items # if entry is not unique, values will be overwritten

# print out values
headerList = ecols + [inFile+"."+str(col) for inFile in inFiles for col in cols]
header = "\t".join(map(str, headerList))
print(header)

for entry in sorted(data.keys()):
	listsByLine = data[entry]
	toPrintList = list(entry)
	for inFile in inFiles:
		fieldList = listsByLine.get(inFile, [fill]*len(cols))
		toPrintList += fieldList
	print("\t".join(map(str, toPrintList)))
