import argparse
import torch
import pytorch_lightning as pl

from src.config import read_config
from src.pipeline import run_train_pipeline
from src.pipeline import run_train_inference_pipeline

def run(args: argparse.Namespace):
    task_config = read_config(args.task_config_path)
    pl.seed_everything(task_config["seed"])

    torch.set_float32_matmul_precision("high")

    if args.task_type == "train":
        run_train_pipeline(task_config)
    elif args.task_type == "train_inference":
        run_train_inference_pipeline(task_config)
    else:
        raise ValueError("Task type is not supported.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-t", "--task_type", type=str, required=True)
    parser.add_argument("-c", "--task_config_path", type=str, required=True)
    args = parser.parse_args()
    run(args)
