from Bio import AlignIO 
import numpy as np
import argparse, os, sys
from datetime import datetime

parser = argparse.ArgumentParser(description='remove gappy columns and sequences from fasta alignments')
parser.add_argument('-i', help='in directory', required=True)
parser.add_argument('-c', help='minimum column occupancy', required=True)
parser.add_argument('-l', help='minimum length after filtering', required=True)
parser.add_argument('-g', help='minimum percentage occupancy per individual', required=True)
parser.add_argument('-m', help='minimum sequences left necessary to write', required=True)
parser.add_argument('-o', help='out directory', required=True)
args = parser.parse_args()

fastaend = ['fasta', 'fas', 'fa']
gaps = ['-', 'n', 'N']

log = open(args.o+'/alignmentfilter.log', 'w')
log.write("Start time: "+str(datetime.now().strftime('%d-%m-%Y %H:%M'))+"\n")
for arg, value in sorted(vars(args).items()):
    log.write("Argument -"+arg+": "+value+"\n")

for file in os.listdir(args.i):
	if file.split('.')[-1] in fastaend:
		fin = AlignIO.read(args.i+'/'+file, "fasta")
		seqs = []
		for record in fin:
			seqs.append(record.id)
		aln = np.array([list(rec) for rec in fin], np.character)
		alnold = aln.shape[1]
		rem = []
		for i in range(0,aln.shape[1]):
			col = str(bytes(aln[:,i]).decode("utf-8"))
			colgappercent = (float(col.count('N')+col.count('-')+col.count('n')) / len(col)) * 100
			if colgappercent > 100 - float(args.c):
				rem.append(i)
		aln = np.delete(aln, np.s_[list(rem)], axis=1)
		if aln.shape[1] >= int(args.l):
			alignout = open(args.o+'/'+str('.'.join(file.split('.')[:-1]))+'.trim.fasta', 'w')
			remind = []
			keepind = []
			for j in range(0,aln.shape[0]):
				row = str(bytes(aln[j,]).decode('utf-8'))
				rowgappercent = (float(row.count('N')+row.count('-')+row.count('n')) / len(row)) * 100
				if rowgappercent > 100 - float(args.g):
					remind.append(int(j))
				else:
					keepind.append(int(j))
			if len(remind) > 0:
				if len(seqs) - len(remind) >= int(args.m):
					aln = np.delete(aln, np.s_[list(remind)], axis=0)
					finalseqs = []
					remseqs = []
					for x in range(0, len(seqs)):
						if x in keepind:
							finalseqs.append(seqs[x])
						else:
							if x in remind:
								remseqs.append(seqs[x])
							else:
								print('bug')
					for out in range(0, len(finalseqs)):
						alignout.write('>'+str(finalseqs[out])+'\n'+str(bytes(aln[out,]).decode('utf-8'))+'\n')
					log.write("alignment "+str(args.i)+"/"+str(file)+": "+str(len(rem))+"/"+str(alnold)+\
					" sites removed: "+str(len(remind))+"/"+str(len(seqs))+" individuals removed ("+str(', '.join(remseqs))+"): written to "+str(args.o)+"/"+str('.'.join(file.split('.')[:-1]))+".trim.fasta\n")
				else:
					remseqs = []
					for x in range(0, len(seqs)):
						if x in remind:
							remseqs.append(seqs[x])
					log.write("alignment "+str(args.i)+"/"+str(file)+" removed as too few individuals after filtering: "\
					+str(len(remind))+"/"+str(len(seqs))+" ("+str(", ".join(remseqs))+") prior to this: "+str(len(rem))+"/"+str(alnold)+\
					" sites were removed\n")
			else:
				for out in range(0, len(seqs)):
					alignout.write('>'+str(seqs[out])+'\n'+str(bytes(aln[out,]).decode('utf-8'))+'\n')
				log.write("alignment "+str(args.i)+"/"+str(file)+": "+str(len(rem))+"/"+str(alnold)+\
				" sites removed: "+str(len(remind))+"/"+str(len(seqs))+" individuals removed: written to "+str(args.o)+"/"+str('.'.join(file.split('.')[:-1]))+".trim.fasta\n")	
		else:
			log.write("alignment "+str(args.i)+"/"+str(file)+" removed as sequence length ("+str(aln.shape[1])+\
			" less than minimum ("+str(args.l)+")\n")


log.write("End time: "+str(datetime.now().strftime('%d-%m-%Y %H:%M'))+"\n")

log.close()
