#!/usr/bin/env python3

import argparse
import numpy
import random
import os
import sys
sys.path.append('tools')
from readfasta import read_record

parser = argparse.ArgumentParser(
	description='Synthetic data maker for testing rocketchip')
parser.add_argument('name', type=str, metavar='<name>',
	help='base name of files to create (*.fa *.fastq)')
parser.add_argument('stacks', type=int, metavar='<stacks>',
	help='number of read stacks')
parser.add_argument('rps', type=int, metavar='<rps>',
	help='number of reads per stack')
parser.add_argument('-o', type=str, metavar='<ouput_dir>', required=False,
	default='.', help='Output directory')
parser.add_argument('--genome', type=str, metavar='<string>', required=False,
	help='Path to existing genome. Random genome generate otherwise.')
parser.add_argument('--control', type=str, metavar='<string>', required=False,
	help='base name of files to create control (*.fa *.fastq)')
parser.add_argument('--padding', type=int, metavar='<int>', required=False,
	default=1000, help='spacing between read stacks [%(default)i]')
parser.add_argument('--stddev', type=float, metavar='<float>', required=False,
	default=0.1, help='scatter of reads [%(default).3f]')
parser.add_argument('--width', type=int, metavar='<int>', required=False,
	default=0, help='width of read stack [%(default)i]')
parser.add_argument('--length', type=int, metavar='<int>', required=False,
	default=100, help='length of reads [%(default)i]')
parser.add_argument('--paired', type=int, metavar='<int>', required=False,
	default=0, help='generate paired reads spanning <int> nt')
parser.add_argument('--seed', type=int, metavar='<int>', required=False,
	help='use random seed')
parser.add_argument('--flank', type=int, metavar='<int>', required=False,
	default=1000, help='flanking sequence [%(default)i]')
arg = parser.parse_args()

if arg.seed:
	random.seed(arg.seed)
	numpy.random.seed(arg.seed)
if arg.o != '.':
	os.makedirs(arg.o, exist_ok=True)
    
###############################
## Create positions of reads ##
###############################

coor = arg.flank
locs = [] # location of each read (on center I suppose)
peaks = []

for i in range(arg.stacks):
	peaks.append(arg.flank + i * arg.padding)
	for r in numpy.random.normal(0, arg.stddev, arg.rps):
		pos = int(r*arg.length)
		if arg.width > 0: pos = random.randint(pos-arg.width/2, pos+arg.width/2)
		locs.append(coor + pos)
	coor += arg.padding
coor += arg.flank

###########################
## Create PCR duplicates ##
###########################

duplicate_rate = 0.2  
duplicated_reads = int(arg.stacks * arg.rps * duplicate_rate)

for _ in range(duplicated_reads):
    # Randomly select a read and duplicate it
    original_read = random.choice(locs)
    locs.append(original_read)

###################
## Create genome ##
###################

genome = ''
if arg.genome:
	for idn, seq in read_record(arg.genome): genome = seq
	if len(genome) < coor: 
		print('Error: Genome size too small!')
		sys.exit() 
else:
	with open(f'{arg.o}/{arg.name}.fa', 'w') as fp:
		fp.write(f'>{arg.name}\n')
		count = 0
		line = 80
		for i in range(coor):
			count += 1
			nt = random.choice('ACGT')
			fp.write(nt)
			if count % line == 0: fp.write('\n')
			genome += nt

##################
## Create fastq ##
##################

def write_fastq(fp, n, seq):
	fp.write(f'@read.{n}\n{seq}\n+\n')
	fp.write('F' * len(seq))
	fp.write('\n')

def revcomp(seq):
	comp = str.maketrans('ACGTRYMKWSBDHVN', 'TGCAYRKMWSVHDBN')
	anti = seq.translate(comp)[::-1]
	return anti


half = arg.length // 2 # half single
n = 0

if arg.paired:
	fp1 = open(f'{arg.o}/{arg.name}_1.fastq', 'w')
	fp2 = open(f'{arg.o}/{arg.name}_2.fastq', 'w')
else:
	fp = open(f'{arg.o}/{arg.name}.fastq', 'w')

for pos in locs:
	n += 1
	if arg.paired:
		p1 = pos - arg.paired // 2
		p2 = pos + arg.paired // 2
		seq1 = genome[p1-half:p1+half]
		seq2 = revcomp(genome[p2-half:p2+half])
		write_fastq(fp1, n, seq1)
		write_fastq(fp2, n, seq2)
	else:
		if random.getrandbits(1):
			seq = genome[pos-half:pos+half]
		else:
			seq = revcomp(genome[pos-half:pos+half])
		write_fastq(fp, n, seq)

####################
## Create control ##
####################

if arg.control:
	totrds = arg.stacks * arg.rps
	if arg.paired:
		fc1 = open(f'{arg.o}/{arg.control}_1.fastq', 'w')
		fc2 = open(f'{arg.o}/{arg.control}_2.fastq', 'w')
	else:
		fc = open(f'{arg.o}/{arg.control}.fastq', 'w')

	n = 0
	for i in range(totrds):
		pos = random.randint(arg.padding, coor-arg.padding)
		n+=1
		if arg.paired:
			p1 = pos - arg.paired // 2
			p2 = pos + arg.paired // 2
			seq1 = genome[p1-half:p1+half]
			seq2 = revcomp(genome[p2-half:p2+half])
			write_fastq(fc1, n, seq1)
			write_fastq(fc2, n, seq2)
		else:
			if random.getrandbits(1):
				seq = genome[pos-half:pos+half]
			else:
				seq = revcomp(genome[pos-half:pos+half])
			write_fastq(fc, n, seq)

#####################
# Create Peaks File #
#####################

with open(f'{arg.o}/{arg.name}.peaks', 'w') as fp:
	for peak in peaks:
		fp.write(f'{peak}\n')
