# Notebook to process GIA files from cluster

In [2]:
import numpy as np
import xarray as xr
import dask
import pandas as pd
from scipy.io import loadmat
from matplotlib import pyplot as plt
from matplotlib import cm
from matplotlib import colors as mcolors
import matplotlib.patches as mpatches
from matplotlib.collections import PatchCollection
from matplotlib.patches import Rectangle

import cartopy.crs as ccrs
import cartopy
import cartopy.feature as cfeature

import os
import glob
from time import time
from itertools import product
import datetime 
from pathlib import Path
import random
import shutil
import sys
import logging

now = datetime.datetime.now()
today = f'{now.year}{now.month}{now.day}'
PROJECT_ROOT = Path.cwd() / "data" # .parents[0]
DASKTMP_DIR = Path.cwd() / "data/output/dask_temp"

sys.path.append(str(PROJECT_ROOT))


import dask
import dask.dataframe as dd
import dask.array as da
import dask.bag as db
from dask import delayed, config
from dask_jobqueue import SLURMCluster
from distributed import Client
import distributed
from dask.distributed import progress
from dask.diagnostics import ProgressBar
dask.config.set({"distributed.comm.timeouts.connect": "60s"})

# from numba import float64, guvectorize
lazyload = dask.delayed(loadmat, pure=True)



def preprocess(dsold):

    ds = dsold.copy()
    name = '_'.join(ds.encoding["source"].split('_')[1:-1])
    print(name)

    times = ds.TIME

    ds = ds.drop(['LAT', 'LON', 'RSL', 'TIME']).rename({'ESL':'esl'})
    ds = ds.assign_coords(time=times, synthtests=name)
    ds = ds.expand_dims('synthtests')
    # ds = ds.interp(time=interptime, kwargs={"fill_value": "extrapolate"})

    return ds

def rnd(df):
    return round(df+1-1).astype(int).astype(str)  


def make_dflatlon(df):
    
    df_ll = df[['latrnd', 'lonrnd', 'sites']].copy()

    #kludge to get rid of negative zeros
    latrnd,  lonrnd = rnd(df.latrnd), rnd(df.lonrnd)
    df_ll['site'] = latrnd + '_' + lonrnd
    df_ll = df_ll.drop(['latrnd', 'lonrnd'],  axis=1)
    df_latlon = df_ll.sort_values(by='sites').drop_duplicates().reset_index(drop=True)
    return df_latlon

def add_features(ax):
    """ """
    ocean_50m = cfeature.NaturalEarthFeature('physical', 'ocean', '10m')
    land_50m = cfeature.NaturalEarthFeature('physical', 'land', '10m')
    ax.add_feature(ocean_50m, color='lightgray', zorder=0)
    ax.add_feature(land_50m, color='white', zorder=1)
    ax.gridlines(linewidth=1, color='white', alpha=0.5, draw_labels=True, zorder=4)
    ax.coastlines(resolution='10m', zorder=3)
    return ax

print('done')

done


In [3]:
cluster = SLURMCluster(cores=6, # Number of cores per job
                       processes=1,   # Number of Python processes to cut up each job
                       memory="128GB", # Amount of memory per job
                       project="jalab",
                       walltime="04:00:00",
                       header_skip=["-p "],
                       queue="normal",
                       silence_logs=logging.ERROR,
                       # silence_logs='error',
                       # interface='ib0',
                       local_directory= DASKTMP_DIR / 'spill',
                      )

# cluster = SLURMCluster()
print(cluster.job_script())

#!/usr/bin/env bash

#SBATCH -J dask-worker
#SBATCH -A jalab
#SBATCH -n 1
#SBATCH --cpus-per-task=6
#SBATCH --mem=120G
#SBATCH -t 04:00:00

/rigel/jalab/users/rcc2167/miniconda3/envs/pangeo2/bin/python -m distributed.cli.dask_worker tcp://10.43.4.53:41398 --nthreads 6 --memory-limit 119.21GiB --name dummy-name --nanny --death-timeout 60 --local-directory /rigel/jalab/users/rcc2167/data/output/dask_temp/spill --protocol tcp://



In [4]:
workers=400
cluster.scale(workers)

client = Client(cluster)  # Connect this local process to remote workers
client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.SLURMCluster
Dashboard: /proxy/8787/status,

0,1
Dashboard: /proxy/8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.43.4.53:41398,Workers: 0
Dashboard: /proxy/8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


# Load RSL data

In [5]:
# Read in ~11000 RSL data

df_data = pd.read_csv(PROJECT_ROOT / 'input/rsldata/gsl_230127_synthensemble_lowlats.csv').set_index('num')#.drop('Unnamed: 0', axis=1) # .drop(['num'], axis=1)

######## ADDING TEST WRSS DATA IN #########

standardized = ['vacchi2021a', 'Clemente2016', 'Engelhart2012', 'Engelhart2015', 'Hawkes2016', 
                 'Hibbert2016', 'Baranskaya18', 'Cooper18', 'Dean19','GarciaArtola18', 'Hijma19', 'Mann19',
                 'Shaw18', 'Tam18','Vacchi18a', 'Kemp2018',  'Woodroffe2009', 'Milne2005', 'Toth',
                 'Woodroffe2012', 'Xiong2018',  'Vacchi2021', 'Barnett2021', 'Hibbert2018', 'SEAMIS1p1',
                 'Bender', 'chua2020', 'kaniewski2021', 'vacchi2020', 'barlow2016', 'garrett2020']

df_data['standardized'] = df_data['compilation'].apply(lambda x: x in standardized)

########

df_data = df_data.dropna(how='any', axis=0, thresh=20)  ######## DROPPING NAN rows
df_data = df_data[df_data.age > 10] ######### DROP AGE=0 

# titles of synthetic columns in each saved .mat file
titles = ['_'.join(d.split('_')[:-3]) for d in df_data.columns if (len(d) > 15) ] ####### & ("ump2_lm40" in d)uusing only one viscosity


# Get model names for which we have both csv and esl
eslpath = PROJECT_ROOT / 'output/all/esls/'
esl_afpath = PROJECT_ROOT / 'output/all/esls_af/'
matpath = PROJECT_ROOT / 'output/all/mats/'
rslpath = PROJECT_ROOT / 'output/rsls/'


fs = [str(f).split('/')[9][:-4] for f in matpath.glob('*.mat') if 'l71C.ump2.lm40' not in str(f)]   ###### HARD CODED CHANGE TO 8 #######
fs2 = [str(f).split('/')[9][:-4] for f in eslpath.glob('*.mat') if 'l71C.ump2.lm40' not in str(f)]
fs3 = [str(f).split('/')[9][:-4] for f in esl_afpath.glob('*.mat') if 'l71C.ump2.lm40' not in str(f)]

# make all models synchronized b/w GMSL and WRSS
df_fs = pd.DataFrame(data=fs, columns=['name'])
df_fs['im'] = df_fs['name'].str.split('_um').str[0]
df_fs2 = df_fs[['im']].iloc[df_fs['im'].drop_duplicates().index]

fs = list(df_fs['name'].iloc[df_fs[['im']][df_fs2.isin(fs2)].dropna().index])
fs2 = [f[:-5] for f in fs]

# mods = fs                                                          ###### NOW DOES NOT INTERSECT WITH GMSLS NECESSARILY #######
mods = list(set([str(f[:-5]) for f in fs]) & set(fs2))

print(f'mat = {len(fs)}, esls = {len(fs2)}, intersection = {len(mods)}')

mats = [str(matpath / m) + '.mat' for m in fs]    ###### NOW DOES INTERSECT WITH GMSLS NECESSARILY #######
esls = [str(eslpath / m) + '.mat' for m in fs2]  
esls_af = [str(esl_afpath / m) + '.mat' for m in fs2]  


# get ages for ESL curves
p = PROJECT_ROOT / 'input/standardized/ais/6000.nc'
ages = xr.open_dataset(p).age.values[::-1]

by = 7 # degree rounding

# dataframe of only basic info, not WRSS
df = df_data[[d for d in df_data.columns if len(d) < 15]].copy()# .reset_index(drop=True)  ######## RESETTING INDEX HERE!!!! ###### 

# rounded lat/lon boxes and ages to nearest 100
df['latrnd'] = round(df.lat / by) * by
df['lonrnd'] = round(df.lon / by) * by
df['agernd'] = round(df.age, -2) 

# choose only slips ##### 
# _a denotes 'all' instead of just SLIPs
STANDARDIZED = False
df_a = df.copy()    
if STANDARDIZED:
    df_a = df_a[df_a.standardized==True]
    
    
IS3D = False
ONEVISC = False
viscnum = 'allvisc'
visc = 'l71C.ump3.lm9.'
visc_ul= 'l71C_ump3_lm9'
if ONEVISC:
    mats = [m for m in mats if visc in m]
    viscnum = 'onevisc'
    
# df = df[df.type == 0]# .reset_index(drop=True) ##### reset index so index is same as GIA mod index

# calculate number of data at each location 
count = df.groupby(['latrnd', 'lonrnd'])['type'].transform('count')
count_a = df_a.groupby(['latrnd', 'lonrnd'])['type'].transform('count')

# names of all gia models with GIA included
names = [f.split('midhol_')[1][:-4].replace('.', '_') for f in mats]

# column names of all GIA models with GIA parts removed
cols = ['_'.join(f.split('midhol_')[1][:-4].split('_')[:-1]) for f in mats]
uniquecols = np.unique(cols)


# assign each lat/lon box a unique id
sites = df.groupby(['latrnd', 'lonrnd']).ngroup()
sites_a = df_a.groupby(['latrnd', 'lonrnd']).ngroup()
unique_sites = np.unique(sites_a)

print(f'total SLIP sites = {sites.unique().shape[0]}')
print(f'total sites = {sites_a.unique().shape[0]}')
print('')


# assign sites as column to dataframes
df['sites'] = sites
df_a['sites'] = sites_a

#collect lat/lons from sites
df_latlon = make_dflatlon(df)
df_latlon_r = make_dflatlon(df_a)


##### NOTE: COULD ONLY SAVE ONE ESL PER VISCOSITY ###########

# Get GMSLs from every unique ice model for which we have a weight
# we can drop duplicates because GMSL will be same for each GIA model

eslmds = ['_'.join(m.split('/')[-1].split('_')[2:-1]) for m in esls]
matmds = ['_'.join(m.split('/')[-1].split('_')[2:-1]) for m in mats]

df_matnames = pd.DataFrame({'mods':matmds, 'mats':mats}).drop_duplicates('mods', keep='first')   
df_eslnames = pd.DataFrame({'mods':eslmds, 'esls':esls}).drop_duplicates('mods', keep='first')

eslnames = [str(eslpath / n.split('/')[-1][:-4]) + '.mat' for n in df_eslnames['esls']]


####### Preprocess PISM weights ##########
USEANTRSLS = False
print('processing pism')
antn = '09'
if USEANTRSLS:
    antn = '10'
    
p = PROJECT_ROOT / f"input/pismweights/scores_16km_{antn}datatypes_ed.txt"

df_pism = pd.read_csv(p, 
                         sep=",", 
                         encoding = "utf-8", 
                         header=None,
                        ).iloc[:,:2]

df_pism = df_pism.rename({0:'pism',1:'weight'}, axis=1).set_index('pism')


print('done')


mat = 103759, esls = 103759, intersection = 103759
total SLIP sites = 166
total sites = 166

processing pism
done


In [6]:
flist = []
for f in fs[:]:
    _, _, gis, ais, eis, lis, therm, mnt, um = f.split('_')
    flist.append([gis, ais, eis, lis, therm, mnt, um])
    
df_cnt = pd.DataFrame(data=flist, columns=['gis', 'ais', 'eis', 'lis', 'therm', 'mnt', 'um'])

if False:
    for col in df_cnt.columns:
        counts = df_cnt.groupby([col])[col].count().sort_values()
        print('\n')
        print(counts.name)
        for i, r in counts.iteritems():
            print(i, r)

            # print(i, sep = ' ')

        print('\n')


# Save .mat ESL files into zarr format for easier loading

In [7]:
allmodnames = np.unique([int(e.split('_')[3]) for e in sorted(eslnames)])

outpath = DASKTMP_DIR / 'esls/'
donenames = sorted([int(str(s)[-4:]) for s in outpath.glob('*')])
antmodlist = sorted(set(allmodnames).difference(donenames))


for i, model in enumerate(antmodlist):

    if not (outpath / f'{model}.nc').is_file():
        print(model)

        p = (PROJECT_ROOT / 'output/all/esls')
        fnames = np.array([f.as_posix() for f in p.glob('*') if str(model) in str(f)])
        lazyfiles = [lazyload(p)['ESL'] for p in fnames]
        sample = lazyfiles[0].compute()[0,:]


        def read_one_file(block_id, filenames=fnames, axis=1):

            # a function that reads in one chunk of data
            path = filenames[block_id[axis]]
            file = loadmat(path)['ESL'][0,:]

            return np.expand_dims(file, axis=axis)

        #load esls files
        esl_all = da.map_blocks(
            read_one_file,
            dtype=sample.dtype,
            chunks=(*sample.shape, (1,) * len(fnames))
        ).rechunk()

        ds_esl = xr.Dataset(
            {
                "esl":(['esls', 'time'],esl_all.T)
            }, 
            coords={ 
                "esls": ['_'.join(n.split('/')[-1].split('_')[-6:])[:-4] for n in fnames],
                "time": ages[ages < 13],
            },
        ).esl.drop_duplicates(dim='esls').to_dataset(name='esl') #.persist()

        ds_esl.to_zarr(DASKTMP_DIR / f'esls/ant{model}')
        
        [client.cancel(d) for d in [ds_esl, esl_all]]
print('done')

done


# Load saved .zarr ESL files into xarray

In [8]:
def prepro_zarr(ds):
    return ds.sortby('esls')
loadnames = np.sort([f.as_posix() for f in (DASKTMP_DIR / f'esls/').glob('*')])

ds_esls = xr.concat(
    [xr.open_mfdataset(l, 
                       preprocess=prepro_zarr, 
                       engine='zarr').persist() for l in loadnames],
    dim='esls')


# assign numbers to real gmsl ice models to ease processing
ds_real = xr.Dataset(
                {
                    "esl":(['icemods',  'time'],ds_esls.esl.sortby('esls').data)  #### CRUCIAL SORT BY HERE <----- XXXX 
                }, 
                coords={ 
                    "icemods": np.arange(len(ds_esls.esls.values)),
                    "time":ds_esls.time.values,

                },
            ).chunk(dict(icemods=-1)) # .sel(time=synth_esls.time.values)

# small enough to keep in memory
ds_reallgm = ds_real.chunk({'time':1}).persist()
print('done')            


done


# Process Real & Synthetic WRSS by site to get full ice model weights

In [9]:
if True:
    def arg_min(a):
        return a.idxmin(dim='giamod').astype('str')

    df_n = pd.DataFrame(mats).rename({0:'mod'}, axis=1)
    df_n['ant'] = df_n['mod'].apply(lambda x: x.split('_')[3])
    df_ngrps = df_n.groupby(['ant'])['mod'].apply(list).reset_index()

    # do same for esl_af names
    df_e_n = pd.DataFrame(esls_af).rename({0:'mod'}, axis=1)
    df_e_n['ant'] = df_e_n['mod'].apply(lambda x: x.split('_')[4])
    df_e_ngrps = df_e_n.groupby(['ant'])['mod'].apply(list).reset_index()


    VISCDIR = DASKTMP_DIR / f'latlon{by}/{viscnum}/'
    donemax = 0
    notdone_sites = [u for u in unique_sites] #  if u not in done_sites]


    # for site in df_n.gris.unique():
    #     p = VISCDIR / f's_{site}/'
    #     if not os.path.exists(p):
    #         os.makedirs(p)

    for site in range(len(unique_sites)):
        p = VISCDIR / f's{site}/'
        if not os.path.exists(p):
            os.makedirs(p)

    runnames = [# 'idx',
                'real',
               'mst_6249_mst_i6g_l71C_ump2_lm40',
               'i6g_6238_glc_i6g_l71C_ump2_lm40',
               'mst_6223_mst_i6g_l71C_ump2_lm40',
               'mst_6246_glc_i6g_l71C_ump2_lm40',
               'i6g_5980_glc_i6g_l71C_ump2_lm40',
               'mst_6245_glc_i6g_l71C_ump2_lm40',
               'mst_6237_mst_i6g_l71C_ump2_lm40',
               'i6g_6245_mst_i6g_l71C_ump2_lm40',
               'mst_6253_mst_i6g_l71C_ump2_lm40',
               'S40_3D_runs_p55_F3D_p55_S40',
               # 'rsl_ratestep'
               ]

    allmods = np.arange(6256, 6256)
    step=1
    n=0
    with dask.config.set(**{'array.slicing.split_large_chunks': False}):
        # for n in notdone_sites[::step]:
            # print(n)
            # iterate through zarr-saved files, split up by Antarctica

        for _, (ant, fnames) in df_ngrps.iterrows():
        # for _, (gris, lau, fnames) in df_ngrps.iloc[1:].iterrows():
            # print('gris: ', gris, 'lau:', lau, end=' ')
            

            start = time()
            
            # Get file names for esls above floatation
            fnames_af = df_e_ngrps[df_e_ngrps.ant == ant]['mod'].iloc[0]

        #
            # get list of files in site directory, specifying 'mst' as it is last for each Antarctic model.  
            fglobs = [f.as_posix() for f in (VISCDIR / f's{n}').glob('*') if '/esl' in f.as_posix()]

            if len(fglobs) > 0:
                alreadydones = [int(f.split('/')[-1][4:8]) for f in fglobs if '/esl' in f]
                donemax = np.max(alreadydones)

            donemax=5977
            # choose only antarctic models not done yet
            antsnotdone = allmods[allmods > donemax] 

            if int(ant) in antsnotdone:

                print(f'ant: {ant}', end='| ')

                # get giamodel names once 
                giamodnames = [f.split('midhol_')[1][:-3] for f in fnames]
                giamodnames_af = [f.split('midhol_')[1][:-4] for f in fnames_af]


                lazyfiles = [lazyload(p)['mats'] for p in fnames]
                lazyfiles_af = [lazyload(p)['ESL_afs'] for p in fnames_af]

                sample = lazyfiles[0].compute()[:,:,:,1:-1] ####### DROP INDEX COLUMN & RSL COLUMN ####### [0,:]
                sample_af = lazyfiles_af[0].compute()[:,:,:] ####### DROP INDEX COLUMN & RSL COLUMN ####### [0,:]

                def read_one_file(block_id, filenames=fnames, axis=4): ### NOTE THAT AXIS MAY NEED ADJUSTING

                    # a function that reads in one chunk of data
                    path = filenames[block_id[axis]]
                    file = loadmat(path)['mats'][:,:,:,1:-1] ####### DROP INDEX COLUMN & RSL COLUMN ######

                    return np.expand_dims(file, axis=axis)
                
                def read_one_file_af(block_id, filenames=fnames_af, axis=3): ### NOTE THAT AXIS MAY NEED ADJUSTING

                    # a function that reads in one chunk of data
                    path = filenames[block_id[axis]]
                    file = loadmat(path)['ESL_afs'][:,:,:]

                    return np.expand_dims(file, axis=axis)


                #load esls files
                mats_all = da.map_blocks(
                    read_one_file,
                    dtype=sample.dtype,
                    chunks=(*sample.shape, (1,) * len(fnames))
                    # chunks=(*sample.shape, (1,) )

                ).rechunk()
                
                esl_af_all = da.map_blocks(
                    read_one_file_af,
                    dtype=sample_af.dtype,
                    chunks=(*sample_af.shape, (1,) * len(fnames_af))
                    # chunks=(*sample.shape, (1,) )

                ).rechunk()

                ds_esl_af = xr.Dataset(
                        {
                            "esl":(['lm', 'lith', 'age', 'icemod'], esl_af_all)
                        }, 
                        coords={ 
                            "lm": [3, 5, 7, 8, 9, 10, 15, 20, 30, 40, 50], 
                            "lith": [71, 96],
                            "age": ages[ages < 13],
                            "icemod":giamodnames_af, 

                        },
                    ).esl.persist()
                
                ds_mats = xr.Dataset(
                        {
                            "wrss":(['lm', 'lith', 'rsldata', 'run', 'icemod'], mats_all)
                        }, 
                        coords={ 
                            "lm": [3, 5, 7, 8, 9, 10, 15, 20, 30, 40, 50], 
                            "lith": [71, 96],
                            "rsldata": df_data.index.values,
                            "run": runnames,
                            "icemod": giamodnames, 

                        },
                    ).wrss.persist()

                # Iterate through sites
                for site in unique_sites[:]:

                    # try:
                    if site%10 == 0:
                        print(f'{site}', end=' ')

                    # get indices of specific site
                    sitevec = sites_a[sites_a == site].index.values


                    # Choose only data at site
                    ds_site = ds_mats.sel(rsldata=sitevec)# .persist()
                    # break

                    ####### REDUCE DATA #########

                    #take mean over RSL data & get best GIA model for each ice emodel
                    ds_min = ds_site.mean('rsldata').min(['lm', 'lith']).to_dataset(name='wrss')

                    #get best ESL_af model for each ice emodel
                    ds_site_it = ds_site.stack({'lmlith':['lm','lith']}).mean('rsldata')

                    ds_esl_af_it = ds_esl_af.stack({'lmlith':['lm','lith']})

                    # set esl icemod dim to have names including um visc to align with wrss dim
                    ds_esl_af_it['icemod'] = ds_site.icemod.values

                    # Replace lmlith multiindex for single index
                    idxmin = ds_site_it.chunk('auto')

                    lmlith = ds_site.stack({'lmlith':['lm', 'lith']}).lmlith.values
                    lmlithidx = [f'{str(lm)}_{str(lith)}' for lm, lith in lmlith]
                    idxmin['lmlith'] = lmlithidx
                    
                    idxmin = idxmin.idxmin('lmlith')
                
                    # Get minimum lm/lith index from wrsses

                    # Choose GMSL curves that fit lm/lith index.  Sum to remove nans.
                    ds_esl_min = ds_esl_af_it.where(ds_site_it.lmlith == idxmin).sum('lmlith').to_dataset(name='esl')

                    # change index back to no um visc
                    ds_esl_min['icemod'] = ds_esl_af.icemod
                    

                    
                    ##### CHOOSE WHICH OR BOTH OF THESE TO SAVE -- ESLS? OR WTS? ############

                    # Make sure single chunk
                    ds_min = ds_min.chunk(ds_min.sizes)
                    # ds_esl_min = ds_esl_min.chunk(ds_esl_min.sizes)
                    # idxmin = idxmin.to_dataset(name='lmlith').chunk(idxmin.sizes)

                    
                    save = ds_min.to_zarr(VISCDIR / f's{site}/tmp_{ant}', mode='w')
                    # esl_save = ds_esl_min.to_zarr(VISCDIR / f's{site}/esl_{ant}', mode='w')
                    # idxmin_save = idxmin.compute().to_zarr(VISCDIR / f's{site}/idx_{ant}', mode='w')

                    # except:
                    #     print(f'cannot save site {site}')
                #     break
                # break

                    
                    # [client.cancel(d) for d in [esl_save, save, ds_min, ds_esl_min, idxmin, ds_esl_af_it, ds_site_it, ds_site]]
                
                end = np.round(time() - start, 1)
                print(f'| time = {end} seconds')


print('done')

done


# Process saved zarr files for each site down to single file per site
## This file represents best GIA model fit for each site. 

# First Index

In [12]:
def split_ds_string(x, n):
    return int(str(x).split('_')[n])
    
def apply_ufunc(ds, n):
    return xr.apply_ufunc(split_ds_string,
                            ds,
                            n,
                            vectorize=True,
                            dask="parallelized",
                           output_dtypes=[int])

def preprocess_str(ds):
    
    um = xr.apply_ufunc(lambda x: int(x.split('_')[-1][3:-1]),
                        ds.icemod,
                        vectorize=True).to_dataset(name='um')
       
    lm = apply_ufunc(ds, 0).sel(run='real').rename({'lmlith':'lm'})
    lith = apply_ufunc(ds, 1).sel(run='real').rename({'lmlith':'lith'})

    
    return xr.merge([um, lm, lith]).chunk('auto')

idxpaths = glob.glob((DASKTMP_DIR / f'latlon{by}/{viscnum}/*/*').as_posix())
done_sites = sorted([int(s.split('/')[-1].split('_')[1][1:]) for s in idxpaths if 'allidx' in s])
started_sites = np.unique([int(s.split('/')[-2][1:]) for s in idxpaths if 'idx' in s and 'allidx' not in s and 'allsites_idx' not in s])

for site in started_sites:
    
    print(f'site {site}', end=' ')

    
    PATH = DASKTMP_DIR / f'latlon{by}/{viscnum}/s{site}'
    im_files = glob.glob(PATH.as_posix() + '/allidx*')
    idxfiles = glob.glob(PATH.as_posix() + '/idx*')

    if len(idxfiles) > 0:

        name = sorted(idxfiles)[-1].split('_')[-1]

        if ('6255' in name) & (len(im_files) == 0) & (site >= 0):

            # Combine site esl_af files into single file
            ds = xr.open_mfdataset(idxfiles[:], engine='zarr',
                                   parallel=True, 
                                   combine='nested', #combine='by_coords', #
                                   concat_dim=['icemod'],
                                   preprocess=preprocess_str
                                  )

            ds = ds.chunk(ds.sizes).astype(str).persist()
            
            print('saving')

            save = ds.to_zarr(PATH / f'allidx_s{site}', mode='w')

        else:
            print(f' not saved', end = ', ') 
    
print('done')

done


### Then ESL

In [13]:
done_eslpaths = glob.glob((DASKTMP_DIR / f'latlon{by}/{viscnum}/*/*').as_posix())
done_sites = sorted([int(s.split('/')[-1].split('_')[1][1:]) for s in done_eslpaths if 'allesl' in s])
# started_sites = np.unique([int(s.split('/')[-2][1:]) for s in done_eslpaths if 'esl' in s and 's_allims' not in s ])
started_sites = np.unique([int(s.split('/')[-2][1:]) for s in done_eslpaths if 'esl' in s and 'allesl' not in s and 'allsites' not in s])

for site in started_sites:
    
    print(f'site {site}', end=' ')

    
    PATH = DASKTMP_DIR / f'latlon{by}/{viscnum}/s{site}'
    im_files = glob.glob(PATH.as_posix() + '/s_allesl*')
    eslfiles = glob.glob(PATH.as_posix() + '/esl*')

    if len(eslfiles) > 0:

        name = sorted(eslfiles)[-1].split('_')[-1]

        if ('6255' in name) & (len(im_files) == 0) & (site > -1):

            # Combine site esl_af files into single file
            ds = xr.open_mfdataset(eslfiles[:], engine='zarr',
                                   parallel=True, 
                                   combine='nested', #combine='by_coords', #
                                   concat_dim=['icemod'],
                                   # preprocess=preprocess
                                  )

            ds = ds.chunk(ds.sizes).persist()
            
            print('saving')

            save = ds.to_zarr(PATH / f'allesl_s{site}', mode='w')

        else:
            print(f' not saved', end = ', ') 
    
print('done')

done


### Then WRSS

In [14]:
donepaths = glob.glob((DASKTMP_DIR / f'latlon{by}/{viscnum}/*/tmp_*').as_posix())
done_sites = sorted([int(s.split('/')[-1].split('_')[2][1:]) for s in donepaths if 's_allim' in s])

done_epaths = glob.glob((DASKTMP_DIR / f'latlon{by}/{viscnum}/*/*esl_*').as_posix())
done_esites = sorted([int(s.split('/')[-1].split('_')[1][1:]) for s in done_epaths if 'allesl' in s])

done_ipaths = glob.glob((DASKTMP_DIR / f'latlon{by}/{viscnum}/*/*idx_*').as_posix())
done_isites = sorted([int(s.split('/')[-1].split('_')[1][1:]) for s in done_ipaths if 'allidx' in s])

In [None]:
started_sites = np.unique([int(s.split('/')[-2][1:]) for s in donepaths if 'tmp' in s and 's_allims' not in s ])
whichruns = '0to206'

for site in started_sites:

    print(f'site {site}', end=' ')
    PATH = DASKTMP_DIR / f'latlon{by}/{viscnum}/s{site}'
    im_files = glob.glob(PATH.as_posix() + '/s_allims*')
    tmpfiles = glob.glob(PATH.as_posix() + '/tmp*')


    # def preprocess(ds):
    #     ds = ds.swap_dims({'run':'icemod'})
    #     ds = ds.drop('visc')                  ###### DROPPING VISCOSITY
    #     # ds = ds.expand_dims('run')
    #     return ds

    if len(tmpfiles) > 0:
        
        name = sorted(tmpfiles)[-1].split('_')[-1]
        
        if ('6255' in name) & (len(im_files) == 0):
            
            ds = xr.open_mfdataset(tmpfiles[:], engine='zarr',
                                   parallel=True, 
                                   combine='nested', #combine='by_coords', #
                                   concat_dim=['icemod'],
                                   # preprocess=preprocess
                                  )
            ds = ds.chunk(ds.sizes).persist()

            print('saving')

            save = ds.to_zarr(PATH / f's_allims_s{site}_{whichruns}', mode='w')
            
        else:
            print(f' not saved', end = ', ') 
        

print('done')

site 0 

In [None]:
if True:
    for site in done_sites:

        PATH = DASKTMP_DIR / f'latlon{by}/{viscnum}/s{site}'

        imfiles = glob.glob(PATH.as_posix() + '/s_allims*')
        tmpfiles = glob.glob(PATH.as_posix() + '/tmp*')

        efiles = glob.glob(PATH.as_posix() + '/allesl*')
        etmpfiles = glob.glob(PATH.as_posix() + '/esl*')
        
    


        # delete temp files once we have composite ice model file
        if imfiles:
            if tmpfiles:
                print(f'site {site} is already processed.  Huzzah!  Deleting temp files ...')
                [shutil.rmtree(f) for f in tmpfiles]
        
        if efiles:
            if etmpfiles:
                print(f'esl at site {site} is already processed.  Huzzah!  Deleting temp files ...')
                [shutil.rmtree(f) for f in etmpfiles]
            
     
    print('done')

    
    for site in done_isites:
        PATH = DASKTMP_DIR / f'latlon{by}/{viscnum}/s{site}'
        
        ifiles = glob.glob(PATH.as_posix() + '/allidx*')
        itmpfiles = glob.glob(PATH.as_posix() + '/idx*')
        if ifiles:
            # if itmpfiles:
            print(f'idx at site {site} is already processed.  Huzzah!  Deleting temp files ...')
            [shutil.rmtree(f) for f in itmpfiles]

            
    for site in done_esites:

        PATH = DASKTMP_DIR / f'latlon{by}/{viscnum}/s{site}'

        efiles = glob.glob(PATH.as_posix() + '/allesl*')
        etmpfiles = glob.glob(PATH.as_posix() + '/esl*')

        # delete temp esl files once we have composite esl_af file

        if efiles:
            print(f'esl at site {site} is already processed.  Huzzah!  Deleting temp files ...')
            [shutil.rmtree(f) for f in etmpfiles]
    print('done')

# Load combined ice model weights/viscs & ESLs

In [None]:
def preprocess_esl_idx(ds):
    name = ds.encoding['source']
    ds['site'] = int(name.split('_')[-1][1:])
    return ds.chunk(ds.sizes)

def load_ds_batch(files, i, n, preprocess):

    
    print(i)
    ds = xr.open_mfdataset(files[i*n:i*n+n], 
                           engine='zarr', 
                           parallel=True, 
                           combine='nested',
                           concat_dim=['site'],
                           # compat='no_conflicts',

                           preprocess=preprocess_esl_idx
                          ).persist()
    return ds

ALLVISCPATH = f'/rigel/jalab/users/rcc2167/data/output/dask_temp/latlon{by}/{viscnum}/'

if True:
    p = 's*/allesl*'
    files = glob.glob(ALLVISCPATH + p)

    n = 12
    dses = [load_ds_batch(sorted(files), i, n, preprocess_esl_idx) for i in range(18)]

    ds_allsite_esl = xr.concat(dses, dim='site').esl.chunk('auto')

    # ds_allsite_esl = ds_allsite_esl.assign_coords({'site':range(207)}) ## NB these site numbers are diff from previous site numbeers 

    save = ds_allsite_esl.to_dataset(name='esl').to_zarr(ALLVISCPATH + f'allsites_esls', mode='w')


if True:
    n=12
    files = glob.glob(ALLVISCPATH + 's*/allidx*')

    dses = [load_ds_batch(files, i, n, preprocess_esl_idx) for i in range(17)]

    ds_allsite_idx = xr.concat(dses, dim='site').chunk('auto')

    save = ds_allsite_idx.to_zarr(ALLVISCPATH + f'allsites_idxs', mode='w')


   
# ds_allsite_esl = xr.open_mfdataset(ALLVISCPATH + 'allsites_esls', engine='zarr')


# ds_allsite_idx = xr.open_mfdataset(ALLVISCPATH + 'allsites_idxs', engine='zarr')

# # remove upper mantle viscosity from dataset
# ds_allsite_idx['icemod'] = xr.apply_ufunc(lambda x: '_'.join(x.split('_')[:-1]),ds_allsite_idx.icemod, vectorize=True,)

# ds_allsite_idx

# Load all wrss scores

In [None]:
def preprocess_wts(ds):
    name = ds.encoding['source']
    ds['site'] = int(name.split('_')[-2][1:])
    return ds.chunk(ds.sizes)

if True:
    
    p = 's*/s_allims*'
    files = glob.glob(ALLVISCPATH + p)

    n = 12
    dses = [load_ds_batch(files, i, n, preprocess_wts) for i in range(18)]

    ds_allsite = xr.concat(dses, dim='site').wrss.chunk('auto')

    # ds_allsite = ds_allsite.assign_coords({'site':range(207)}) ## NB these site numbers are diff from previous site numbeers 

    save = ds_allsite.to_dataset(name='wrss').to_zarr(ALLVISCPATH + f'allsites_ims', mode='w')
    
    
ds_allsite = xr.open_mfdataset(ALLVISCPATH + 'allsites_ims', engine='zarr')
ds_allsite

# process WRSS models to align with GMSL models

In [None]:
df_allsite = ds_allsite[['icemod']].to_pandas()
df_allsite['icemod_novisc'] = ds_allsite.icemod.to_pandas().str.split('_um').str[0].values
df_allsite = df_allsite['icemod_novisc'].drop_duplicates()
df_allsite = df_allsite[df_allsite.to_frame().isin(ds_esls.esls.values)['icemod_novisc']]

d = df_allsite.isin(ds_allsite_idx.icemod.to_pandas().index)
df_allsite_viscs = df_allsite[d].index # .to_xarray()
df_allsite_noviscs = df_allsite[d].values

In [None]:
ds_allsite = ds_allsite.sel(icemod=df_allsite_viscs).chunk('auto')
ds_allsite

# Process only real (need to make for synthetics)

In [None]:
# Model 6088 failed -- need to rerun

df_allsite = (ds_allsite.wrss.sel(run='real').T.to_pandas() / 10).reset_index().dropna(thresh=10)# .drop('site')
df_allsite['icemod'] = df_allsite.icemod.str.split('_um').str[0]
df_allsite.head()

# Get mean WRSS for each ice model over all sites

In [None]:
icemod_wrss = df_allsite.copy().set_index('icemod').mean(axis=1).to_frame('wrss')

In [None]:
df = df_allsite.copy()

df['thermosteric'] = df.icemod.str.split('_').str[4].str[1:]

thermodict = df.drop('icemod', axis=1).groupby('thermosteric').mean().mean(axis=1).to_dict()

thermo_wrss = df.replace({'thermosteric':thermodict})[['icemod', 
                                                      'thermosteric']].rename({'thermosteric':
                                                                               'wts'}, axis=1)
thermo_wrss = thermo_wrss.set_index('icemod')


In [None]:
df = df_allsite.copy()

df['mountains'] = df.icemod.str.split('_').str[5].str[1:]

mountdict = df.drop('icemod', axis=1).groupby('mountains').mean().mean(axis=1).to_dict()

mount_wrss = df.replace({'mountains':mountdict})[['icemod', 
                                                      'mountains']].rename({'mountains':
                                                                            'wts'}, axis=1)
mount_wrss = mount_wrss.set_index('icemod')

In [None]:
df = df_allsite.copy()

df['ais'] = df.icemod.str.split('_').str[1]

df_dd = dd.from_pandas(df, npartitions=10)

aisdict = df_dd.drop('icemod', axis=1).groupby('ais').mean().mean(axis=1).compute().to_dict()

ais_wrss = df.replace({'ais':aisdict})[['icemod',
                                       'ais']].rename({'ais':'wts'}, axis=1).set_index('icemod')

ais_wrss.head()

# Transform WRSS into weights by inverting and normalizing 

In [None]:
def normalize(df):
    return np.true_divide(df, df.sum())

In [None]:
ds_icemod_norm = normalize(1/icemod_wrss).to_xarray()# .to_dataset(name='wt') 
ds_mountain_norm = normalize(1/mount_wrss).to_xarray()
ds_thermo_norm = normalize(1/thermo_wrss).to_xarray()
ds_ais_norm = normalize(1/ais_wrss).to_xarray()
ds_esls_final = ds_esls.sel(esls=ds_icemod_norm.icemod).chunk('auto').rename({'time':'age'}).persist()
ds_esls_af_final = ds_allsite_esl.sel(icemod=ds_icemod_norm.icemod).chunk('auto').persist()
ds_esls_af_final

# Process ESL_af on cluster b/c too big for local

In [None]:
# weight GMSL curves by icemodel weight
ds_esls_af_wted = ds_esls_af_final.weighted(ds_icemod_norm.wrss).mean('site').persist()

# Load PISM weights

In [None]:
df_pism_wt = ds_icemod_norm.to_pandas()
df_pism_wt['ais'] = ds_icemod_norm.icemod.to_pandas().str.split('_').str[1].astype('int')
df_pism_wt = df_pism_wt[['ais']]
df_pism_wt['wt'] = df_pism_wt['ais'].map(df_pism['weight'].to_dict())

ds_pism = normalize(df_pism_wt[['wt']].to_xarray())
ds_pism

# Make flat distribution wts

In [None]:
import scipy.stats as stats

# make 40 % of distribution below uniform distribution
xmin = ds_reallgm.esl.sel(time=6.5, method='nearest').min().values
xlow = np.arange(xmin, -10, 0.1)

x2 = np.linspace(stats.truncexpon.ppf(0.001, b=50),
                stats.truncexpon.ppf(0.99999999, b=10), len(xlow)) 
ylow = stats.truncexpon.pdf(x=x2, loc=0.00, b=10)[::-1] 


########## MAKE ZERO PART OF DISTRIBUTION ##########

mu = 1
xcenter = np.arange(-10, 2, 0.1)

x = np.linspace(stats.uniform.ppf(0.01, mu),
                stats.uniform.ppf(0.99, mu), len(xcenter)) 
ycenter = stats.uniform.pdf(x, loc=1, scale=1 )[::-1]

####### MAKE FULL DISTRIBUTION #########

gmsl = np.concatenate([xlow, xcenter])
post_pdf = np.concatenate([ylow, ycenter])
post_pdf_box = post_pdf 

gmsl_nonans_len = len(ds_reallgm.sel(time=6).esl[ds_reallgm.sel(time=6).esl.notnull()])

gmsl_interp = np.linspace(np.nanmin(ds_reallgm.esl.sel(time=6, method='nearest').values),
                          np.nanmax(ds_reallgm.esl.sel(time=6, method='nearest').values),
                          gmsl_nonans_len,)

post_pdf_box = np.interp(gmsl_interp, gmsl,  post_pdf_box) 
post_pdf_box = post_pdf_box / np.trapz(post_pdf_box)

plt.plot(post_pdf_box, gmsl_interp, 'k-', lw=2, label='frozen pdf')


In [None]:
t=6
esl_6k = ds_esls_final.sel(age=t).sortby('esls').esl.chunk('auto').persist()
esl_6k_sort = esl_6k.to_pandas().sort_values().to_xarray()
esl_6k_sort

In [None]:
#calculate prior density of models 
density, bins = np.histogram(esl_6k_sort, bins=170, density=True)

prior_density_interp = np.interp(esl_6k_sort.values, bins[:-1], density)
prior_density_interp  = prior_density_interp / np.trapz(prior_density_interp)

# interpolate desired posterior density distribution onto GMSL grid
post_pdf_box_real = np.interp(esl_6k_sort, gmsl_interp, post_pdf_box)
post_pdf_box_real = post_pdf_box_real / np.trapz(post_pdf_box_real)

# calculate necessary scale factor to get GMSLs to posterior density
scale_factor = post_pdf_box_real / prior_density_interp

# make dataset
ds_scale_factor = xr.Dataset(
            {
                "wt":(['icemod'],scale_factor)
            }, 
            coords={ 
                "icemod":esl_6k_sort.icemod.values,
            },
        )

#plot figure
fig, ax = plt.subplots(1, 3, figsize=(20, 6))
ax1, ax2, ax3 = ax.flatten()

# ax1.scatter(prior_density_interp, eslplt_sort.values, s=0.3)
ax1.plot(prior_density_interp, esl_6k_sort.values)
ax1.set_title('GMSL model prior density', fontsize=20) 
ax1.set_ylabel('GMSL (m)', fontsize=20)
ax1.set_xlabel('Probability density', fontsize=20)

ax2.plot(scale_factor, esl_6k_sort, color='g', lw=2)
ax2.axvline(1, color='r')
ax2.set_title('weighting factor', fontsize=20)
ax2.set_xlabel('multiplier', fontsize=20)


ax3.plot(scale_factor * prior_density_interp, esl_6k_sort, lw=3)
ax3.set_title('posterior model density', fontsize=20)
ax3.set_xlabel('Probability density', fontsize=20)

[a.grid(alpha=0.5) for a in ax.flatten()]

plt.suptitle('Left * Center = Right', fontsize=24, y=1)

# Save site weights, weighted by overall model weights

In [None]:
def normalize(ds, dim='icemod'):
    return np.true_divide(ds, ds.sum(dim))

In [None]:
ds_allsite_novisc_r = ds_allsite.sel(run='real').copy()
ds_f = ds_allsite.icemod.to_pandas().str.split('_um').str[0].to_frame().reset_index()
ds_allsite_idx_r = ds_f.drop('icemod', axis=1).rename({0:'icemod'}, axis=1).set_index('icemod').to_xarray()
ds_allsite_novisc_r['icemod'] = ds_allsite_idx_r.icemod

# Get total weights 
ds_rslpismflat = normalize((ds_icemod_norm.wrss + ds_pism.wt) * normalize(ds_scale_factor.wt))
# ds_rslpismflat = ds_rslpismflat.fillna(ds_rslpismflat.mean())

# Get AIS weights
ds_rslpismflat_ais = normalize((ds_ais_norm.wts + ds_pism.wt) * normalize(ds_scale_factor.wt))
ds_scalefactor_ais = normalize(ds_scale_factor.wt)


ds_allsite_novisc_r = ds_allsite_novisc_r.sel(icemod=ds_rslpismflat.icemod)

In [None]:
ds_sitewt = normalize(ds_allsite_novisc_r.weighted(ds_rslpismflat).mean('icemod').wrss, dim='site')
ds_sitewt

In [None]:
plt.plot(ds_sitewt.values)

# Compute posterior weights

In [None]:
ds_rslpismflat_wts = normalize( (normalize(ds_icemod_norm.wrss) + normalize(ds_pism.wt) ) * normalize(ds_scale_factor.wt) )

# Save models

In [None]:
COLLDIR = DASKTMP_DIR / f'latlon{by}/collated/'

ds_sitewt.to_netcdf(COLLDIR / 'site_wts.nc')
ds_icemod_wt = ds_icemod_norm.to_netcdf(COLLDIR / 'rsl_wts.nc')
ds_pism.to_netcdf(COLLDIR / 'pism_wts.nc')
ds_esls_final.to_netcdf(COLLDIR / 'icemod_esls.nc')
ds_scale_factor.to_netcdf(COLLDIR / 'scalefactor_wts.nc')
ds_ais_norm.to_netcdf(COLLDIR / 'ais_wts.nc')

ds_thermo_norm.to_netcdf(COLLDIR / 'thermo_wts.nc')
ds_mountain_norm.to_netcdf(COLLDIR / 'mountain_wts.nc')
ds_esls_af_wted.to_netcdf(COLLDIR / 'icemod_esls_af_wted.nc')

ds_scalefactor_ais.to_dataset(name='wt').to_netcdf(COLLDIR / 'scalefactor_ais_wts.nc')
ds_rslpismflat_ais.to_dataset(name='wt').to_netcdf(COLLDIR / 'rslpismflat_ais_wts.nc')
ds_rslpismflat_wts.to_netcdf(COLLDIR / 'rslpismflat_wts.nc') #ds_im_pism_scale_wts = 

# Repeat processing step for synthetics 

#### NB:don't need scale factor or pism for this; just need RSL weights

In [None]:
def normalize(df):
    return np.true_divide(df, df.sum())
normlist =[]
for run in ds_allsite.run[ds_allsite.run != 'real']:
    
    print(run.values)
    # Choose one run; convert WRSS to dataframe, dropping nans
    df_allsite = (ds_allsite.wrss.sel(run=run).T.to_pandas() / 10).reset_index().dropna(thresh=10)# .drop('site')
    
    # Add icemod name
    df_allsite['icemod'] = df_allsite.icemod.str.split('_um').str[0]
    
    # Get mean WRSS for each ice model over all sites
    icemod_wrss = df_allsite.copy().set_index('icemod').mean(axis=1).to_frame('wrss')
    
    # Normalize as wts rather than WRSS, then Convert back into xarray
    ds_icemod_norm = normalize(1/icemod_wrss).to_xarray()# .to_dataset(name='wt') 
    
    ds_icemod_norm = ds_icemod_norm.expand_dims({'run':[run.values]})

    normlist.append(ds_icemod_norm.chunk('auto'))

    
    # break
ds_synths_norm = xr.concat(normlist, dim='run')

In [None]:
COLLATEDPATH = f'/rigel/jalab/users/rcc2167/data/output/dask_temp/latlon{by}/collated/'
ds_synths_norm.to_netcdf(COLLATEDPATH + 'rsl_synth_wts.nc')

# Load Synthetic GMSLs

In [None]:
def load_esl(path):
    name = path.split('midhol_')[1][:-4] 
    esl = loadmat(path)['ESL']
    return xr.DataArray(esl, dims=['run', 'age'], 
                        coords=dict(
                            run=[name],
                            age=ages[ages < 13])
                       )
    

ds_synth_esls = xr.concat([load_esl(p) for p in files], dim='run')

ds_synth_esls.to_netcdf(COLLATEDPATH + 'synth_esls.nc')

# Get viscosity weights that are composite of global and local

In [73]:
ds_allsite_umlmlith = ds_allsite_idx.astype(int).weighted(ds_rslpismflat).mean('icemod')
# ds_allsite_umlmlith.to_netcdf(COLLDIR / 'visc_wts.nc')

In [253]:
df_wts_ = ds_icemod_norm.wrss.squeeze().drop('run').to_dataset(name='wts').to_pandas()
df_wts_['run'] = ds_icemod_norm.wrss.squeeze().drop('run').icemod.to_pandas().index.str[:16]
minum_idx = df_wts_.groupby('run')['wts'].idxmin().values

ds_rslpismflat_um = normalize((ds_icemod_norm.wrss.sel(icemod=minum_idx) +
                               ds_pism.wt.sel(icemod=minum_idx)
                              ) * normalize(ds_scale_factor.sel(icemod=minum_idx))
                             )
ds_rslpismflat_um

array(1.)

In [None]:
def normalize(ds):
    return np.true_divide(ds, ds.sum('icemod'))

ds_allsite_novisc = ds_allsite.sel(run='real').wrss.copy()
ds_allsite_novisc['icemod'] = ds_rslpismflat_wts.icemod
ds_siteglobe_viscwts = normalize(normalize(ds_allsite_novisc) + ds_rslpismflat_wts)

In [263]:
df_wts_ = ds_rslpismflat_wts.to_dataset(name='wts').to_pandas()# .reset_index()
df_wts_['run'] = ds_rslpismflat_wts.icemod.to_pandas().index.str[:16]
minum_idx = df_wts_.groupby('run')['wts'].idxmin().values

In [266]:
df_siteglobe_umviscwts = normalize(
        normalize(ds_allsite_novisc.sel(icemod=minum_idx)) + 
        normalize(ds_rslpismflat_wts.sel(icemod=minum_idx))
    ).squeeze()

In [267]:
df_siteglobe_umviscwts

Unnamed: 0,Array,Chunk
Bytes,39.89 MiB,1.73 MiB
Shape,"(207, 25256)","(9, 25256)"
Count,393 Tasks,23 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 39.89 MiB 1.73 MiB Shape (207, 25256) (9, 25256) Count 393 Tasks 23 Chunks Type float64 numpy.ndarray",25256  207,

Unnamed: 0,Array,Chunk
Bytes,39.89 MiB,1.73 MiB
Shape,"(207, 25256)","(9, 25256)"
Count,393 Tasks,23 Chunks
Type,float64,numpy.ndarray


In [272]:
ds_allsite_idx_wted = ds_allsite_idx.sel(icemod=minum_idx).astype(int).weighted(df_siteglobe_umviscwts).mean('icemod')

ds_allsite_idx_wted.to_netcdf(COLLDIR / 'visc_wts.nc')

In [148]:

# For each site:

# Get site weights
ds_wrss_it = normalize(ds_allsite.sel(site=0).sel(run='real').wrss)

# Add viscosity to global weights
ds_rslpismflat_wts_viscs = ds_rslpismflat_wts.copy()
ds_rslpismflat_wts_viscs['icemod'] = df_allsite_viscs


# Add site weights to global weights and normalize
ds_wrss_siteglobe = normalize(ds_wrss_it + ds_rslpismflat_wts_viscs)

df_wt_siteglobe = ds_wrss_siteglobe.to_dataset(name='wt').to_pandas()[['wt']]

df_wt_siteglobe['um'] = df_allsite_viscs.str[-1].astype(int)

df_wt_siteglobe

Unnamed: 0_level_0,wt,um
icemod,Unnamed: 1_level_1,Unnamed: 2_level_1
bzv_5978_anu_anu_t1417_m161_ump3,0.000008,3
bzv_5978_anu_anu_t1885_m195_ump5,0.000005,5
bzv_5978_anu_anu_t2337_m195_ump4,0.000005,4
bzv_5978_anu_anu_t2920_m161_ump2,0.000014,2
bzv_5978_anu_glc_t1283_m161_ump2,0.000013,2
...,...,...
mst_6255_mst_i6g_t3842_m102_ump2,0.000005,2
mst_6255_mst_mst_t1692_m70_ump3,0.000003,3
mst_6255_mst_mst_t3077_m68_ump4,0.000002,4
mst_6255_mst_mst_t314_m21_ump5,0.000002,5


In [92]:
%%bash
cd /rigel/jalab/users/rcc2167/data/output/dask_temp/latlon5/collated
rm icemod_esls_af_wts.nc
rm im_pism_scale_wts.nc



In [78]:
%%bash
cd /rigel/jalab/users/rcc2167/data/output/dask_temp/latlon5/collated
du -sh

724M	.
