# !/usr/bin/env python
# Script to produce Figure 6 of the manuscript 'Building Archean cratonic roots'
# Written by Charitra Jain

import sys
import os
import io
import math
import numpy as np
import fnmatch
import h5py
import statistics as st
import shutil
import pandas as pd

import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.colors import ListedColormap, LinearSegmentedColormap

from stagpy import stagyydata

local_system=os.getlogin()

mpl.rcParams['font.family'] = "sans-serif"
mpl.rcParams['font.sans-serif'] = "Arial"
mpl.rcParams['lines.linewidth'] = 1
mpl.rcParams['font.size'] = 10
mpl.rcParams['axes.linewidth'] = 1
mpl.rcParams['axes.titlesize'] = 10
mpl.rcParams['axes.titleweight'] = "bold"
mpl.rcParams['legend.fontsize'] = 10
default_font_size = 10

here = os.getcwd()

run_names_list = ['/cases/A1/','/cases/A7/','/cases/A13/','/cases/B5/','/cases/B8/','/cases/B12/','/cases/C1/','/cases/C8/']
total_runs = int(8)

run = int(0)
# nz values written per core
nz_core = int(128)
# ny values written per core
ny_core = int(64)
# number of columns written per row in HDF dataset
ncolumns = int(ny_core*2)
# number of rows in a HDF dataset
nrows = int(nz_core/2)
# number of radial cells in the computational domain
nz=128
# number of cores (needed for file names)
ncores=8
numcores_files=8
core           = ['00', '01', '02', '03', '04', '05', '06', '07']
dim3           = len(core)


# empty lists
timesteps = []
time_TTGmass = []
TTG1   = []
TTG2   = []
TTG3   = []
frames = []
steps  = []
ny_raw = []
cthick = []

cvol = []
nztot = []
raw = []
time = []
radius = []

seconds_in_year = 31556926
seconds_in_Myear = seconds_in_year*1e6
density_TTG = 2700.
earth_radius = 6371.e3
core_radius = 3481.e3
mantle_depth = 2890.e3
density_TTG = 2700.
depth_window = 100.e3

### Specify fields to extract using stagpy ###
x_profile1='Tmean'
y_profile='r'
radial_profile1 = 'vrms'

nytot = 512

# quadrant analytical volume is 1/16 of full annulus
volume_mantle          = (4 * math.pi*(earth_radius**3-core_radius**3))/3
analytical_volume_full = (4 * math.pi*(earth_radius**3-core_radius**3))/3 * (math.pi)/(nytot)
analytical_volume_quad = (4 * math.pi*(earth_radius**3-core_radius**3))/3 * (math.pi)/(nytot*16)

area_earth             = 4 * math.pi * earth_radius**2
area_earth_quad        = 4 * math.pi * earth_radius**2 * (math.pi)/(nytot*16)
area_core              = 4 * math.pi * core_radius**2

# scaling factor
factor=volume_mantle/analytical_volume_quad

#
compute_TTGmass           = True
compute_RemainingTTGmass  = True
compute_mafic             = True
compute_felsic            = True
plot_everything           = True

debug = False

# initialising empty arrays with each element as a list
# Each list will have different number of elements,
# which would be the number of frames from a particular run
tmean_run = np.frompyfunc(list, 0, 1)(np.empty((total_runs), dtype=object))
velocity_rms_run = np.frompyfunc(list, 0, 1)(np.empty((total_runs), dtype=object))

time_Myr_step_run = np.frompyfunc(list, 0, 1)(np.empty((total_runs), dtype=object))
TotalTTGvolume_run = np.frompyfunc(list, 0, 1)(np.empty((total_runs), dtype=object))
TotalTTGmass_run = np.frompyfunc(list, 0, 1)(np.empty((total_runs), dtype=object))
RemainingTTG_mass_run = np.frompyfunc(list, 0, 1)(np.empty((total_runs), dtype=object))
radius_cell_top = np.frompyfunc(list, 0, 1)(np.empty((nz_core), dtype=object))
finaltime_run = np.frompyfunc(list, 0, 1)(np.empty((total_runs), dtype=object))
lowPTTG_final_volume_run = np.frompyfunc(list, 0, 1)(np.empty((total_runs), dtype=object))
medPTTG_final_volume_run = np.frompyfunc(list, 0, 1)(np.empty((total_runs), dtype=object))
highPTTG_final_volume_run = np.frompyfunc(list, 0, 1)(np.empty((total_runs), dtype=object))
allPTTG_final_volume_run = np.frompyfunc(list, 0, 1)(np.empty((total_runs), dtype=object))

crust_bs = np.frompyfunc(list, 0, 1)(np.empty((total_runs), dtype=object))
crust_TTG = np.frompyfunc(list, 0, 1)(np.empty((total_runs), dtype=object))
crust_bs_thickness = np.frompyfunc(list, 0, 1)(np.empty((total_runs), dtype=object))
crust_TTG_thickness = np.frompyfunc(list, 0, 1)(np.empty((total_runs), dtype=object))
time_Myr_run =  np.frompyfunc(list, 0, 1)(np.empty((total_runs), dtype=object))

# initialising empty array to store crustal volume calculated from mass at the last time step
finalvolume = np.zeros(total_runs)
lowP = np.zeros(total_runs)
medP = np.zeros(total_runs)
highP = np.zeros(total_runs)
finalmass = np.zeros(total_runs)

for batch in run_names_list:

    palette = ['#0A9396', '#0A9396','#0A9396','#EE9B00', '#EE9B00','#EE9B00','#BB3E03','#BB3E03']

    path=(here+batch)
    print(path)

    # from stagpy import stagyydata
    sdat = stagyydata.StagyyData(path)

    # reading in # of frames
    file = (path+'/profiles/frames.dat')
    f = open(file,'r')
    lines = f.readlines()
    f.close
    for line in lines:
        p = line.split()
        frames.append(int(p[0]))


    # extracting time from time_botT.h5
    # this is needed when plotting the time evolution of a field
    time_list = []
    f = h5py.File(path+'+hdf5/time_botT.h5', 'r')
    ndatasets = len(f.keys())
    for key in f.keys():
        data = f[key]
        time_list.append(float(data[0]))
    f.close()

    time=np.array(time_list)
    time_Myr=time/seconds_in_Myear
    print('# HDF datasets: '+str(ndatasets))

    time_Myr_run[run]=time_Myr
    print(time_Myr_run[run][len(time_Myr_run[run])-1])

    time_list.clear()

    # reading in the timesteps when snapshots are saved
    snaps = []
    file = (path+'+op/earthDIM_rprof.dat')
    f = open(file,'r')
    lines=f.readlines()[1:]
    f.close
    nlines=int(len(lines))

    for i in np.arange(0,nlines,nz+1):
        p = lines[i].split()
        snaps.append(int(p[1]))
    nframes = int(len(snaps))
    print('# frames (rprof.dat): ' + str(nframes))

    for i in np.arange(0,nframes-1):
        if snaps[i+1] <= snaps[i]:
            print(i,snaps[i+1],snaps[i])
            print('ERROR: something wrong with steps in _rprof.dat')
            exit()

    print('Time at last frame: '+str('{:.2f}'.format(sdat.tseries['t'][snaps[nframes-1]]*1e-6/seconds_in_year)))

    # loading Tmean from time series
    tmean_list = []
    time_list_steps = []
    # loading velocities (in the top cell, nz-1) from radial profiles
    velocity_rms_list = [] # this is vsurf_rms

    for i in np.arange(0,nframes):
        tmean_list.append(float(sdat.tseries['Tmean'][snaps[i]]))
        time_list_steps.append(float(sdat.tseries['t'][snaps[i]]/seconds_in_Myear))
        velocity_rms_list.append(float(sdat.rprof.loc[snaps[i],radial_profile1][nz-1]*100*seconds_in_year))

    velocity_rms_run[run] = np.array(velocity_rms_list)
    tmean_run[run] = np.array(tmean_list)
    tmean_list.clear()
    time_list_steps.clear()
    velocity_rms_list.clear()

    # computing total mass of TTG produced
    if compute_TTGmass:
        file = (path+'profiles/TTGmass.dat')
        f = open(file,'r')
        lines = f.readlines()
        f.close
        for line in lines:
            p = line.split()
            timesteps.append(int(p[0]))
            time_TTGmass.append(float(p[1]))
            TTG1.append(float(p[2]))
            TTG2.append(float(p[3]))
            TTG3.append(float(p[4]))

        finaltime=time_TTGmass[len(time_TTGmass)-1]*1e-6/seconds_in_year

        time_Myr_step = np.array(time_TTGmass)*1e-6/seconds_in_year

        # importing arrays
        TTG1v = np.array(TTG1)
        TTG2v = np.array(TTG2)
        TTG3v = np.array(TTG3)

        # declaring empty array
        TTGmass = np.array([0])

        # first sum med-P and high-P TTG
        # then add that to low-P TTG
        TTGmass = np.add(TTG1v,np.add(TTG2v,TTG3v))

        # getting cumulative mass
        TTGmass_cumulative = np.cumsum(TTGmass)
        lowP_cumulative = np.cumsum(TTG1v)
        medP_cumulative = np.cumsum(TTG2v)
        highP_cumulative = np.cumsum(TTG3v)

        print('Scaling factor: '+str(factor))
        # these values are scaled up
        print('TTG cumulative final mass: '+str(TTGmass_cumulative[len(time_TTGmass)-1]*factor))
        print('Cumulative lowP mass     : '+str(lowP_cumulative[len(time_TTGmass)-1]*factor))
        print('Cumulative medP mass     : '+str(medP_cumulative[len(time_TTGmass)-1]*factor))
        print('Cumulative highP mass    : '+str(highP_cumulative[len(time_TTGmass)-1]*factor))

        # getting cumulative volume
        TTGvolume = TTGmass_cumulative * factor * 1e-9 / density_TTG
        lowPTTG_final_volume  = lowP_cumulative[len(time_TTGmass)-1]    * factor * 1e-9 / density_TTG
        medPTTG_final_volume  = medP_cumulative[len(time_TTGmass)-1]    * factor * 1e-9 / density_TTG
        highPTTG_final_volume = highP_cumulative[len(time_TTGmass)-1]   * factor * 1e-9 / density_TTG
        allPTTG_final_volume  = TTGmass_cumulative[len(time_TTGmass)-1] * factor * 1e-9 / density_TTG
        print('Relative proportions: '+str('{:.0f}'.format(lowPTTG_final_volume*100/allPTTG_final_volume))+'% '+str('{:.0f}'.format(medPTTG_final_volume*100/allPTTG_final_volume))+'% '+str('{:.0f}'.format(highPTTG_final_volume*100/allPTTG_final_volume))+'%')
        print('Total cumulative TTG volume (scaled, km3): '+str('{:.2e}'.format(allPTTG_final_volume)))

        finalvolume[run] = allPTTG_final_volume
        lowP[run] = lowPTTG_final_volume*100/allPTTG_final_volume
        medP[run] = medPTTG_final_volume*100/allPTTG_final_volume
        highP[run] = highPTTG_final_volume*100/allPTTG_final_volume
        # final cumulative mass has to be scaled up
        finalmass[run] = TTGmass_cumulative[len(time_TTGmass)-1] * factor

        # clean up for the next run
        timesteps.clear()
        time_TTGmass.clear()
        TTG1.clear()
        TTG2.clear()
        TTG3.clear()

        time_Myr_step_run[run] = time_Myr_step
        TotalTTGvolume_run[run] = TTGvolume
        TotalTTGmass_run[run] = TTGmass_cumulative * factor
        finaltime_run[run] = finaltime
        lowPTTG_final_volume_run[run] = lowPTTG_final_volume
        medPTTG_final_volume_run[run] = medPTTG_final_volume
        highPTTG_final_volume_run[run] = highPTTG_final_volume
        allPTTG_final_volume_run[run] = allPTTG_final_volume


    # computing mass of remaining TTG
    if compute_RemainingTTGmass:

        # reading in cell volume with depth extracted from StagYY
        file = (here+'/data/dvol.dat')
        f = open(file,'r')
        lines = f.readlines()
        f.close
        for line in lines:
            p = line.split()
            cvol.append(float(p[0]))
            nztot.append(int(p[1]))
        cell_volume = np.array(cvol)/1e9
        coord_nz = np.array(nztot)
        cvol.clear()
        nztot.clear()

        # this should be equal to mantle_512_quad = 3.4715599202663994E+017 (m3)
        # print('Domain volume (km3): '+ str(np.sum(cell_volume)*512))

        # reading in cell radius
        file = (here+'/data/rad.dat')
        f = open(file,'r')
        lines = f.readlines()
        f.close
        for line in lines:
            p = line.split()
            radius.append(float(p[0]))
        radius_grid = np.array(radius)/1e3
        radius.clear()

        for i in np.arange(0,128):
            radius_cell_top[i] = radius_grid[2*i]

        radius_cell_top_from_bottom = np.flip(radius_cell_top)

        # determining the # rows that need to be processed from HDF datasets
        for i in np.arange(0,len(radius_cell_top_from_bottom)):
            if radius_cell_top_from_bottom[i] > (mantle_depth - depth_window)/1e3:
                layer_bottom_cell = i
                break
        nrows_bottom = int(layer_bottom_cell/2)
        if run==0:
            print('Rows above (and including) '+str(nrows_bottom)+' are considered in the depth window')


        # number of files PER CORE starting from 00000 - 000xx
        nfiles = int(len(fnmatch.filter(os.listdir(path+'+hdf5/'), 'TTG_00*.h5'))/numcores_files)
        print('# files   : '+str(nfiles))


        # number of datasets PER FILE (may be less in the last file)
        f = h5py.File(path+'+hdf5/TTG_00000_00001.h5', 'r')
        dataset_limit = len(f.keys())
        print('data/file : '+str(dataset_limit))
        f.close()

        # storing all the values with dimensions [numcores_files][ndatasets][ny=64][nz=128]
        rawdata =  np.frompyfunc(list, 0, 1)(np.empty((ndatasets), dtype=object))
        data =  np.frompyfunc(list, 0, 1)(np.empty((numcores_files,ndatasets), dtype=object))

        for i in np.arange(1,numcores_files+1):

            # these are because each file has a limited number of datasets
            first = 0
            last = dataset_limit
            # overall counter for number of datasets
            c = 0

            for j in np.arange(0,nfiles):
                file = h5py.File(path+'+hdf5/TTG_00'+str('{:03d}'.format(j))+'_00'+str('{:03d}'.format(i))+'.h5', 'r')
                # print(file)

                # taking one DATASET at a time
                for key in file.keys():
                    # print(key)
                    raw = file[key][()]
                    # storing them in a bigger array
                    rawdata[c] = raw

                    # and then saving everything in data
                    for k in np.arange(first,last):
                        data[i-1,k] = np.array(rawdata[k])

                    # counter for ndatasets
                    c = c + 1

                # counter for datasets per file
                first = last
                last = dataset_limit*(j+2)
                # if the next file is the last file
                # then the number of DATASETS recorded in it may be less than 'dataset_limit'
                # set the counter to total number of DATASETS
                if j == nfiles - 2 :
                    last = ndatasets


        # # extracted data has a shape of (numcores_files, ndatasets)
        # # we tranpose it to (ndatasets, numcores_files)
        data_TTG = data.transpose()

        volume_remaining = np.zeros(ndatasets)

        for m in np.arange(0,ndatasets):
            for n in np.arange(0,numcores_files):
                for p in np.arange(nrows_bottom,nrows):
                    for q in np.arange(0,ncolumns):
                        if q < int(ncolumns/2):
                            volume_remaining[m] = volume_remaining[m] + data_TTG[m][n][0][p][q]*cell_volume[p*2]*1e9
                        elif q >= int(ncolumns/2):
                            volume_remaining[m] = volume_remaining[m] + data_TTG[m][n][0][p][q]*cell_volume[p*2+1]*1e9

        RemainingTTG_mass_run[run] = volume_remaining*density_TTG*factor
        print(RemainingTTG_mass_run[run])

        nytot = int(nytot/len(core));

        # number of files PER CORE starting from 00000 - 000xx
        nfiles = int(len(fnmatch.filter(os.listdir(path+'+hdf5/'), 'CrustThickness_00*.h5'))/numcores_files)
        # print('# files   : '+str(nfiles))

        # total number of datasets starting from 00000 - 000xx
        f = h5py.File(path+'+hdf5/time_botT.h5', 'r')
        ndatasets = len(f.keys())
        print('# datasets: '+str(ndatasets))
        f.close()

        if compute_mafic:

            # average crustal thickness for each core for all frames
            # with dimension = [ncores,nframes]
            # with size = ncores*nframes
            thickness_avg_core_mafic    = np.zeros((dim3,frames[run]))

            # average crustal thickness for entire domain for all frames
            # with dimension = [1, nframes]
            # with size = 1*nframes
            thickness_avg_domain_mafic  = np.zeros(frames[run])

            # crustal volume for all frames
            # with dimension = [1, nframes]
            # with size = nframes
            crustal_volume_domain_mafic = np.zeros(frames[run])


            for numcore in np.arange(dim3):

                # picking the files to open
                file_mafic = (path+'/+hdf5/CrustThickness000'+core[numcore]+'.dat')

                # checking the file
                if debug: print(file_mafic)

                # opens the file for reading
                f = open(file_mafic,'r')

                # returns all lines from the file as a list
                lines = f.readlines()

                # closes the file
                f.close

                for line in lines:
                    # skipping empty lines
                    if line == "\n":
                        continue
                    else:
                        # splits each line into a list of items
                        p = line.split()

                        # saving the ny-coordinates as a list
                        ny_raw.append(int(p[0]))

                        # crustal thickness values in list format
                        # for a given core for all frames
                        # with dimension = [1, frames*ny_core]
                        cthick.append(float(p[1]))

                # checking length of the imported list
                if debug: print(len(cthick))

                # importing crustal thickness values in array format
                # for a given core for all frames
                # with dimension = [1,frames*ny_core]
                cthickv = np.array(cthick)

                # checking size of the imported array
                if debug: print(cthickv.size)

                # cleaning up to get values from the next core
                ny_raw.clear()
                cthick.clear()

                # computing average thickness per core
                for numframe in np.arange(1,frames[run]):

                    # average crustal thickness for each core for all frames
                    # with dimension = [ncores,nframes]
                    # with size = ncores*nframes
                    thickness_avg_core_mafic[numcore,numframe] = st.mean(cthickv[ny_core*(numframe-1):numframe*ny_core])
                    # if debug: print(numcore,numframe)

                    # average crustal thickness for entire domain for all frames
                    # with dimension = [1, nframes]
                    # with size = 1*nframes
                    thickness_avg_domain_mafic[numframe] = st.mean(thickness_avg_core_mafic[:,numframe])
                    # if debug: print(thickness_avg_domain_mafic.size)


            # finished looping on different cores
            # crutal volume for entire domain for all frames
            # with dimension = [1, nframes]
            # with size = nframes
            crustal_volume_domain_mafic = 4 * math.pi * 1e-9 * thickness_avg_domain_mafic * (earth_radius - thickness_avg_domain_mafic/2)**2
            # if debug: print(crustal_volume_domain_mafic.size)

            crust_bs[run] = crustal_volume_domain_mafic

        if compute_felsic:

            # average crustal thickness for each core for all frames
            # with dimension = [ncores,nframes]
            # with size = ncores*nframes
            thickness_avg_core_felsic    = np.zeros((dim3,frames[run]))

            # average crustal thickness for entire domain for all frames
            # with dimension = [1, nframes]
            # with size = 1*nframes
            thickness_avg_domain_felsic  = np.zeros(frames[run])

            # crustal volume for all frames
            # with dimension = [1, nframes]
            # with size = nframes
            crustal_volume_domain_felsic = np.zeros(frames[run])

            for numcore in np.arange(dim3):

                # picking the files to open
                file_felsic = (path+'/+hdf5/CrustThickness_felsic000'+core[numcore]+'.dat')

                # checking the file
                if debug: print(file_felsic)

                # opens the file for reading
                f = open(file_felsic,'r')

                # returns all lines from the file as a list
                lines = f.readlines()

                # closes the file
                f.close

                for line in lines:
                    # skipping empty lines
                    if line == "\n":
                        continue
                    else:
                        # splits each line into a list of items
                        p = line.split()

                        # saving the ny-coordinates as a list
                        ny_raw.append(int(p[0]))

                        # crustal thickness values in list format
                        # for a given core for all frames
                        # with dimension = [1, frames*ny_core]
                        cthick.append(float(p[1]))

                # checking length of the imported list
                if debug: print(len(cthick))

                # importing crustal thickness values in array format
                # for a given core for all frames
                # with dimension = [1,frames*ny_core]
                cthickv = np.array(cthick)

                # checking size of the imported array
                if debug: print(cthickv.size)

                # cleaning up to get values from the next core
                ny_raw.clear()
                cthick.clear()

                # computing average thickness per core
                for numframe in np.arange(1,frames[run]):

                    # average crustal thickness for each core for all frames
                    # with dimension = [ncores,frames]
                    # with size = ncores*nframes
                    thickness_avg_core_felsic[numcore,numframe] = st.mean(cthickv[ny_core*(numframe-1):numframe*ny_core])
                    # if debug: print(numcore,numframe)

                    # average crustal thickness for entire domain for all frames
                    # with dimension = [1, frames]
                    # with size = 1*nframes
                    thickness_avg_domain_felsic[numframe] = st.mean(thickness_avg_core_felsic[:,numframe])
                    # if debug: print(thickness_avg_domain_felsic.size)


            # finished looping on different cores
            # crutal volume for entire domain for all frames
            # with dimension = [1, nframes]
            # with size = nframes
            crustal_volume_domain_felsic = 4 * math.pi * 1e-9 * thickness_avg_domain_felsic * (earth_radius - thickness_avg_domain_felsic/2)**2
            # if debug: print(crustal_volume_domain_felsic.size)

            crust_TTG[run] = crustal_volume_domain_felsic

        run = run + 1


if compute_mafic and compute_felsic:
    # to check if x and y axes have same size
    for i in np.arange(0,total_runs):
        size_x = time_Myr_run[i].size
        size_y1 = crust_bs[i].size
        size_y2 = crust_TTG[i].size
        size_y3 = tmean_run[i].size


        print(i,size_x,size_y1,size_y2,size_y3)
        if size_x != size_y1:
            print('Size mismatch for run '+str(i),size_x,size_y1)
        if size_x != size_y2:
            print('Size mismatch for run '+str(i),size_x,size_y2)
        if size_x != size_y3:
            print('Size mismatch for run '+str(i),size_x,size_y3)


if plot_everything:

    fig, ax = plt.subplots(4, 1, figsize = (9,12), squeeze=False, dpi = 300)

    symbol = ['o','^','s','o','^','s','^','s']
    marker_frequency = [100, 20, 100, 20, 20, 20, 20, 20]
    marker_frequency_TTG = 2000

    for i in np.arange(0,total_runs):
        # mean mantle T
        ax[0][0].plot(time_Myr_run[i],tmean_run[i], linestyle='-', linewidth=1, alpha=1, color = palette[i],marker=symbol[i], markersize=3, mec=palette[i], mew='1', markevery = marker_frequency[i], mfc=palette[i])

    ax[0][0].set_xlim(0., 1500.)
    ax[0][0].set_ylim(2050., 2550.)
    ax[0][0].set_yscale('linear')
    ax[0][0].set_ylabel('Mean mantle temperature (K)')
    ax[0][0].text(1450., 2502, 'A', fontweight='bold', fontsize='12')


    for i in np.arange(0,total_runs):
        # vrms at surface
        ax[1][0].plot(time_Myr_run[i],velocity_rms_run[i], linestyle='-', linewidth=1, alpha=1, color = palette[i],marker=symbol[i], markersize=3, mec=palette[i], mew='1', markevery = marker_frequency[i], mfc=palette[i])

    ax[1][0].set_xlim(0., 1500.)
    ax[1][0].set_ylim(1e-4,1e+3)
    ax[1][0].set_xscale('linear')
    ax[1][0].set_yscale('log')
    ax[1][0].set_ylabel('RMS velocity at surface (cm/yr)')
    ax[1][0].text(1450., 2e2, 'B', fontweight='bold', fontsize='12')



    for i in np.arange(0,total_runs):
        # total crustal volume
        ax[2][0].plot(time_Myr_run[i],(crust_bs[i]+crust_TTG[i]), linestyle='-', linewidth=1, alpha=1, color = palette[i],marker=symbol[i], markersize=3, mec=palette[i], mew='1', markevery = marker_frequency[i], mfc=palette[i])

    ax[2][0].set_xlim(0., 1500.)
    ax[2][0].set_ylim(0., 2e10)
    ax[2][0].set_xscale('linear')
    ax[2][0].set_yscale('linear')
    ax[2][0].set_ylabel('Crustal volume (km$^3$)')
    ax[2][0].text(1450., 1.8e10, 'C', fontweight='bold', fontsize='12')



    for i in np.arange(0,total_runs):
        # TTG mass
        ax[3][0].plot(time_Myr_step_run[i],   TotalTTGmass_run[i],   linestyle='-', linewidth=1, alpha=1,   color = palette[i],marker=symbol[i], markersize=3, mec=palette[i], mew='1', markevery = marker_frequency_TTG, mfc=palette[i])
        ax[3][0].plot(time_Myr_run[i],   RemainingTTG_mass_run[i],   linestyle='-', linewidth=1, alpha=0.5, color = palette[i],marker=symbol[i], markersize=3, mec=palette[i], mew='1', markevery = marker_frequency[i],  mfc=palette[i])

    ax[3][0].set_xlim(0., 1500.)
    ax[3][0].set_ylim(0., 5.e22)
    ax[3][0].set_xscale('linear')
    ax[3][0].set_yscale('linear')
    ax[3][0].set_ylabel('TTG mass (kg)')
    ax[3][0].set_xlabel('Time (Myr)')
    ax[3][0].text(1450., 4.6e22, 'D', fontweight='bold', fontsize='12')

    labels_a1 = ('A1','A7','A13','B5','B8','B12','C1','C8')
    labels_a3 = ('produced','remaining in top 100 km')

    ax[0][0].legend(ncol = 8, labels = (labels_a1), loc='upper center', labelspacing = 0.5, facecolor = 'none', edgecolor = 'none', fancybox = False)
    ax[3][0].legend(ncol = 2, labels = (labels_a3), loc='upper center', labelspacing = 0.5, facecolor = 'none', edgecolor = 'none', fancybox = False)

    file_name='fig6.pdf'
    fig.tight_layout()

    plt.savefig(here+'/'+file_name)
