import os
import argparse

def get_cmd(task, sub_task, model_tag, gpu, data_num, bs, lr, source_length, target_length, patience, epoch, warmup,
            model_dir, summary_dir, res_fn, load_model_dir, tag_suffix):
    cmd_str = 'bash exp_with_args.sh %s %s %s %d %d %d %d %d %d %d %d %d %s %s %s %s %s' % \
              (task, sub_task, model_tag, gpu, data_num, bs, lr, source_length, target_length, patience, epoch,
               warmup, model_dir, summary_dir, res_fn, load_model_dir, tag_suffix)
    return cmd_str


def get_args_by_task_model(task, sub_task, model_tag):
    if task == 'bigfixes_task':
        src_len = 512
        trg_len = 512
        epoch = 30
        patience =30
    elif task == 'bigfixes_base':
        src_len = 512
        trg_len = 512
        epoch = 30
        patience =30

    bs = 4
    lr = 5

    return bs, lr, src_len, trg_len, patience, epoch


def run_one_exp(args):
    bs, lr, src_len, trg_len, patience, epoch = get_args_by_task_model(args.task, args.sub_task, args.model_tag)
    print('============================Start Running==========================')
    cmd_str = get_cmd(task=args.task, sub_task=args.sub_task, model_tag=args.model_tag, gpu=args.gpu,
                      data_num=args.data_num, bs=bs, lr=lr, source_length=src_len, target_length=trg_len,
                      patience=patience, epoch=epoch, warmup=1000,
                      model_dir=args.model_dir, summary_dir=args.summary_dir,
                      res_fn='{}/{}_{}.txt'.format(args.res_dir, args.task, args.model_tag),
                      load_model_dir=args.load_model_dir, tag_suffix=args.tag_suffix)
    print('%s\n' % cmd_str)
    os.system(cmd_str)


def get_sub_tasks(task):
    if task in ['bigfixes_task',"bigfixes_base"]:
        sub_tasks = ['none']
    return sub_tasks


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_tag", type=str, default='codet5_small')
    parser.add_argument("--task", type=str, default='bigfixes_task',
                        choices=['bigfixes_task',"bigfixes_base"])
    parser.add_argument("--sub_task", type=str, default='none')
    parser.add_argument("--res_dir", type=str, default='results', help='directory to save fine-tuning results')
    parser.add_argument("--model_dir", type=str, default='saved_models', help='directory to save fine-tuned models')
    parser.add_argument("--summary_dir", type=str, default='tensorboard', help='directory to save tensorboard summary')
    parser.add_argument("--data_num", type=int, default=-1, help='number of data instances to use, -1 for full data')
    parser.add_argument("--gpu", type=int, default=0, help='index of the gpu to use in a cluster')
    parser.add_argument("--load_model_dir", default='None', type=str, help="Path to trained model: Should contain the .bin files")
    parser.add_argument("--tag_suffix", default='finetune', type=str,
                        help="Experiment full model tag suffix")

    args = parser.parse_args()

    if not os.path.exists(args.res_dir):
        os.makedirs(args.res_dir)

    assert args.sub_task in get_sub_tasks(args.task)

    run_one_exp(args)
