In [None]:
import xarray as xr, numpy as np, pandas as pd
import torch, pathlib, yaml

In [None]:
with open('../input.yml') as f:
    input_data = yaml.load(f, Loader=yaml.loader.SafeLoader)

In [None]:
study_area = input_data['study_area']

year_range_train = range(input_data['training_period']['start'], input_data['training_period']['end']+1)
year_range_test = range(input_data['testing_period']['start'], input_data['testing_period']['end']+1)

gpu_enabled = input_data['gpu_enabled']

num_epochs = input_data['gan_hyperparameters']['num_epochs']
batch_size = input_data['gan_hyperparameters']['batch_size']
num_warmup_batches = input_data['gan_hyperparameters']['num_warmup_batches']
nc = input_data['gan_hyperparameters']['nc']
ngf = input_data['gan_hyperparameters']['ngf']
ndf = input_data['gan_hyperparameters']['ndf']
lrG = input_data['gan_hyperparameters']['lrG']
lrD = input_data['gan_hyperparameters']['lrD']
beta1 = input_data['gan_hyperparameters']['beta1']
beta2 = input_data['gan_hyperparameters']['beta2']

out_dir = pathlib.Path(input_data['data_directories']['output']).joinpath(input_data['study_area'], 'gan_gp_physics_sr')
out_dir.mkdir(parents=True, exist_ok=True)

model_out_dir = pathlib.Path(input_data['data_directories']['output']).joinpath(input_data['study_area'], 'gan_gp_physics_sr', 'model1')
model_out_dir.mkdir(parents=True, exist_ok=True)

In [None]:
device = torch.device("cuda:5" if (torch.cuda.is_available() & gpu_enabled) else "cpu")
Tensor = torch.cuda.FloatTensor if (torch.cuda.is_available() & gpu_enabled) else torch.Tensor

In [None]:
extreme_days_in_dir = pathlib.Path(input_data['data_directories']['output']).joinpath(input_data['study_area'], 'extreme_days')

with open(extreme_days_in_dir/'extreme_days_era.npy', 'rb') as f:
    extreme_days = np.load(f)

In [None]:
data_in_dir = pathlib.Path(input_data['data_directories']['output']).joinpath(input_data['study_area'], 'gp_sr')
with xr.open_dataset(data_in_dir/'era5_rainfall_sr.nc') as ds:
    era_rain_gp = ds.rainfall
    
data_in_dir = pathlib.Path(input_data['data_directories']['output']).joinpath(input_data['study_area'], 'orographic_rainfall')
with xr.open_dataset(data_in_dir/'era5_orographic_rainfall.nc') as ds:
    era_rain_topo = ds.rainfall

data_in_dir = pathlib.Path(input_data['data_directories']['output']).joinpath(input_data['study_area'], 'input_data')
with xr.open_dataset(data_in_dir/'era5_land_rainfall.nc') as ds:
    era_land_rain = ds.rainfall

In [None]:
era_rain_gp_train = era_rain_gp.where(era_rain_gp.time.isin(extreme_days[np.isin(pd.DatetimeIndex(extreme_days).year, year_range_train)]), drop=True)
era_rain_topo_train = era_rain_topo.where(era_rain_topo.time.isin(extreme_days[np.isin(pd.DatetimeIndex(extreme_days).year, year_range_train)]), drop=True)
era_land_rain_train = era_land_rain.where(era_land_rain.time.isin(extreme_days[np.isin(pd.DatetimeIndex(extreme_days).year, year_range_train)]), drop=True)

In [None]:
era_rain_gp_train_std = (era_rain_gp_train - np.nanmean(era_rain_gp_train)) / np.nanstd(era_rain_gp_train)
era_rain_topo_train_std = (era_rain_topo_train - np.nanmean(era_rain_topo_train)) / np.nanstd(era_rain_topo_train)
era_land_rain_train_std = (era_land_rain_train - np.nanmean(era_land_rain_train)) / np.nanstd(era_land_rain_train)

In [None]:
era_rain_gp_train_std = era_rain_gp_train_std.data[:, np.newaxis, :, :]
era_rain_topo_train_std = era_rain_topo_train_std.data[:, np.newaxis, :, :]
era_land_rain_train_std = era_land_rain_train_std.data[:, np.newaxis, :, :]

era_rain_train_std = np.concatenate((era_rain_gp_train_std, era_rain_topo_train_std), axis=1)

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __getitem__(self, index):
        img_x = torch.Tensor(self.x[index, :, :, :])
        img_y = torch.Tensor(self.y[index, :, :, :])
        
        return {"x": img_x, "y": img_y,}
    
    def __len__(self):
        return np.shape(self.x)[0]

In [None]:
train_dataloader = torch.utils.data.DataLoader(Dataset(x = era_rain_train_std, y = era_land_rain_train_std), batch_size=batch_size, shuffle=True)

In [None]:
class DenseResidualBlock(torch.nn.Module):
    """Dense Residual Block.
    Building block of Residual Dense Network.
    https://arxiv.org/abs/1802.08797

    Args:
        nf (int): Channel number of intermediate features.
            Default: ngf (size of feature map in generator)
        res_scale (float): Residual scale. Default: 0.2
    """
    def __init__(self, nf=ngf, res_scale=0.2):
        super(DenseResidualBlock, self).__init__()
        self.res_scale = res_scale

        def resblock(n_features, non_linearity=True):
            layers = [torch.nn.Conv2d(n_features, nf, 3, 1, 1, bias=True)]
            if non_linearity:
                layers += [torch.nn.LeakyReLU(0.2, inplace=True)]
            return torch.nn.Sequential(*layers)

        self.blocks = torch.nn.ModuleList([resblock(n_features=nf),
                       resblock(n_features=2*nf),
                       resblock(n_features=3*nf),
                       resblock(n_features=4*nf),
                       resblock(n_features=5*nf, non_linearity=False)])

    def forward(self, x):
        inputs = x
        for block in self.blocks:
            out = block(inputs)
            inputs = torch.cat([inputs, out], 1)
        return out.mul(self.res_scale) + x

In [None]:
class ResidualInResidualDenseBlock(torch.nn.Module):
    """Residual in Residual Dense Block.
    Used in Generator(RRDB-Net) in ESRGAN.
    https://arxiv.org/abs/1809.00219
    
    Args:
        nf (int): Channel number of intermediate features.
            Default: ngf (size of feature map in generator)
        res_scale (float): Residual scale. Default: 0.2
    """
    def __init__(self, nf=ngf, res_scale=0.2):
        super(ResidualInResidualDenseBlock, self).__init__()
        self.res_scale = res_scale
        self.dense_blocks = torch.nn.Sequential(
            DenseResidualBlock(nf), DenseResidualBlock(nf), DenseResidualBlock(nf)
        )

    def forward(self, x):
        return self.dense_blocks(x).mul(self.res_scale) + x

In [None]:
class Generator(torch.nn.Module):
    """Generator network consisting of Residual in Residual Dense block.
    Used in ESRGAN. https://arxiv.org/abs/1809.00219
    
    Args:
        num_in_ch (int): Channel number of inputs. Default: nc
        num_out_ch (int): Channel number of outputs. Default: nc
        nf (int): Channel number of intermediate features. Default: ngf
        num_res_blocks (int): Number of Residual in Residual Dense block. Default: 16
    """
    def __init__(self, num_in_ch=nc, num_out_ch=nc, nf=ngf, num_res_blocks=16):
        super(Generator, self).__init__()

        # First convolutional layer
        self.conv1 = torch.nn.Conv2d(num_in_ch, nf, 3, 1, 1)
        # Residual blocks
        self.res_blocks = torch.nn.Sequential(*[ResidualInResidualDenseBlock(nf) for _ in range(num_res_blocks)])
        # Second convolutional layer post residual blocks
        self.conv2 = torch.nn.Conv2d(nf, nf, 3, 1, 1)
        # Final convolutional block
        self.conv3 = torch.nn.Sequential(
            torch.nn.Conv2d(nf, nf, 3, 1, 1),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Conv2d(nf, num_out_ch, 3, 1, 1),
        )

    def forward(self, x):
        out1 = self.conv1(x)
        out = self.res_blocks(out1)
        out2 = self.conv2(out)
        out = torch.add(out1, out2)
        out = self.conv3(out)
        return out

In [None]:
class Discriminator(torch.nn.Module):
    """VGG style discriminator, used for training ESRGAN."""
    #No argument because everything is hardcoded.
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = torch.nn.Sequential(
            # input size is (nc) x 10 x 10
            torch.nn.Conv2d(nc, ndf, 3, 1, 1),
            torch.nn.LeakyReLU(0.2, inplace=True),
            # input size is (ndf) x 10 x 10
            torch.nn.Conv2d(ndf, ndf, 4, 2, 1, bias=False),
            torch.nn.InstanceNorm2d(ndf),
            torch.nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 5 x 5
            torch.nn.Conv2d(ndf, ndf*2, 3, 1, 1, bias=False),
            torch.nn.InstanceNorm2d(ndf*2),
            torch.nn.LeakyReLU(0.2, inplace=True),
            # input size is (ndf*2) x 5 x 5
            torch.nn.Conv2d(ndf*2, ndf*2, 3, 2, 1, bias=False),
            torch.nn.InstanceNorm2d(ndf*2),
            torch.nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 3 x 3
            torch.nn.Conv2d(ndf*2, ndf*4, 3, 1, 1, bias=False),
            torch.nn.InstanceNorm2d(ndf*4),
            torch.nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 3 x 3
            torch.nn.Flatten(),
            torch.nn.Linear(ndf*4*3*3, 100),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Linear(100, 1)
            # state size. 1
        )

    def forward(self, img):
        return self.model(img)

In [None]:
# Losses
criterion_GAN = torch.nn.BCEWithLogitsLoss().to(device)
criterion_pixel = torch.nn.L1Loss().to(device)

# Intialise empty Generator and Discriminator loss
G_losses = []
D_losses = []

In [None]:
# Initialize generator and discriminator
netG = Generator(num_in_ch=2).to(device)
netD = Discriminator().to(device)

In [None]:
optimizerG = torch.optim.Adam(netG.parameters(), lr=lrG, betas=(beta1, beta2))
optimizerD = torch.optim.Adam(netD.parameters(), lr=lrD, betas=(beta1, beta2))

In [None]:
if input_data['train']:
    for epoch in range(1, num_epochs+1):
        for i, d in enumerate(train_dataloader):

            batches_done = (epoch-1) * len(train_dataloader) + i + 1

            # Configure model input
            lr = d["x"].type(Tensor).to(device)
            hr = d["y"].type(Tensor).to(device)
            b_size = hr.shape[0]

            # The real sample label is 1, and the generated sample label is 0.
            real_label = Tensor(np.ones((b_size, 1))).to(device)
            fake_label = Tensor(np.zeros((b_size, 1))).to(device)

            # ------------------
            #  Train Generator
            # ------------------
            # Generate a super-resolution image from low resolution input
            sr = netG(lr)

            # Measure pixel-wise loss against ground truth
            loss_pixel = criterion_pixel(sr, hr)

            optimizerG.zero_grad()
            if batches_done <= num_warmup_batches:
                # Warm-up (pixel-wise loss only)
                loss_pixel.backward()
                optimizerG.step()
                continue

            # Run Discriminator on real and fake samples
            real_out = netD(hr).detach()
            fake_out = netD(sr)

            # Adversarial loss (relativistic average GAN)
            loss_GAN = criterion_GAN(fake_out - real_out.mean(0, keepdim=True), real_label) #fake is real for Generator

            # Total Generator loss
            lossG = 0.01 * loss_GAN + 0.99 * loss_pixel

            lossG.backward()
            optimizerG.step()

            # ---------------------
            #  Train Discriminator
            # ---------------------
            # Run Discriminator on real and fake samples
            sr = netG(lr)
            real_out = netD(hr)
            fake_out = netD(sr.detach())

            # Adversarial loss for real and fake images (relativistic average GAN)
            lossD_real = criterion_GAN(real_out - fake_out.mean(0, keepdim=True), real_label)
            lossD_fake = criterion_GAN(fake_out - real_out.mean(0, keepdim=True), fake_label)

            # Total Discriminator loss
            lossD = lossD_real + lossD_fake

            # Update Discriminator
            optimizerD.zero_grad()
            lossD.backward()
            optimizerD.step()

            # Save Losses for plotting later
            G_losses.append(lossG.item())
            D_losses.append(lossD.item())
        
        # Log progress every 10 epochs
        if ((epoch % 10 == 0) & (batches_done > num_warmup_batches)):
            print("[Epoch %d/%d] [D loss: %f] [G loss: %f, adv: %f, pixel: %f]" % (epoch, num_epochs, lossD.item(), lossG.item(), loss_GAN.item(), loss_pixel.item()))

        # Save model states every 100 epochs
        if epoch % 100 == 0:
            print("Saving model states...")
            torch.save(netG.state_dict(), model_out_dir/('G' + str(epoch).zfill(4) + '.pt'))
            torch.save(netD.state_dict(), model_out_dir/('D' + str(epoch).zfill(4) + '.pt'))

In [None]:
netG_trained = Generator(num_in_ch=2)
if input_data['train']:
    netG_trained.load_state_dict(torch.load(model_out_dir/'G3000.pt'))
else:
    netG_trained.load_state_dict(torch.load('gan1_gp_upslope.pt', map_location=device))

In [None]:
#use the mean and std dev of the training set
era_rain_gp_std = (era_rain_gp - np.nanmean(era_rain_gp_train)) / np.nanstd(era_rain_gp_train)
era_rain_topo_std = (era_rain_topo - np.nanmean(era_rain_topo_train)) / np.nanstd(era_rain_topo_train)
era_land_rain_std = (era_land_rain - np.nanmean(era_land_rain_train)) / np.nanstd(era_land_rain_train)

In [None]:
era_rain_gp_std = era_rain_gp_std.data[:, np.newaxis, :, :]
era_rain_topo_std = era_rain_topo_std.data[:, np.newaxis, :, :]
era_land_rain_std = era_land_rain_std.data[:, np.newaxis, :, :]

era_rain_total_std = np.concatenate((era_rain_gp_std, era_rain_topo_std), axis=1)

In [None]:
full_dataloader = torch.utils.data.DataLoader(Dataset(x = era_rain_total_std, y = era_land_rain_std), batch_size=np.shape(era_rain_gp)[0], shuffle=False)

In [None]:
lr = next(iter(full_dataloader))["x"]
hr = next(iter(full_dataloader))["y"]
sr = netG_trained(lr)
sr = sr.mul(np.nanstd(era_land_rain_train.data)).add(np.mean(era_land_rain_train.data))
sr[sr<0]=0
sr_np = sr.detach().numpy()

In [None]:
with xr.open_dataset(data_in_dir/'era5_land_rainfall.nc') as ds1:

    ds = xr.Dataset(
        data_vars=dict(rainfall=(["time", "lon", "lat"], np.squeeze(sr_np))),
        coords=dict(lon=ds1.lon.data, lat=ds1.lat.data, time=ds1.time),
        attrs=dict(description="ERA5 daily rainfall downscaled with GP+Physics+GAN"),
    )

    ds.to_netcdf(out_dir/"era5_sr1.nc")

In [None]:
era_rain_gp_test = era_rain_gp.where(era_rain_gp.time.isin(extreme_days[np.isin(pd.DatetimeIndex(extreme_days).year, year_range_test)]), drop=True)
era_rain_topo_test = era_rain_topo.where(era_rain_topo.time.isin(extreme_days[np.isin(pd.DatetimeIndex(extreme_days).year, year_range_test)]), drop=True)
era_land_rain_test = era_land_rain.where(era_land_rain.time.isin(extreme_days[np.isin(pd.DatetimeIndex(extreme_days).year, year_range_test)]), drop=True)

In [None]:
#use the mean and std dev of the training set
era_rain_gp_test_std = (era_rain_gp_test - np.nanmean(era_rain_gp_train)) / np.nanstd(era_rain_gp_train)
era_rain_topo_test_std = (era_rain_topo_test - np.nanmean(era_rain_topo_train)) / np.nanstd(era_rain_topo_train)
era_land_rain_test_std = (era_land_rain_test - np.nanmean(era_land_rain_train)) / np.nanstd(era_land_rain_train)

In [None]:
era_rain_gp_test_std = era_rain_gp_test_std.data[:, np.newaxis, :, :]
era_rain_topo_test_std = era_rain_topo_test_std.data[:, np.newaxis, :, :]
era_land_rain_test_std = era_land_rain_test_std.data[:, np.newaxis, :, :]

era_rain_total_test_std = np.concatenate((era_rain_gp_test_std, era_rain_topo_test_std), axis=1)

In [None]:
test_dataloader = torch.utils.data.DataLoader(Dataset(x = era_rain_total_test_std, y = era_land_rain_test_std), batch_size=np.shape(era_rain_gp_test)[0], shuffle=False)

In [None]:
lr = next(iter(test_dataloader))["x"]
hr = next(iter(test_dataloader))["y"]
sr = netG_trained(lr)
sr = sr.mul(np.nanstd(era_land_rain_train.data)).add(np.mean(era_land_rain_train.data))
sr[sr<0]=0
sr_np = sr.detach().numpy()

In [None]:
import matplotlib.pyplot as plt

In [None]:
fig = plt.figure(figsize=(14,6))
n = np.random.randint(0, np.shape(lr)[0])
#n = 36

ax = fig.add_subplot(131)
plt.imshow(era_rain_gp_test[n,:,:], vmin=0, cmap=plt.get_cmap("gist_earth_r"))
plt.colorbar(orientation="horizontal")
ax.set_xticks([])
ax.set_yticks([])
ax.set_title('GP SR')

ax = fig.add_subplot(132)
plt.imshow(sr_np[n,0,:,:], vmin=0, cmap=plt.get_cmap("gist_earth_r"))
plt.colorbar(orientation="horizontal")
ax.set_xticks([])
ax.set_yticks([])
ax.set_title('GP+Physics+GAN SR')

ax = fig.add_subplot(133)
plt.imshow(era_land_rain_test[n,:,:], vmin=0, cmap=plt.get_cmap("gist_earth_r"))
plt.colorbar(orientation="horizontal")
ax.set_xticks([])
ax.set_yticks([])
ax.set_title('HR Ground Truth')

plt.tight_layout()