import tifffile
import networks
import torch

## Build the model
gpu_ids = [0] if torch.cuda.is_available() else [-1]
device = 'gpu' if torch.cuda.is_available() else 'cpu'

UNetlive = networks.define_S(1, 2, 64, 'unet_128', 'batch', False, 'normal', 0.02, gpu_ids)

## Load the pretrained weights
UNetlive.module.load_state_dict(torch.load('pretrained_Unet_live.pth'))

## Test on a sample image
import tifffile

file = '20201216_Auto_Select_MultiRegioncs4_ROI2_initial.tif'
sted = tifffile.imread(file)[1]

# Normalize
sted = ((sted.astype('float32')-sted.min())/(sted.max()-sted.min()))*2-1

# Crop [500 x 500] -> [384 x 384] (all dimensions must be multiples of 128)
sted = sted[58:-58,58:-58]

# Convert to tensor
sted = torch.tensor(sted).unsqueeze(0).unsqueeze(0)

# Predict
pred = UNetlive(sted)

# Save results (raw predictions and thresholded predictions)
pred = pred.cpu().detach().numpy()
tifffile.imsave(file.replace('.tif', '_seg.tif'), pred)
tifffile.imsave(file.replace('.tif', '_seg_thresh.tif'), (((pred + 1.0) / 2.0 * 255.0)>60).astype('uint8')*255)
