{
 "cells": [
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Step 0: Import Necessary Libraries",
   "id": "f343d2332bfbdc0"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "import ee\n",
    "import requests\n",
    "import io\n",
    "import numpy as np\n",
    "import os\n",
    "import torch\n",
    "from torch import Tensor, nn\n",
    "import copy\n",
    "import numpy as np\n",
    "import torch\n",
    "import os\n",
    "from tqdm import tqdm"
   ],
   "id": "bf55841506780334"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Step 1: Preparing RS/MO/DEM Data (GEE)",
   "id": "9569e224b33cd68"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# BERTH's Default Band Order\n",
    "columns_mo = ['ws_u_10', 'ws_v_10', 't_dewpoint', 't_air', 't_skin','pressure', 'down_radiation_short', 'down_radiation_long']\n",
    "columns_rs = ['blue', 'green', 'red', 'nir', 'swir1', 'swir2']\n",
    "columns_terrain = ['dem', 'aspect', 'slope']\n",
    "columns_hydro = ['precipitation', 'soilmoisture', 'evapotranspiration', 'runoff']\n",
    "columns_berth = columns_rs + columns_mo + columns_terrain\n",
    "\n",
    "# GEE\n",
    "ee.Authenticate()\n",
    "ee.Initialize(project='ee-zhaoyuan-yao')\n",
    "\n",
    "# Export Setting\n",
    "region = {\"west\": 38.2378, \"north\": 30.1577, \"south\": 30.0446, \"east\": 38.4369}\n",
    "region_name = 'Example_Hail_Arabia'\n",
    "region_ee = ee.Geometry.BBox(region['west'], region['south'], region['east'], region['north'])\n",
    "TIME_START = ee.Date('2017-01-01')\n",
    "TIME_END = ee.Date('2020-12-31')\n",
    "TIME_LENGTH = TIME_END.difference(TIME_START, 'day').getInfo() + 1"
   ],
   "id": "e0fd54de70caa70e"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# Download DEM via GEE\n",
    "def preprocess_terrain():\n",
    "    dem = ee.ImageCollection('COPERNICUS/DEM/GLO30').select('DEM').mosaic().setDefaultProjection('EPSG:3857', None, 30).resample('bilinear').rename('dem')\n",
    "    slope = ee.Terrain.slope(dem).rename('slope')\n",
    "    aspect = ee.Terrain.aspect(dem).rename('aspect')\n",
    "    return dem.unmask(-1000).rename('dem').addBands(aspect.unmask(-1)).addBands(slope.unmask(-1))\n",
    "\n",
    "url = ee.Image(preprocess_terrain()).select(columns_terrain).float().getDownloadURL({\"region\": region_ee, \"scale\": 30, \"crs\": \"EPSG:4326\", \"format\": \"NPY\"})\n",
    "response = requests.get(url)\n",
    "response.raise_for_status()\n",
    "terrain_data = np.load(io.BytesIO(response.content), allow_pickle=True)\n",
    "np.save(f'{region_name}_30m_Terrain_GLO30.npy', terrain_data.view((np.float32, len(columns_terrain))))\n",
    "print(terrain_data.shape)"
   ],
   "id": "2c8ab44d13fd33eb"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# Download Atmospheric Forcing via GEE\n",
    "def preprocess_era5land(date_offset):\n",
    "    date = TIME_START.advance(date_offset, 'day')\n",
    "    date_filter = ee.Filter.date(date.advance(-0.1, 'day'), date.advance(1, 'day'))\n",
    "    mo_image = ee.ImageCollection('ECMWF/ERA5_LAND/DAILY_AGGR').filter(date_filter).first()\n",
    "    ws_u_10 = mo_image.select('u_component_of_wind_10m').divide(20.0)\n",
    "    ws_v_10 = mo_image.select('v_component_of_wind_10m').divide(20.0)\n",
    "    t_dewpoint = mo_image.select('dewpoint_temperature_2m').subtract(200).divide(150.0)\n",
    "    t_air = mo_image.select('temperature_2m').subtract(200).divide(150.0)\n",
    "    t_skin = mo_image.select('skin_temperature').subtract(200).divide(150.0)\n",
    "    p = mo_image.select('surface_pressure').divide(1000).subtract(70).divide(80.0)\n",
    "    down_short = mo_image.select('surface_solar_radiation_downwards_sum').multiply(1.1574E-5).divide(1000)\n",
    "    down_long  = mo_image.select('surface_thermal_radiation_downwards_sum').multiply(1.1574E-5).divide(1000)\n",
    "    data = ws_u_10.rename('ws_u_10')\\\n",
    "        .addBands(ws_v_10.rename('ws_v_10'))\\\n",
    "        .addBands(t_dewpoint.rename('t_dewpoint'))\\\n",
    "        .addBands(t_air.rename('t_air'))\\\n",
    "        .addBands(t_skin.rename('t_skin'))\\\n",
    "        .addBands(p.rename('pressure'))\\\n",
    "        .addBands(down_short.rename('down_radiation_short'))\\\n",
    "        .addBands(down_long.rename('down_radiation_long'))\n",
    "    return data.resample('bilinear')\n",
    "if not os.path.exists(f'{region_name}_30m_MO'):\n",
    "    os.mkdir(f'{region_name}_30m_MO')\n",
    "for day_offset in range(TIME_LENGTH):\n",
    "    url = ee.Image(preprocess_era5land(day_offset)).select(columns_mo).float().getDownloadURL({\"region\": region_ee, \"scale\": 30, \"crs\": \"EPSG:4326\", \"format\": \"NPY\"})\n",
    "    response = requests.get(url, stream=True)\n",
    "    response.raise_for_status()\n",
    "    mo_data = np.load(io.BytesIO(response.content), allow_pickle=True)\n",
    "    np.save(f'{region_name}_30m_MO/{region_name}_30m_MO_{day_offset}.npy', mo_data.view((np.float32, len(columns_mo))))\n",
    "    print(f'day:{day_offset}, MO, {mo_data.shape}')"
   ],
   "id": "dac632be46368a15"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# Download Surface Reflectance via GEE\n",
    "def preprocess_single_LS89(image):\n",
    "    dilatedCloudBitMask = (1 << 1)\n",
    "    cirrusBitMask = (1 << 2)\n",
    "    cloudBitMask = (1 << 3)\n",
    "    cloudShadowBitMask = (1 << 4)\n",
    "    qa = image.select('QA_PIXEL')\n",
    "    mask = qa.bitwiseAnd(cloudShadowBitMask).eq(0).And(qa.bitwiseAnd(cirrusBitMask).eq(0)).And(qa.bitwiseAnd(dilatedCloudBitMask).eq(0)).And(qa.bitwiseAnd(cloudBitMask).eq(0))\n",
    "    red  = ee.Image(image).select('SR_B4').multiply(2.75e-05).add(-0.2)\n",
    "    nir = ee.Image(image).select('SR_B5').multiply(2.75e-05).add(-0.2)\n",
    "    blue  = ee.Image(image).select('SR_B2').multiply(2.75e-05).add(-0.2)\n",
    "    green = ee.Image(image).select('SR_B3').multiply(2.75e-05).add(-0.2)\n",
    "    swir1 = ee.Image(image).select('SR_B6').multiply(2.75e-05).add(-0.2)\n",
    "    swir2 = ee.Image(image).select('SR_B7').multiply(2.75e-05).add(-0.2)\n",
    "    data = red.rename('red')\\\n",
    "        .addBands(blue.rename('blue'))\\\n",
    "        .addBands(green.rename('green'))\\\n",
    "        .addBands(nir.rename('nir'))\\\n",
    "        .addBands(swir1.rename('swir1'))\\\n",
    "        .addBands(swir2.rename('swir2'))\\\n",
    "        .updateMask(mask)\n",
    "    return data.float()\n",
    "def preprocess_single_S2(image):\n",
    "    qa = image.select('QA60')\n",
    "    scl = image.select('SCL')\n",
    "    mask = qa.bitwiseAnd(1 << 10).eq(0).And(qa.bitwiseAnd(1 << 11).eq(0)).And(scl.neq(3)).And(scl.neq(8)).And(scl.neq(9)).And(scl.neq(10))\n",
    "    red  = ee.Image(image).select('B4').multiply(0.0001).multiply(0.982).add(0.00094)\n",
    "    nir = ee.Image(image).select('B8').multiply(0.0001).multiply(1.001).add(-0.00029)\n",
    "    blue  = ee.Image(image).select('B2').multiply(0.0001).multiply(0.977).add(-0.00411)\n",
    "    green = ee.Image(image).select('B3').multiply(0.0001).multiply(1.005).add(-0.00093)\n",
    "    swir1 = ee.Image(image).select('B11').multiply(0.0001).multiply(1.001).add(-0.00015)\n",
    "    swir2 = ee.Image(image).select('B12').multiply(0.0001).multiply(0.996).add(-0.00097)\n",
    "    data = red.rename('red')\\\n",
    "        .addBands(blue.rename('blue'))\\\n",
    "        .addBands(green.rename('green'))\\\n",
    "        .addBands(nir.rename('nir'))\\\n",
    "        .addBands(swir1.rename('swir1'))\\\n",
    "        .addBands(swir2.rename('swir2'))\\\n",
    "        .updateMask(mask)\n",
    "    return data.float()\n",
    "def preprocess_remotesensing(date_offset):\n",
    "    date = TIME_START.advance(date_offset, 'day')\n",
    "    date_filter = ee.Filter.date(date, date.advance(1, 'day'))\n",
    "    day_images_ls89 = ee.ImageCollection('LANDSAT/LC08/C02/T1_L2').filter(date_filter).merge(ee.ImageCollection('LANDSAT/LC09/C02/T1_L2').filter(date_filter)).filter(ee.Filter.equals('PROCESSING_LEVEL', 'L2SP')).map(preprocess_single_LS89)\n",
    "    day_images_s2 = ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED').filter(date_filter)\\\n",
    "        .filter(ee.Filter.neq('system:index', '20190417T065631_20190417T070736_T38NRL'))\\\n",
    "        .filter(ee.Filter.neq('system:index', '20190419T105621_20190419T105622_T41XML'))\\\n",
    "        .filter(ee.Filter.neq('system:index', '20190119T210811_20190119T210805_T06VXK'))\\\n",
    "        .filter(ee.Filter.neq('system:index', '20190117T010959_20190117T011000_T55TEN'))\\\n",
    "        .filter(ee.Filter.neq('system:index', '20190117T061209_20190117T061411_T42TVK'))\\\n",
    "        .filter(ee.Filter.neq('system:index', '20190117T140051_20190117T141300_T20HND'))\\\n",
    "        .map(preprocess_single_S2)\n",
    "    day_images = ee.ImageCollection(day_images_ls89).merge(day_images_s2)\n",
    "    data = ee.Algorithms.If(condition=ee.ImageCollection(day_images).first(),\\\n",
    "                            trueCase=day_images.mean().unmask(-1),\\\n",
    "                            falseCase=ee.Image.constant([-1.0, -1.0, -1.0, -1.0, -1.0, -1.0]).rename(['red', 'blue', 'green', 'nir', 'swir1', 'swir2']))\n",
    "    return ee.Image(data).float()\n",
    "if not os.path.exists(f'{region_name}_30m_RS'):\n",
    "    os.mkdir(f'{region_name}_30m_RS')\n",
    "for day_offset in range(TIME_LENGTH):\n",
    "    url = ee.Image(preprocess_remotesensing(day_offset)).select(columns_rs).float().getDownloadURL({\"region\": region_ee, \"scale\": 30, \"crs\": \"EPSG:4326\", \"format\": \"NPY\"})\n",
    "    response = requests.get(url, stream=True)\n",
    "    response.raise_for_status()\n",
    "    rs_data = np.load(io.BytesIO(response.content), allow_pickle=True)\n",
    "    np.save(f'{region_name}_30m_RS/{region_name}_30m_RS_{day_offset}.npy', rs_data.view((np.float32, len(columns_rs))))\n",
    "    print(f'day:{day_offset}, RS, {rs_data.shape}')"
   ],
   "id": "62c608f090c1d2ae"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Step-2: Run BERTH",
   "id": "288adb18e88be2f6"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# Define BERTH\n",
    "class HydroTrans(nn.Module):\n",
    "    def __init__(self, rs_dim=6, mo_dim=7, out_dim=4, max_len=500):\n",
    "        super().__init__()\n",
    "        self.model_hidden = 256\n",
    "        self.model_head = 4\n",
    "        self.model_layer = 6\n",
    "        single_encoder_layer = nn.TransformerEncoderLayer(d_model=self.model_hidden, nhead=self.model_head,\n",
    "                                                          dropout=0, batch_first=True)\n",
    "        combined_encoder_layer = nn.TransformerEncoderLayer(d_model=self.model_hidden * 3, nhead=self.model_head,\n",
    "                                                            dropout=0, batch_first=True)\n",
    "        self.rs_embedding = nn.Linear(rs_dim, self.model_hidden)\n",
    "        self.rs_encoder = nn.TransformerEncoder(single_encoder_layer, num_layers=3)\n",
    "        self.mo_embedding = nn.Linear(mo_dim, self.model_hidden)\n",
    "        self.mo_encoder = nn.TransformerEncoder(single_encoder_layer, num_layers=3)\n",
    "        self.topo_embedding = nn.Linear(3, self.model_hidden)\n",
    "        self.combined_encoder = nn.TransformerEncoder(combined_encoder_layer, num_layers=3)\n",
    "        self.position_embedding = nn.Embedding(max_len, self.model_hidden)\n",
    "        self.linear = nn.Linear(self.model_hidden * 3, out_dim)\n",
    "\n",
    "        for param in self.parameters():\n",
    "            param.requires_grad = True\n",
    "        self.max_length = max_len\n",
    "\n",
    "    def forward(self, rs, mo, topo):\n",
    "        position = torch.arange(self.max_length).unsqueeze(0).to(rs.device)\n",
    "        p_eb = self.position_embedding(position)\n",
    "        rs_features = self.rs_encoder(self.rs_embedding(rs) + p_eb, src_key_padding_mask=rs[:, :, 0] == -1)\n",
    "        mo_features = self.mo_encoder(self.mo_embedding(mo) + p_eb)\n",
    "        topo_features = self.topo_embedding(topo).unsqueeze(1).repeat(1, 500, 1)\n",
    "\n",
    "        combined_features = torch.concat([rs_features, mo_features, topo_features], dim=2)\n",
    "\n",
    "        hydro_features = self.combined_encoder.forward(combined_features)\n",
    "        return self.linear(hydro_features)\n",
    "\n",
    "# Load BERTH\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(device)\n",
    "model = HydroTrans(rs_dim=6, mo_dim=7, out_dim=4, max_len=500)\n",
    "model.load_state_dict(torch.load('BERTH_LS789S2_v1.pt', map_location='cpu'))\n",
    "model.to(device)\n",
    "model.eval()"
   ],
   "id": "8bf000d14c41a59f"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# Load Data\n",
    "grid_topo = np.load(f'{region_name}_30m_Terrain_GLO30.npy')\n",
    "grid_x = grid_topo.shape[0]\n",
    "grid_y = grid_topo.shape[1]\n",
    "grid_topo[:, :, 0] = grid_topo[:, :, 0] / 4000\n",
    "grid_topo[:, :, 1] = grid_topo[:, :, 1] / 360\n",
    "grid_topo[:, :, 2] = grid_topo[:, :, 2] / 30\n",
    "grid_mo = np.zeros((500, grid_x, grid_y, 8)) - 1\n",
    "grid_rs = np.zeros((500, grid_x, grid_y, 6)) - 1\n",
    "grid_hydro = np.ones((500, grid_x, grid_y, 4)) * -9999\n",
    "pbar = tqdm(total=500, desc='Loading')\n",
    "for i in range(500):\n",
    "    pbar.update(1)\n",
    "    grid_rs[i, :, :, :] = np.load(os.path.join(f'{region_name}_30m_RS', f'{region_name}_30m_RS_{i}.npy'))\n",
    "    grid_mo[i, :, :, :] = np.load(os.path.join(f'{region_name}_30m_MO', f'{region_name}_30m_MO_{i}.npy'))\n",
    "grid_rs = np.where(grid_rs < 0, -1, grid_rs)\n",
    "pbar.close()\n",
    "\n",
    "# Run BERTH\n",
    "pbar = tqdm(total=grid_x, desc=f'Running')\n",
    "for x_id in range(grid_x):\n",
    "    pbar.update(1)\n",
    "    for y_id in range(grid_y):\n",
    "        mo = torch.as_tensor(grid_mo[:, x_id, y_id, [0, 1, 2, 3, 5, 6, 7]], dtype=torch.float).to(device).unsqueeze(0)\n",
    "        rs = torch.as_tensor(grid_rs[:, x_id, y_id, :], dtype=torch.float).to(device).unsqueeze(0)\n",
    "        poi_topo = torch.as_tensor(grid_topo[x_id, y_id, :], dtype=torch.float).to(device)\n",
    "        topo = poi_topo.unsqueeze(0)\n",
    "        if torch.max(rs) < 0:\n",
    "            continue\n",
    "        model_output = model.forward(rs, mo, topo)\n",
    "        model_hydro = model_output[0].detach().cpu().numpy()\n",
    "        model_hydro[:, 1] = model_hydro[:, 1] / 10 # soil moisture\n",
    "        grid_hydro[:, x_id, y_id, :] = model_hydro\n",
    "pbar.close()\n",
    "\n",
    "# Export\n",
    "if not os.path.exists(f'{region_name}_30m_Hydro'):\n",
    "    os.mkdir(f'{region_name}_30m_Hydro')\n",
    "for i in range(500):\n",
    "    np.save(os.path.join(f'{region_name}_30m_Hydro', f'{region_name}_30m_Hydro_{i}.npy'), grid_hydro[i])"
   ],
   "id": "39ebb1f765cf98e3"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# Visualize\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from mpl_toolkits.axes_grid1.inset_locator import inset_axes\n",
    "from matplotlib import rcParams, colors as mcolors, cm, ticker as mticker, pyplot as plt\n",
    "import matplotlib.gridspec as gridspec\n",
    "\n",
    "\n",
    "et = None\n",
    "for i in range(0, 365):\n",
    "    print(i)\n",
    "    hydro = np.load(f'./Example_Hail_Arabia/DayOffset_{i}.npy')\n",
    "    if not isinstance(et, np.ndarray):\n",
    "        et = np.zeros((hydro.shape[0], hydro.shape[1]))\n",
    "    et += hydro[:, :, 2] + np.abs(np.min(hydro[:, :, 2]))\n",
    "\n",
    "config = {\n",
    "    'font.size': 12,\n",
    "    # 'font.weight': 'bold',\n",
    "    \"mathtext.fontset\": 'stix',\n",
    "    \"font.family\": 'Microsoft YaHei',\n",
    "    \"axes.unicode_minus\": False\n",
    "}\n",
    "rcParams.update(config)\n",
    "fig = plt.figure(figsize=(9, 6), dpi=300)\n",
    "grid = gridspec.GridSpec(1, 1)\n",
    "\n",
    "ax = fig.add_subplot(grid[0, 0])\n",
    "ax.imshow(et, cmap='coolwarm_r', vmin=0, vmax=500)\n",
    "ax.set_xticks([])\n",
    "ax.set_yticks([])\n",
    "\n",
    "axes = fig.get_axes()\n",
    "axins = inset_axes(axes[0], width=\"100%\", height=\"50%\", loc='lower center',\n",
    "               bbox_to_anchor=(0, -0.1, 1, 0.1), bbox_transform=axes[0].transAxes)\n",
    "cbar = plt.colorbar(cm.ScalarMappable(cmap=\"coolwarm_r\", norm=mcolors.Normalize(vmin=0, vmax=500)),\n",
    "                    orientation='horizontal', label=r'Total ET（mm·y${^{-1}}$）', cax=axins)\n",
    "\n",
    "plt.subplots_adjust(left=0.05,bottom=0.05,top=0.99,right=0.8,hspace=0.05, wspace=0.05)\n",
    "plt.show()\n"
   ],
   "id": "e9e7db648957e02b"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
