#!/usr/bin/env python

import time, itertools, torch

import datetime, os, shutil, argparse, sys, pysam

from src import utils
    
if __name__ == '__main__':
    
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--print_models", help='Print details of models available', default=False, action='store_true') 
    main_subparsers = parser.add_subparsers(title="Options", dest="option")
    
    parent_parser = argparse.ArgumentParser(add_help=False,)
    parent_parser.add_argument("--prefix", help='Prefix for the output files',type=str, default='output')
    parent_parser.add_argument("--output", help= 'Path to folder where intermediate and final files will be stored, default is current working directory', type=str)
    
    
    parent_parser.add_argument("--qscore_cutoff", help='Minimum cutoff for mean quality score of a read',type=float, default=0)
    parent_parser.add_argument("--length_cutoff", help='Minimum cutoff for read length',type=int, default=0)
    
    parent_parser.add_argument("--mod_t",  help=  'Probability threshold for a per-read prediction to be considered modified. Only predictiond with probability >= mod_t will be considered as modified for calculation of per-site modification levels.', default=0.5, type=float)

    parent_parser.add_argument("--unmod_t",  help=  'Probability threshold for a per-read prediction to be considered unmodified. Only predictiond with probability < unmod_t will be considered as unmodified for calculation of per-site modification levels.', default=0.5, type=float)
    
    parent_parser.add_argument("--include_non_cpg_ref",  help='Include non-CpG reference loci in per-site output where reads have CpG motif.',default=False, action='store_true')
    
    detect_parser = main_subparsers.add_parser("detect", parents=[parent_parser],
                                      add_help=True,
                                      help="Call methylation from Guppy or Dorado basecalled POD5/FAST5 files using move tables for signal alignment.",  formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    
    detect_required=detect_parser.add_argument_group("Required Arguments")
    
    detect_parser.add_argument("--threads", help='Number of threads to use for processing signal and running model inference. If a GPU is used for inference, then --threads number of threads will be running on GPU concurrently. The total number of threads used by DeepMod2 is equal to --threads plus --bam_threads. It is recommended to run DeepMod2 with mutliple cores, and use at least 4 bam_threads for compressing BAM file.',type=int, default=4)
    detect_parser.add_argument("--ref", help='Path to reference FASTA file to anchor methylation calls to reference loci. If no reference is provided, only the motif loci on reads will be used.', type=str)
    
    
    detect_parser.add_argument("--model", help='Name of the model. Default model is "guppy_R9.4.1" meant to be used for R9.4.1 flowcell reads, but you can provide a path to your own model. Use --print_models to display all models available. Use "guppy_R10.4.1" model for R10.4.1 flowcell reads.',type=str, required=True)
    
    detect_parser.add_argument("--supplementary", help='Analyse supplementary reads in addition to primary reads, recommended for RNA reads.', default=False, action='store_true')
    
    detect_required.add_argument("--bam", help='Path to aligned or unaligned BAM file. It is ideal to have move table in BAM file but move table from FAST5 fies can also be used. Aligned BAM file is required for reference anchored methylation calls, otherwise only the motif loci on reads will be called.', type=str, required=True)
    
    detect_required.add_argument("--file_type", help='Specify whether the signal is in FAST5 or POD5 file format. If POD5 file is used, then move table must be in BAM file.',choices=['fast5','pod5'], type=str, default='fast5',required=True)
    
    detect_required.add_argument("--input", help='Path to POD5/FAST5 file or folder containing POD5/FAST5 files. If folder provided, then POD5/FAST5 files will be recusrviely searched', type=str, required=True)
    
    detect_parser.add_argument("--guppy_group", help='Name of the guppy basecall group',type=str, default='Basecall_1D_000')
    detect_parser.add_argument("--chrom", nargs='*',  help='A space/whitespace separated list of contigs, e.g. chr3 chr6 chr22. If not list is provided then all chromosomes in the reference are used.')
    
    detect_parser.add_argument("--fast5_move", help='Use move table from FAST5 file instead of BAM file. If this flag is set, specify a basecall group for FAST5 file using --guppy_group parameter and ensure that the FAST5 files contains move table.', default=False, action='store_true')
    
    detect_parser.add_argument("--skip_per_site", help='Skip per site output', default=False, action='store_true')
    detect_parser.add_argument("--device", help='Device to use for running pytorch models. you can set --device=cpu for cpu, or --device=cuda for GPU. You can also specify a particular GPU device such as --device=cuda:0 or --device=cuda:1 . If --device paramater is not set by user, then GPU will be used if available otherwise CPU will be used.', type=str)
    detect_parser.add_argument("--disable_pruning", help='Disable model pruning (not recommended for CPU inference). By default models are pruned to remove some weights with low L1 norm in linear layers. Pruning has little effect on model accuracy, it can signifcantly improve CPU inference time but not GPU inference time.', default=False, action='store_true')
    
    detect_parser.add_argument("--exclude_ref_features", help='Exclude reference sequence from feature matrix. By default, if a reference FASTA file is provided via --ref parameter, then the reference sequence is added as a feature for aligned reads, but not if a read is unmapped or if no reference is provided.', default=False, action='store_true')
    detect_parser.add_argument("--batch_size", help='Batch size to use for GPU inference. For CPU inference, batch size is fixed at 512.',type=int, default=1024)
    
    detect_parser.add_argument("--bam_threads", help='Number of threads to use for compressed BAM output. Setting it lower than 4 can significantly lower the runtime.',type=int, default=4)
    detect_parser.add_argument("--skip_unmapped", help='Skip unmapped reads from methylation calling', default=False, action='store_true')
    
    merge_parser = main_subparsers.add_parser("merge", parents=[parent_parser],
                                      add_help=True,
                                      help="Merge per-read calls into per-site calls",  formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    
    merge_parser.add_argument("--input",   nargs='*', help= 'List of paths of per-read methylation calls to merge. File paths should be separated by space/whitespace. Use either --input or --list argument, but not both.')
    merge_parser.add_argument("--list",  help=  'A file containing paths to per-read methylation calls to merge (one per line). Use either --inputs or --list argument, but not both.', type=str)    
    
    if len(sys.argv)==1:
        parser.print_help()
        parser.exit()
        
    
    elif len(sys.argv)==2:
        if sys.argv[1]=='merge':
            merge_parser.print_help()
            merge_parser.exit()
        
        elif sys.argv[1]=='detect':
            detect_parser.print_help()
            detect_parser.exit()

    args = parser.parse_args()
    
    
    if args.print_models:
        utils.get_model_help()
        parser.exit()
        
    t=time.time()

    print('%s: Starting DeepMod2.' %str(datetime.datetime.now()), flush=True)
            
    if not args.output:
        args.output=os.getcwd()
    
    os.makedirs(args.output, exist_ok=True)

    if args.option=='merge':
        if args.input:
            input_list= args.input
        
        elif args.list:
            with open(args.list,'r') as file_list:
                input_list=[x.rstrip('\n') for x in file_list.readlines()]
        
        params={'output':args.output, 'prefix':args.prefix, 'qscore_cutoff':args.qscore_cutoff,
                'length_cutoff':args.length_cutoff, 'mod_t':args.mod_t, 
                'unmod_t':args.unmod_t,'include_non_cpg_ref':args.include_non_cpg_ref}
        
        site_pred_file=utils.get_per_site(params, input_list)
        
    else:        
        if args.chrom:
            chrom_list=args.chrom
        else:
            chrom_list=pysam.AlignmentFile(args.bam,'rb',check_sq=False).references

        if args.device:
            dev=args.device
        else:
            if torch.cuda.is_available():  
                dev = "cuda" 
            else:
                dev = "cpu"
        params={'input':args.input, 'output':args.output, 'threads':args.threads, 
                'prefix':args.prefix, 'window':10, 'model':args.model, 
                'qscore_cutoff':args.qscore_cutoff, 'ref':args.ref,
                'length_cutoff':args.length_cutoff, 'bam':args.bam,
                'file_type':args.file_type, 'fast5_move':args.fast5_move,
                'guppy_group':args.guppy_group, 'supplementary': args.supplementary,  
                'mod_t':args.mod_t, 'unmod_t':args.unmod_t, 'include_non_cpg_ref': args.include_non_cpg_ref,
                'skip_per_site':args.skip_per_site, 'chrom_list':chrom_list, "dev":dev,
                'disable_pruning':args.disable_pruning, 'batch_size':args.batch_size, 
                'exclude_ref_features':args.exclude_ref_features,'bam_threads':args.bam_threads,
                'skip_unmapped':args.skip_unmapped}

        print('\n%s: \nCommand: python %s\n' %(str(datetime.datetime.now()), ' '.join(sys.argv)), flush=True)

        with open(os.path.join(args.output,'args'),'w') as file:
            file.write('Command: python %s\n\n\n' %(' '.join(sys.argv)))
            file.write('------Parameters Used For Running DeepMod2------\n')
            for k in vars(args):
                file.write('{}: {}\n'.format(k,vars(args)[k]) )

        from src import detect
        detect.call_manager(params)
    
    print('\n%s: Time elapsed=%.4fs' %(str(datetime.datetime.now()),time.time()-t), flush=True)
