
import numpy as np
import os
import tifffile

import argparse

from skimage.transform import resize
from skimage import morphology

from time import perf_counter
from monai.networks.nets import SwinUNETR
import monai.transforms as T    
import torch
import yaml


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description="Input Path to dir")

    parser.add_argument('filename',
                        help='input path',
                        default='')




    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print('Inference on: ', device)

    args = parser.parse_args()
    filename = args.filename

    arr = tifffile.imread(filename)

    model_file = './SnSNet_best_loss_model.pth'
    post_pred = T.AsDiscrete(to_onehot_y=True, threshold=0.6)
    post_sigmoid = T.Activations(sigmoid=True)
   
    with open("./SnSNet_config.yml", 'r') as f:
        model_config = yaml.load(f, yaml.FullLoader)

    model_config['INPUT_SHAPE'] = tuple((int(dim) for dim in model_config['INPUT_SHAPE'].strip('()').split(',')))
    model = SwinUNETR(in_channels=model_config['IN_CHANNELS'],
                      img_size=tuple(model_config['INPUT_SHAPE']),
                      out_channels=model_config['OUT_CHANNELS'],
                      feature_size=model_config['NUM_FILTERS'],
                      norm_name=model_config['NORMALIZATION_LAYER'],
                      use_checkpoint=True,
                      normalize=True).to(device)

    model.load_state_dict(torch.load(model_file))
    model.to(device)
    model.eval()
    orig_shape = arr.shape
    

    print("ROBUST NORMALIZATION AT LOAD")
    _iqr = np.subtract(*np.percentile(arr, [75, 25]))
    _median = np.median(arr)
    # print(_iqr, _median)
    image = (arr - _median)/_iqr

   

    image = resize(image, (128,128,128), preserve_range = True)
    
    image = np.expand_dims(image, axis = [0,1])
    start = perf_counter()
    totensor = T.ToTensor()
    image = totensor(image).cuda()
    pred = model(image)
    pred = post_pred(post_sigmoid(pred))
    pred_1 = pred[0][1].cpu().detach().numpy()
    end = perf_counter()

    pred_1 = morphology.dilation(pred_1,  morphology.ball(3))
    print('Mask Prediction #1: ', end - start)
    pred_1 = np.transpose(pred_1, [1,0,2])
    pred_1 = resize(pred_1, orig_shape, anti_aliasing=False, preserve_range = True)
    tifffile.imwrite(os.path.join('\\'.join(filename.split('\\')[:-1]), 'PRED_' + filename.split('\\')[-1]),pred_1.astype(int))
