import glob, shutil
from datetime import datetime
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.ticker import LogFormatter, LogFormatterSciNotation
import numpy as np
import pandas as pd
from itertools import zip_longest, cycle
from scipy.optimize import curve_fit
import scipy
from driftrate import scilabel
import driftrate
import corner
import arrivaltimes

plt.rcParams.update({'font.size': 12})
plt.rcParams['toolbar'] = 'toolbar2'
# nu*dtdnu vs t

### Figure options
annotate                 = False
markBurst                = None
usesymlog                = False
xscale                   = 'log' # duration scale for dtdnu plot
plotinset                = True # ultra frbs
plotinset2               = False # long frbs
plot_drifts              = True
crop_view                = True
sourcefilter             = 'frb20121102A' # e.g. 'frb20121102A' or 'frb20220912A' or None
use_dmoptimized_duration = False # Use a spreadsheet of dm corrected measurements to overwrite durations
plot_additional_data     = False

savefig      = False # Save slope law figure
savehist     = False # Save burst properties histogram
savecorner   = False # Save corner plot
savedrift    = False
###

files = glob.glob('measurements/*/*/*.csv')
# [print(f) for f in sorted(files)]
# exit()

DM_FRB20121102A = 560.105 # DM-of-Shortest. apply DM of B30 of Snelders+2023, then apply eq. B4 of Jahns to find deltaDM
DM_FRB20220912A = 219.356 # DM-of-Shortest. apply DM of B1 of Hewitt+2023, then apply eq. B4 of Jahns to find deltaDM

colors = cycle([
	'black',
	'red',
	'green',
	'blue',
	'yellow',
	'darkgreen',
	'brown',
	'tomato'
])

tagColors = [
	'#D97E9F',
	'#67618C',
	'#F2AB27',
	'#F29472',
	'#73544D',
	'#0000FF',
	'#FF0000'
]
colors = cycle(tagColors[3:])

def resetmarkers():
	return cycle(['o','d','s','*','p','^','v','<','>','X','P','h'])
markers = resetmarkers()

measdf = pd.read_csv('allmeasurements_prefilter.csv').set_index('name')
print(f"# of measurements pre-filter: {len(measdf)}")

## Copy dm optimized durations to measurements sheet
if use_dmoptimized_duration:
	dmoptdf = pd.read_csv('allmeasurements_postfilter_dmoptimized.csv').set_index('name')
	commonidx = measdf.index.intersection(dmoptdf.index)
	olddurs = measdf.loc[commonidx, 'duration (ms)']
	newdurs = dmoptdf.loc[commonidx, 'duration (ms)']
	# plt.scatter(olddurs, newdurs, edgecolor='k')
	# plt.plot(range(8), range(8), 'k--')
	# plt.show()
	measdf.loc[commonidx, 'duration (ms)'] = dmoptdf.loc[commonidx, 'duration (ms)']
	measdf = measdf.drop(measdf.index.difference(dmoptdf.index))

##### Measurement Filters ######
measdf = measdf[measdf['duration (ms)'] < 100]
measdf = measdf[measdf['center_f (MHz)'] != 1]

# Compute any drift rates after minimum neccessary filters
drifts = arrivaltimes.measuredrifts(measdf, show_plot=False, verbose=False)
driftdf = pd.DataFrame(
	data=drifts,
	columns=arrivaltimes.drift_columns
).set_index('name')

driftdf['nu_drift'] = driftdf['center_f (MHz)']*driftdf['drift (ms/MHz)']
driftdf.to_csv('alldrifts.csv')
numdrifts = len(driftdf)
driftdf = driftdf[driftdf['drift (ms/MHz)'] != -1] # -1 indicates no fit

## ACF drift measurements
driftacfdf = pd.read_csv('alldrifts_acf.csv').set_index('name')
driftacfdf['drift (ms/MHz)'] = driftacfdf['dtdnu_ACF (ms/MHz)']
# copy durations from arrival times measurements
print(
	"Bursts with ACF drifts but no arrival times drift: \n",
	driftacfdf.index.difference(driftdf.index)
)
commonidx = driftacfdf.index.intersection(driftdf.index)
print(f"{len(driftdf) = }, {len(driftacfdf) = }, {len(commonidx) = }")
print(driftacfdf[driftacfdf.index.duplicated()])
print(driftdf[driftdf.index.duplicated()])
driftacfdf.loc[commonidx, 'duration (ms)'] = driftdf.loc[commonidx, 'duration (ms)']
# Copy sources
driftacfdf.loc[commonidx, 'source'] = driftdf.loc[commonidx, 'source']
##

driftdf = driftdf[abs(driftdf['drift_err'])/abs(driftdf['drift (ms/MHz)']) < 1]
numdrifts_post = len(driftdf)
print(f"Calculated {numdrifts} drift rates. Dropped {numdrifts - numdrifts_post}.")

## ACF
driftacfdf['drift_err'] = driftacfdf['dtdnu_ACF_err']
driftacfdf = driftacfdf[driftacfdf['drift (ms/MHz)'] != -1] # -1 indicates no fit
driftacfdf = driftacfdf[abs(driftacfdf['drift_err'])/abs(driftacfdf['drift (ms/MHz)']) < 1]
driftacfdf = driftacfdf[abs(driftacfdf['drift_err'])/abs(driftacfdf['drift (ms/MHz)']) < 1]
driftacfdf = driftacfdf[abs(driftacfdf['center_f_err'])/abs(driftacfdf['center_f (MHz)']) < 1] # stragglers
driftacfdf = driftacfdf[abs(driftacfdf['duration_err'])/abs(driftacfdf['duration (ms)']) < 1]
##

# Copy waterfalls with drift rate measurements to separate folder for review
print(f"Copying {len(driftdf.index)} drift rate figures to measurements/driftrates/ for review...")
for bid in driftdf.index:
	bdm = driftdf[driftdf.index == bid]['DM'].iloc[0]
	bfiles = glob.glob(f'measurements/*/*/{bid}_DM{bdm:.3f}*')
	if len(bfiles) > 0:
		# print(bid, bfiles[0], bfiles[0].split('/')[-1])
		shutil.copy(
			bfiles[0],
			f"measurements/driftrates/{bfiles[0].split('/')[-1]}",
		)

measdf = measdf[measdf['dtdnu (ms/MHz)'] != 0]
# measdf = measdf[measdf['center_f (MHz)'] > 3000] # Frequency filter
# measdf = measdf[measdf['dtdnu (ms/MHz)'] < 0]
measdf = measdf[abs(measdf['duration_err'])/abs(measdf['duration (ms)']) < 1]
measdf = measdf[abs(measdf['dtdnu_err'])/abs(measdf['dtdnu (ms/MHz)']) < 1] # this can maybe be relaxed with the S/N filter. 2? 3?
measdf = measdf[measdf['num_arrtimes'] > 2]
measdf = measdf[abs(measdf['center_f_err'])/abs(measdf['center_f (MHz)']) < 1] # stragglers
# print(measdf.loc[measdf['center_f_err'] > 2000][['center_f (MHz)', 'center_f_err']])

### Manual drops
# Sheikh B11 comp i,j, and k
measdf = measdf.drop([
	# 'burst_B31_b', # Low snr, pixelated, few arrival times, large uncertainties
	## Manual DMs. Drop when calculating δDM
	# 'FRB121102_tracking-M01_0264_a',
	# 'FRB121102_tracking-M01_0264_b',
	# 'FRB121102_tracking-M01_0264_c',
])
################################

## Post filters dataset
measdf.to_csv('allmeasurements_postfilter.csv')
print(f"# of measurements post-filter: {len(measdf)}")

# print(measdf[measdf['dtdnu (ms/MHz)'] > 0])

# Fade out positive slopes, but don't remove them
# measdf.loc[measdf['dtdnu (ms/MHz)'] > 0, 'alpha'] = 0.1

# ultra-FRBs table for paper
def ultraFRBs_table(df):
	# df = df.sort_values(by='duration (ms)')
	print()
	table = (
		## preamble
		'\\begin{table*}\n'
		'\\begin{centering}\n'
		'\\begin{tabular}{lllrrrr}\n'
		'\\toprule\n'
		## columns
		r'\textbf{Source} & '
		r'\textbf{Burst ID} &'
		r'\textbf{DM pc/cm$^3$} & '
		r'\textbf{$\nu_0$ MHz} & '
		r'\textbf{$\sigma_t$ $\upmu$s} & '
		r'\textbf{$\sigma_\nu$ MHz} & '
		r'\textbf{$\text{d}t/\text{d}\nu$ ms/MHz} \tabularnewline' '\n'
		'\\midrule\n'
		'\\midrule\n'
	)
	for source, nrows in zip(['frb20121102A', 'frb20220912A', 'frb20200120E'], [17, 16,25]):
		sdf = df.loc[df.source == source].sort_values(by='duration (ms)').head(nrows)
		# print(sdf[['DM', 'center_f (MHz)', 'duration (ms)']])
		firstrow = True
		for i, row in sdf.iterrows():
			ret = (
				f"{'FRB '+source[3:] if firstrow else ''} & "
				f"{'-'.join(i.split('_'))} & "
				f"{row['DM']} & "
				f"{row['center_f (MHz)']:.0f} $\\pm$ {row['center_f_err']:.0f} & "
				f"{row['duration (ms)']*1000:.2f} $\\pm$ {row['duration_err']*1000:.1f}& "
				f"{row['bandwidth (MHz)']:.0f} $\\pm$ {row['bandwidth_err']:.0f} & "
				f"{scilabel(row['dtdnu (ms/MHz)'], row['dtdnu_err'])} \\tabularnewline\n"
			)
			table += ret
			firstrow = False
		table
		table += (
			' & \\textbf{...} & \\textbf{...} & \\textbf{...} & \\textbf{...} & \\textbf{...} & \\textbf{...} \\tabularnewline\n'
			'\\midrule\n'
		)
	table += (
		"\\bottomrule\n"
		"\\end{tabular}\n"
		"\\par\\end{centering}\n"
		"\\caption{}\\label{tab:appultra}\n"
		"\\end{table*}\n"
	)
	print(table)

# ultraFRBs_table(measdf)
# exit()

# print nice info
print(f"Total measurements: {len(measdf) = }")
# Include # of bursts in title
figtitle = f'{len(measdf)} bursts'

##### Plot Inverse Slope vs. duration
measdf['slope'] = 1/measdf['nudtdnu'] #measdf['center_f (MHz)']1/measdf['dtdnu (ms/MHz)']

fig, [ax, axres] = plt.subplots(
	2, 1,
	figsize=(12,8.5),
	height_ratios=[5,1] if not plot_additional_data else [100,1], # cheesily hide the bottom panel
	sharex=True,
 	layout='constrained',
	# gridspec_kw={'wspace': 0.0001, 'hspace': 0.001, 'h_pad':0}
)

# fig, axs = plt.subplot_mosaic(
# 	'AC;BC',
# 	figsize=(12,8),
# 	height_ratios=[4,1],
# 	width_ratios=[20,1],
# 	sharex=True
# )
# ax, axres, axc = axs['A'], axs['B'], axs['C']

# Color by frequency. Marker by source
freqcmap = cm.ScalarMappable(
	# norm=mpl.colors.Normalize(
	# 	vmin=100/1000,
	# 	vmax=7500/1000,
	# ),
	norm=mpl.colors.LogNorm(
		vmin=100/1000,
		vmax=7500/1000,
		# vmin=measdf['center_f (MHz)'].min(),
		# vmax=measdf['center_f (MHz)'].max()
	),
	# cmap=cm.cividis,
	cmap=cm.viridis
)
dlabel = False
for source in measdf.source.unique():
	sourcedf = measdf[(measdf.source == source)]
	marker = next(markers)
	labelled = False
	for dataset in sourcedf.dataset.unique():
		df = sourcedf[sourcedf.dataset == dataset]
		print(f"{source}\t{dataset}\t{len(df)}")
		ax.scatter(
			df['duration (ms)'],
			df['center_f (MHz)']*df['dtdnu (ms/MHz)'],
			# color=df['c'],
			color=freqcmap.to_rgba(df['center_f (MHz)']/1000),
			marker=marker,
			# label=f"{source.upper()} {dataset.capitalize()}",
			label=f"{source.upper()}" if not labelled else '' ,
			alpha=df['alpha'],
			# s=20,
			zorder=1,
			edgecolor=(0,0,0,0.1)
		)
		labelled = True
		ax.errorbar(
			df['duration (ms)'],
			df['center_f (MHz)']*df['dtdnu (ms/MHz)'],
			xerr=df['duration_err'],
			yerr=np.sqrt((df['center_f (MHz)']*df['dtdnu_err'])**2 + (df['center_f_err']*df['dtdnu (ms/MHz)'])**2),
			# ecolor=df['c'],
			ecolor=freqcmap.to_rgba(df['center_f (MHz)']/1000),
			alpha=0.5,
			zorder=0,
			linestyle='none'
		)
		for name, row in df.iterrows():
			if annotate:
				ax.annotate(
					name[-15:],
					(
						row['duration (ms)'],
						row['center_f (MHz)']*row['dtdnu (ms/MHz)']
					),
					xytext=(5,5),
					textcoords='offset points'
				)

	if plot_drifts: # plot drift rates
		sdriftdf = driftdf[(driftdf.source == source)]
		ax.scatter(
			sdriftdf['duration (ms)'],
			sdriftdf['center_f (MHz)']*sdriftdf['drift (ms/MHz)'],
			color=freqcmap.to_rgba(sdriftdf['center_f (MHz)']/1000),
			marker=marker,
			label=f"Drift rates $\\nu (\\Delta t / \\Delta \\nu)$ ({len(driftdf)} total)" if not dlabel else '' ,
			# alpha=sdriftdf['alpha'],
			s=100,
			zorder=10,
			edgecolor=(1,1,1,0.75)
		)
		dlabel = True
		ax.errorbar(
			sdriftdf['duration (ms)'],
			sdriftdf['center_f (MHz)']*sdriftdf['drift (ms/MHz)'],
			xerr=sdriftdf['duration_err'],
			yerr=np.sqrt((sdriftdf['center_f (MHz)']*sdriftdf['drift_err'])**2 + (sdriftdf['center_f_err']*sdriftdf['drift (ms/MHz)'])**2),
			# ecolor=sdriftdf['c'],
			ecolor=freqcmap.to_rgba(sdriftdf['center_f (MHz)']/1000),
			alpha=0.5,
			zorder=0,
			linestyle='none'
		)

# Duration Range
for source in measdf.source.unique():
	sdf = measdf[(measdf.source == source)]
	print(f"{source} duration range: {min(sdf['duration (ms)']):.3e} -- {max(sdf['duration (ms)']):.3f} ms")

# Frequency color mapping
cb = fig.colorbar(
	freqcmap,
	ax=ax,
	label="Frequency (GHz)",
	fraction=0.02,
	# ticks=[100, 1000, 2000, 3000, 4000, 5000, 6000, 7000],
	ticks=[1, 2, 3, 4, 5, 6, 7],
	format=LogFormatter(labelOnlyBase=False, minor_thresholds=(10,0.4))
)
# cb.set_ticks([100, 4000, 7000])

# Burst locator:
if markBurst:
	if type(markBurst) == str:
		markBurst = [markBurst]
	for mbs in markBurst:
		ax.axvline(x=measdf.loc[mbs]['duration (ms)'], c='r', ls='--')
		ax.axhline(y=measdf.loc[mbs]['nudtdnu'], c='r', ls='--')

##### Fits
# Use min/max periodically to check points aren't being cut off
t = np.linspace(
	# min(measdf['duration (ms)']),
	# max(measdf['duration (ms)']), #30
	0.001,
	360,
	num=10000
)

ax.plot(t, 0*t, 'k--', lw=1, alpha=0.5, zorder=-2)
ax.plot(
	t,
	0.4*t,
	'k-.',
	lw=1,
	alpha=0,#0.5,
	# label="0.4t (visual fit)"
)

if plot_additional_data:
	additional_resids = []
	def splitdata(tabdata):
		dur          = [g[0] for g in tabdata]
		dur_err      = [g[1] for g in tabdata]
		slope_dnudt  = np.array([g[2] for g in tabdata])
		slope_err    = np.array([g[3] for g in tabdata])
		slope_dtdnu_err = - (slope_err / (slope_dnudt**2))

		print(f"{dur = } {dur_err=}")
		print(f"{slope_dnudt = } {slope_dtdnu_err = }")
		return dur, dur_err, slope_dnudt, slope_dtdnu_err

	## Gopinath+2024 FRB 20180916B
	# Data from Table 1, Table 2, and Table B1
	gopinath2024 = [ # dur, dur_err, drift (MHz/ms), driftt_err (MHz/ms)
		[4.60e+01, 1.00e+00, -5.00e-01, 2.00e-01],
		[3.84e+01, 9.00e-01, -4.00e-01, 2.70e+00],
		[9.70e+01, 1.00e+00, -3.00e-03, 6.00e-02],
		[6.60e+01, 4.00e+00, -2.00e-02, 2.00e-01],
		[4.70e+01, 1.00e+00, -3.00e-01, 2.00e-01],
		[1.50e+01, 6.00e+00, -2.00e-01, 1.00e+00],
		[30.5, 0.6, -0.7, 0.3],
		[47.7, 2.1, -0.37, 1.1],
		[90.7, 1.7, -0.02, 3.8],
		[6.9, 0.6, 0, 1.6],
		[37.5, 2.7, -3.66, 7.2],
		[26.8,0.6,-0.51,0.2],
		[26.0,0.4,-1.21,0.2],
		[40.7,1.3,-1.36,2],
		[26.9,1,-2.02,0.2],
		[20.3,0.8,-0.92,1.7],
		[19.8,1.2,-3.63,0.9],
		[24.4,1,-1.58,0.9],
		[25.1,0.2,-1.19,0.2]
	]

	dur, dur_err, slope, slopeinv_err = splitdata(gopinath2024)
	ax.errorbar(
		dur,
		150*(1/slope),
		xerr=dur_err,
		yerr=np.abs(slopeinv_err),
		label='Gopinath+2024 (FRB 20180916B)',
		linestyle='none',
		marker='o',
		markeredgecolor='k',
		c=freqcmap.to_rgba(150/1000)
	)
	additional_resids.append(150*(1/slope))

	## Pastor-Marazuela+2021 FRB 20180916B
	pmdf = pd.read_csv('PM2021_arts_r3_properties.csv').set_index("paper_name")
	pmdf = pmdf.loc[pd.notna(pmdf.drift_rate), ['fcen', 'width_ms', 'drift_rate', 'drift_rate_err']]
	dur, slope_dnudt, slope_err = pmdf['width_ms'], pmdf['drift_rate'], pmdf['drift_rate_err']
	fcen = pmdf['fcen']
	slope_dtdnu_err = - (slope_err / (slope_dnudt**2))
	ax.errorbar(
		dur,
		fcen*(1/slope_dnudt),
		# xerr=dur_err,
		yerr=np.abs(slope_dtdnu_err),
		label='Pastor-Marazuela+2021 (FRB 20180916B)',
		linestyle='none',
		marker='d',
		markeredgecolor='k',
		# c='y'
		c=freqcmap.to_rgba(fcen.mean()/1000)
	)
	additional_resids.append(fcen*(1/slope_dnudt))

	# Hessels+2019 FRB 20121102A
	hessels2019 = [ # dur, dur_err, drift (MHz/ms), driftt_err (MHz/ms)
		[1.03, 0, -204, 0],
		[0.19, 0, -122, 0],
		[0.25, 0, -187, 0],
		[0.30, 0, -221, 0],
		[0.34, 0, -46, 0],
		[0.31, 0, -129, 0],
		[0.24, 0, -128, 0],
		[0.43, 0, -140, 0],
		[0.20, 0, -205, 0],
		[0.23, 0, -50, 0],
		[0.14, 0,  0, 0],
		[0.35, 0, -168, 0],
		[0.17, 0, -286, 0],
		[0.13, 0, -237, 0],
		[0.16, 0, -251, 0],
		[0.30, 0, -141, 0],
		[0.40, 0, -276, 0],
		[0.13, 0, -865, 0],
	]
	dur, dur_err, slope, slopeinv_err = splitdata(hessels2019)
	ax.errorbar(
		dur,
		1500*(1/slope),
		xerr=dur_err,
		yerr=np.abs(slopeinv_err),
		label='Hessels+2019 (FRB 20121102A)',
		linestyle='none',
		markeredgecolor='k',
		marker='s',
		c=freqcmap.to_rgba(1500/1000)
	)
	additional_resids.append(1500*(1/slope))

	# Zhou 2022
	zhou2022freqs = np.array([
		1153.8,
		1188.8,
		1151.3,
		1124.9,
		1127.5,
		1250.7,
		1057.3,
		1146.2,
		1280.2,
		1367.0,
		1249.7,
		1383.2,
		1342.4,
		1151.2,
		1413.9,
		1204.3,
		1165.0,
		1129.1,
		1260.1,
		1174.1,
		1192.6,
		1269.4,
		1160.2,
		1217.6,
		1135.2,
		1333.2,
		1215.1,
		1256.5,
		1351.3,
		1398.0,
		1298.2,
	])
	zhou2022 = [
		[10.9, 0, -58, 7],
		[11.4, 0, -64, 8],
		[10.3, 0, -95, 12],
		[11.0, 0, -59, 7],
		[12.8, 0, -54, 7],
		[11.6, 0, -77, 10],
		[18.6, 0, -55, 7],
		[57.4, 0, -58, 7],
		[10.7, 0, -68, 9],
		[4.3, 0, -166, 18],
		[20.9, 0, -36, 4],
		[21.7, 0, -58, 7],
		[10.4, 0, -78, 10],
		[15.7, 0, -41, 5],
		[9.4, 0, -62, 7],
		[12.5, 0, -104, 13],
		[13.3, 0, -52, 6],
		[7.8, 0, -93, 12],
		[9.9, 0, -62, 8],
		[10.9, 0, -61, 5],
		[15.7, 0, -66, 8],
		[6.8, 0, -128, 17],
		[18.4, 0, -28, 3],
		[7.3, 0, -104, 14],
		[10.6, 0, -54, 7],
		[80, 0, -74, 9],
		[13.2, 0, -50, 6],
		[9.1, 0, -86, 11],
		[8.4, 0, -74, 9],
		[16.2, 0, -42, 5],
		[9.1, 0, -90, 12]
	]
	dur, dur_err, slope, slopeinv_err = splitdata(zhou2022)
	ax.errorbar(
		dur,
		zhou2022freqs*(1/slope),
		xerr=dur_err,
		yerr=np.abs(slopeinv_err),
		label='Zhou+2022 (FRB 20201124A)',
		linestyle='none',
		markeredgecolor='k',
		marker='p',
		c=freqcmap.to_rgba(zhou2022freqs.mean()/1000)
	)
	additional_resids.append(zhou2022freqs*(1/slope))
#

def fitdata(model, x, y, p0=[], **kwargs):
	popt, pcov = curve_fit(
		model,
		x/np.max(x),
		y/np.max(y),
		p0=p0,
		**kwargs
	)
	return popt, pcov

def line(x, m, b):
	return m*x+b

def fitsource(df, ax=None, silent=False):
	# popt, pcov = fitdata(
	# 	line,
	# 	df['duration (ms)'],
	# 	df['center_f (MHz)']*df['dtdnu (ms/MHz)'],
	# 	p0=[-10,0],
	# 	# sigma=
	# )
	# print(popt, np.sqrt(np.diag(pcov)))
	if not silent: print("Normalized inverse slope:")
	odrjob = scipy.odr.ODR(
		scipy.odr.RealData(
			df['duration (ms)'],
			df['center_f (MHz)']*df['dtdnu (ms/MHz)'],
			sx=df['duration_err'],
			sy=np.sqrt((df['center_f (MHz)']*df['dtdnu_err'])**2 + (df['center_f_err']*df['dtdnu (ms/MHz)'])**2),
		),
		scipy.odr.Model(lambda B, x: B[0]*x + B[1]),
		beta0=[-1, 0]
	)
	odrjob.set_job(fit_type=0)
	odrfit = odrjob.run()
	if not silent: odrfit.pprint()

	fit_redchisq = odrfit.res_var # https://stackoverflow.com/a/21406281/3133399
	if fit_redchisq > 0.1: # overfit cutoff
		fiterr = odrfit.sd_beta
	else:
		fiterr = np.sqrt(np.diag(odrfit.cov_beta))
	# print(f"{fiterr = } {fit_redchisq = :.3f}")

	if ax:
		ax.plot(
			t,
			odrfit.beta[0]*t + odrfit.beta[1],
			'k-.',
			# label=f"{odrfit.beta[0]:.2f}t + {odrfit.beta[1]:.3f}"
			label=f'{scilabel(odrfit.beta[0], fiterr[0])} $\sigma_t$ + {scilabel(odrfit.beta[1], fiterr[1])}',
			zorder=-1
		)
	return odrfit

r1df = measdf
if sourcefilter:
	r1df = measdf.loc[measdf.source == sourcefilter]
# r1df = measdf.loc[measdf.source == 'frb20121102A']
# r1df = measdf.loc[measdf.source == 'frb20220912A']
# r1df = r1df[(r1df['dtdnu (ms/MHz)'] < 0)]

## Main fit for plot
if len(r1df) > 0:
	odrfit = fitsource(r1df, ax)
	# print(f"{odrfit.sd_beta = } {np.sqrt(np.diag(odrfit.cov_beta)) = }")
	normfit = odrfit
	normfit_redchisq = odrfit.res_var

## Source fits for paper table
print("Source fits:")
for s in measdf.source.unique():
	sdf = measdf.loc[measdf.source == s]
	odrfit = fitsource(sdf, silent=True)
	fit_redchisq = odrfit.res_var # https://stackoverflow.com/a/21406281/3133399
	if fit_redchisq > 0.1: # overfit cutoff
		fiterr = odrfit.sd_beta
	else:
		fiterr = np.sqrt(np.diag(odrfit.cov_beta))
	fitlbl = f'{scilabel(odrfit.beta[0], fiterr[0])}$\sigma_t$ + {scilabel(odrfit.beta[1], fiterr[1])}'
	# print(f'{s}: {fitlbl}')
	# print(f'{s.upper()}\t{fitlbl}\t{fit_redchisq:.2f}\t{len(sdf)} | {odrfit.beta[0] = } {fiterr[0] = }')

	# Print table rows
	print(
		f'{s.upper()} & '
		f'{scilabel(odrfit.beta[0], fiterr[0])} & '
		f'{scilabel(odrfit.beta[1], fiterr[1])} & '
		f'{fit_redchisq:.2f} & '
		f'{len(sdf)} \\tabularnewline'
	)

#### Gaussian/autocorrelation fits (i.e. earlier fits)
# ax.plot(
# 	t,
# 	-10*t,
# 	'g-.',
# 	lw=1.5,
# 	label='-10t (gauss/autocorr fit result)',
# 	zorder=-2
# )
frange = [-1/0.058, -1/0.122]
# ax.fill_between(
# 	t,
# 	(frange[0])*t, (frange[1])*t,
# 	alpha=0.2,
# 	color=tagColors[3],
# 	edgecolor='k',
# 	# label=f'Range of fits: {frange[0]:.1f} $\leq\,A_1\,\leq$ {frange[1]:.1f}',
# 	label='Earlier fit range',
# 	zorder=-10
# )
####

ax.set_xscale(xscale)
if usesymlog: ax.set_yscale('symlog', linthresh=0.8)

# ax.set_xlabel('Duration $\\sigma_t$ (ms)')
ax.set_ylabel('Normalized Sub-burst Slope $\\nu (\\text{d}t/\\text{d}\\nu) $ (ms)')
h, l = ax.get_legend_handles_labels()
if plot_drifts:
	ax.legend(
		ncols=2,
		## Move drift rates label to the end
		handles=[h[0]]+h[2:-1]+[h[1],h[-1]],
		labels=[l[0]]+l[2:-1]+[l[1],l[-1]],
		## Poster legend labels:
		# handles=[h[1],h[8]]+h[10:],
		# labels=['FRB 20121102A (7 Datasets)', 'FRB 20220912A (3 Datasets)']+l[10:]
	)
else:
	ax.legend(ncols=2)

ax.set_title(figtitle)

# Inset zoom on ultraFRBs
markers = resetmarkers()
if plotinset:
	maxdur = 0.3 # 300 μs
	axins = ax.inset_axes(
		# [0.001, -80, 0.1, 60],
		[0.05, 0.35, 0.5, 0.4], # fig fraction [x0, y0, width, height]
		# [0.05, 0.1, 0.5, 0.4], # burst dm
		xlim=(0.0011, maxdur),
		ylim=(-2.5,2.5),
		# xticklabels=[],
		# yticklabels=[]
	)
	numufrbs = len(measdf[(measdf['duration (ms)'] < maxdur)])
	for source in measdf.source.unique():
		sourcedf = measdf[(measdf.source == source)]
		marker = next(markers)
		for dataset in sourcedf.dataset.unique():
			df = sourcedf[sourcedf.dataset == dataset]
			print(f"{source}\t{dataset}\t{len(df)}")
			axins.scatter(
				df['duration (ms)'],
				df['center_f (MHz)']*df['dtdnu (ms/MHz)'],
				# color=df['c'],
				color=freqcmap.to_rgba(df['center_f (MHz)']/1000),
				marker=marker,
				label=f"{source.upper()} {dataset.capitalize()}",
				alpha=df['alpha'],
				# s=20,
				zorder=1,
				edgecolor=(0,0,0,0.1)
			)
			axins.errorbar(
				df['duration (ms)'],
				df['center_f (MHz)']*df['dtdnu (ms/MHz)'],
				xerr=df['duration_err'],
				yerr=np.sqrt((df['center_f (MHz)']*df['dtdnu_err'])**2 + (df['center_f_err']*df['dtdnu (ms/MHz)'])**2),
				# ecolor=df['c'],
				ecolor=freqcmap.to_rgba(df['center_f (MHz)']/1000),
				alpha=0.5,
				zorder=0,
				linestyle='none'
			)
	axins.plot(
		t,
		normfit.beta[0]*t + normfit.beta[1],
		'k-.',
		zorder=-2
	)
	axins.plot(t, 0*t, 'k--', lw=1, alpha=0.5, zorder=-2)
	axins.set_xscale('log')
	axins.set_title(f"{numufrbs} ultra-FRBs (μ-second long)")
	r, c = ax.indicate_inset_zoom(axins, edgecolor="black")
	c[1].set_visible(True) # plotting drifts turns off these lines for some reason
	c[2].set_visible(True)

# Inset zoom on main cluster
markers = resetmarkers()
if plotinset2:
	axins2 = ax.inset_axes(
		# [0.001, -80, 0.1, 60],
		[0.55, 0.1, 0.4, 0.3], # [x0, y0, width, height] in fig fraction
		xlim=(0.5, 9),
		ylim=(-50, 10),
		# xticklabels=[],
		# yticklabels=[]
	)
	for source in measdf.source.unique():
		sourcedf = measdf[(measdf.source == source)]
		marker = next(markers)
		for dataset in sourcedf.dataset.unique():
			df = sourcedf[sourcedf.dataset == dataset]
			print(f"{source}\t{dataset}\t{len(df)}")
			axins2.scatter(
				df['duration (ms)'],
				df['center_f (MHz)']*df['dtdnu (ms/MHz)'],
				# color=df['c'],
				color=freqcmap.to_rgba(df['center_f (MHz)']/1000),
				marker=marker,
				label=f"{source.upper()} {dataset.capitalize()}",
				alpha=df['alpha'],
				# s=20,
				zorder=1,
				edgecolor=(0,0,0,0.1)
			)
			axins2.errorbar(
				df['duration (ms)'],
				df['center_f (MHz)']*df['dtdnu (ms/MHz)'],
				xerr=df['duration_err'],
				yerr=np.sqrt((df['center_f (MHz)']*df['dtdnu_err'])**2 + (df['center_f_err']*df['dtdnu (ms/MHz)'])**2),
				# ecolor=df['c'],
				ecolor=freqcmap.to_rgba(df['center_f (MHz)']/1000),
				alpha=0.5,
				zorder=0,
				linestyle='none'
			)
	axins2.plot(
		t,
		normfit.beta[0]*t + normfit.beta[1],
		'k-.',
		zorder=-2
	)
	axins2.plot(t, 0*t, 'k--', lw=1, alpha=0.5, zorder=-2)
	axins2.set_xscale('log')
	axins2.set_title("low frequency")
	ax.indicate_inset_zoom(axins2, edgecolor="black")

# print(f"{ax.get_xlim() = }")
# print(f"{ax.get_ylim() = }")

# ax.set_xlim(0.001377009069009, 10.014339667758495)
# ax.set_ylim(-181.55538710690752, 14.120496647008299)
if crop_view:
	# ax.set_xlim(min(t),max(t))
	ax.set_xlim(0.001, 20)
	ax.set_ylim(-150, 20.3)

## residuals
resmod = normfit.beta[0]*measdf['duration (ms)'] + normfit.beta[1]
residuals = measdf['center_f (MHz)']*measdf['dtdnu (ms/MHz)'] - resmod
res_err = np.sqrt((measdf['center_f (MHz)']*measdf['dtdnu_err'])**2 + (measdf['center_f_err']*measdf['dtdnu (ms/MHz)'])**2)
freqcolors = freqcmap.to_rgba(measdf['center_f (MHz)']/1000)

# sort multiple lists based on first one
# thank you https://stackoverflow.com/questions/11601961/sorting-multiple-lists-based-on-a-single-list-in-python
durs, dur_err, residuals, res_err, freqcolors = zip(*sorted(zip(
	measdf['duration (ms)'],
	measdf['duration_err'],
	residuals,
	res_err,
	freqcolors
)))

# axres.plot(durs, residuals, '-', linewidth=1, zorder=-1, c='gray')
axres.scatter(
	durs,
	residuals,
	marker='.',
	color=freqcolors,
	alpha=0.9,
	# linewidth=0.5,
	s=50,
	edgecolor=(0,0,0,0.1)
)
axres.errorbar(
	durs,
	residuals,
	xerr=dur_err,
	yerr=res_err,
	zorder=-2,
	ecolor=freqcolors,
	alpha=0.5,
	linestyle='none'
)
axres.plot(t, 0*t, 'k--', lw=1, alpha=0.5, zorder=-2)
axres.set_xlabel('Duration $\\sigma_t$ (ms)')
axres.set_ylabel('Residuals (ms)')

if plot_additional_data:
	fig.delaxes(axres)
	ax.set_xlabel("Duration $\\sigma_t$ (ms)")
##

print("Unnormalized inverse slope:")
odrmodel = scipy.odr.Model(lambda B, x: B[0]*x + B[1])
odrdata = scipy.odr.RealData(
	r1df['duration (ms)'],
	r1df['dtdnu (ms/MHz)'],
	sx=r1df['duration_err'],
	sy=r1df['dtdnu_err']
)
odrjob = scipy.odr.ODR(odrdata, odrmodel, beta0=[-1, 0])
odrjob.set_job(fit_type=0)
odrfit = odrjob.run()
odrfit.pprint()

adm = driftrate.a_dm
nu0bar = np.mean(r1df['center_f (MHz)'])
nu0bar_err = 0.1
dtb, dtb_err = odrfit.beta[0], np.sqrt(np.diag(odrfit.cov_beta))[0]
dtc = odrfit.beta[1] # residual drift from fitting linear model with all bursts
dtc_err = np.sqrt(np.diag(odrfit.cov_beta))[1]

δDM_global = - dtc * (nu0bar**3) / (2*adm)
# δDM_err = (1/(2*adm)) * np.sqrt(9*(nu0bar**4)*(dtc**2)*(nu0bar_err**2) + (nu0bar**6)*(dtc_err**2))
δDM_err = (1/(2*adm)) * ((nu0bar**3)*dtc_err + 3*(nu0bar**2)*nu0bar_err*dtc) # good approx
DM_applied = r1df['DM'].mean()

print(f"\nDM Optimization:\n{nu0bar = }")
print(f"{dtb = :.4e} +/- {dtb_err:.4e} ms/MHz")
print(f"{dtc = :.4e} +/- {dtc_err:.4e} ms/MHz")
print(f"{δDM_global = :.4f} +/- {δDM_err} pc/cm3")
print(f"{DM_applied = :.4f}")

DM_real = DM_applied + δDM_global
print(f"{DM_real = :.4f} pc/cm3")

# plt.tight_layout() # not needed if layout='constrained'
if savefig:
	plt.savefig("fig_nudtdnu.pdf")
	print("Saved fig_nudtdnu.pdf.")
plt.show() # sub burst slope law plot
plt.close()
# exit()

#########################
### Drift plot
#########################
plt.rcParams.update({'font.size': 14})
fig, ax = plt.subplots(
	1, 1,
	figsize=(7,7)
)

truncated_drifts = [
	'5-04',
	'9-02',
	'9-03',
	'9-04',
	'9-05',
	'9-08',
	'9-09',
	'9-11',
	'9-13',
	'9-14',
	'9-15',
	'9-16',
	'57638.486502332715',
	'57638.490931637_multi',
	'57642.455112897405',
	'57642.46957790253_multi',
	'57644.43017666884',
	'57644.430179356954',
	'57644.43017989982',
	'57644.448765269415_multi',
	'57644.46476354708_multi',
	'57645.43063245034_multi',
	'57645.44999348996_multi',
	'57648.39469526673',
	'57666.40086931443',
	'20201225B',
	'B02',
	'B04',
	'B05',
	'B06',
	'B15',
	'B19',
	'B20',
	'B21',
	'B26',
	'B27',
	'B29',
	'B31',
	'B34',
	'B38',
	'B41',
	'B45',
	'B57',
	'B65',
	'B66',
	'B71',
	'B72',
	'B73',
	'B87',
	'B89',
	'B003.1',
	'B007.1',
	'B048.1',
	'B109',
	'B121.1',
	'burst_4',
	'burst_11',
	'M007',
	'M010',
	'M013',
	'FRB121102_tracking-M01_0163',
	'FRB121102_tracking-M01_0415',
	'FRB121102_tracking-M01_0487',
	'FRB180301_20191008-M01_1509_sub',
	'M01_0098a',
	'M01_0098b',
	'M01_1463',
	'M01_1481',
	'FRB20180917A',
	'FRB20181028A',
	'FRB20190605B',
]

truncated_acfdrifts = [
	'5-04', '9-02', '9-03', '9-04', '9-05', '9-08', '9-09', '9-11', '9-13', #'9-15', '9-16',
	'57644.448765269415_multi', '57644.46476354708_multi', '57648.39469526673', '57666.40086931443',
	'20201225B',
	'B02', 'B04', 'B05', 'B06', 'B15', 'B19', 'B20', 'B21', 'B26', 'B27', 'B29', 'B31', 'B34',
	'B38', 'B41', 'B45', 'B57', 'B66', 'B71', 'B72', 'B73', 'B87',
	'B003.1', 'B007.1', 'B048.1', 'B109', 'B121.1',
	'burst_4', 'burst_11',
	'M007', 'M010', 'M013',
	'FRB121102_tracking-M01_0163', 'FRB121102_tracking-M01_0415', 'FRB121102_tracking-M01_0487',
	'M01_0098a', 'M01_1463', 'M01_1481',
	'FRB20180917A', 'FRB20181028A', 'FRB20190605B'
]
# driftdf = driftdf.drop(truncated_drifts)

driftdf = driftdf.drop([
	# 'B071',
	# '57638.486502332715', # dubious, low snr
	# 20220114_B5, # low snr but okay
	# FRB20190605B, # low snr, lots of rfi, may be ok
])

## Data Switch
# driftdf = driftacfdf
# truncated_drifts = truncated_acfdrifts
##

print(f"{len(driftdf) = } {len(truncated_drifts) = }")
print(f"Drift info: {min(driftdf['drift (ms/MHz)'].abs())} - {max(driftdf['drift (ms/MHz)'].abs())} ms/MHz")
for source in driftdf.source.unique():
	sourcedf = driftdf[(driftdf.source == source)]
	marker = next(markers)

	sdriftdf = driftdf[(driftdf.source == source)]
	ax.scatter(
		sdriftdf['duration (ms)'],
		sdriftdf['center_f (MHz)']*sdriftdf['drift (ms/MHz)'],
		# color=df['c'],
		color=freqcmap.to_rgba(sdriftdf['center_f (MHz)']/1000),
		marker=marker,
		# label=f"{source.upper()} {dataset.capitalize()}",
		label=f"{source.upper()}" if not labelled else '' ,
		# alpha=sdriftdf['alpha'],
		s=100,
		zorder=10,
		edgecolor=(1,1,1,0.75)
	)
	ax.errorbar(
		sdriftdf['duration (ms)'],
		sdriftdf['center_f (MHz)']*sdriftdf['drift (ms/MHz)'],
		xerr=sdriftdf['duration_err'],
		yerr=np.sqrt((sdriftdf['center_f (MHz)']*sdriftdf['drift_err'])**2 + (sdriftdf['center_f_err']*sdriftdf['drift (ms/MHz)'])**2),
		# ecolor=sdriftdf['c'],
		ecolor=freqcmap.to_rgba(sdriftdf['center_f (MHz)']/1000),
		alpha=0.5,
		zorder=1,
		linestyle='none'
	)

# Annotate truncated points
lbl = 'Truncated drift rates'
for _, row in driftdf.loc[truncated_drifts].iterrows():
	ax.scatter(
		row['duration (ms)'],
		row['center_f (MHz)']*row['drift (ms/MHz)'],
		facecolors='none',
		edgecolors='k',
		linewidth=0.5,
		# alpha=0.5,
		label=lbl,
		zorder=0,
		s=300,
	)
	if lbl != '': lbl = ''

t = np.linspace(
	0.001,#min(driftdf['duration (ms)']),
	150,#max(driftdf['duration (ms)']),
	num=10000
)

ax.plot(
	t,
	normfit.beta[0]*t + normfit.beta[1],
	'k-.',
	# label=f"{normfit.beta[0]:.2f}t + {normfit.beta[1]:.3f}"
	label=f'Sub-burst slope fit',
	zorder=-1
)
ax.plot(
	t,
	-normfit.beta[0]*t + normfit.beta[1],
	'k-.',
	# label=f"{normfit.beta[0]:.2f}t + {normfit.beta[1]:.3f}"
	# label=f'Sub-burst slope fit',
	zorder=-1,
	alpha=0.2
)
ax.plot(t, 0*t, 'k--', lw=1, alpha=0.5, zorder=-2)
ax.legend(loc=3)
ax.set_xscale('log')
ax.set_yscale('symlog', linthresh=0.1)

ax.set_ylim(-20_000, 15_000)
ax.set_xlim(0.007, 150)

## ACF measurements limits
# ax.set_ylim(-20000, 100_000)
# ax.set_xlim(0.001, 150)
# ax.set_title("ACF Drift Rates")

ax.set_xlabel("Duration (ms)")
ax.set_ylabel("Norm. Drift Rate $\\nu(\\Delta t/\\Delta \\nu)$ (ms)")
plt.tight_layout()

if savedrift:
	plt.savefig("fig_drift.pdf")
	print("Saved fig_drift.pdf")
plt.show() # Drift plot
plt.close()
# exit()

#########################
### Histogram plot
#########################

# 3 histograms of center frequency, duration, and bandwidth, broken down by source
plt.style.use('tableau-colorblind10')
# plt.style.use('seaborn-v0_8-colorblind')
fig2, axs = plt.subplot_mosaic(
	'ABC',
	figsize=(12,4.5)
)

fbins = []
durbins = []
μdurbins = []
bandbins = []
order = [ # print(measdf.source.unique())
	'frb20220912A', 'frb20121102A', 'frb20200120E', 'frb20180916b',
	'frb20201124a', 'frb20180301a', 'frb20180814a', 'frb20190804e',
	'frb20190915d', 'frb20200223b', 'frb20200929c', 'frb20201130a',
]

# for source in measdf.source.unique():
for source in order:
	sourcedf = measdf[(measdf.source == source)]
	fbins.append(sourcedf['center_f (MHz)'])
	durbins.append(sourcedf['duration (ms)'])
	μdurbins.append(sourcedf[sourcedf['duration (ms)'] < 1]['duration (ms)']*1000)
	bandbins.append(sourcedf['bandwidth (MHz)'])

nbins = (40, 100, 40)
nbins = (50,50,50)
labels = [s.upper() for s in order]

# cmap = mpl.colormaps['cividis_r']
# colors = cmap(np.linspace(0, 1, 12))
axs['A'].hist(
	fbins,
	nbins[0],
	histtype='stepfilled',
	stacked=True,
	label=labels,
	edgecolor='k',
	linewidth=0.5,
	# color=colors
)
axs['A'].set_xlim(0, 7.89e3)

axs['B'].hist(
	durbins,
	nbins[1],
	histtype='stepfilled',
	stacked=True,
	label=labels,
	edgecolor='k',
	linewidth=0.5,
	# color=colors
)
axs['B'].set_xlim(0, 9) # mainly to set left edge

# Inset on duration historgram for microbursts
axins = axs['B'].inset_axes(
	[0.43, 0.5, 0.5, 0.4], # fig fraction [x0, y0, width, height]
	xlim=(0, 1000),
	# ylim=(0,40),
	# xticklabels=[],
	# yticklabels=[]
)

μnbins = 25
axins.hist(
	μdurbins,
	μnbins,
	histtype='stepfilled',
	stacked=True,
	label=labels,
	edgecolor='k',
	linewidth=0.5,
	# color=colors
)
axins.set_title("ultra-FRBs", fontsize=12)
axins.set_xlabel("Duration (μs)")

axs['C'].hist(
	bandbins,
	nbins[2],
	histtype='stepfilled',
	stacked=True,
	label=labels,
	edgecolor='k',
	linewidth=0.5,
	# color=colors
)
axs['C'].set_xlim(0, 460) # mainly to set left edge


axs['A'].set_xlabel("Center Frequency (MHz)")
axs['A'].set_ylabel('Number of sub-bursts')
axs['B'].set_xlabel('Sub-burst Duration (ms)')
axs['C'].set_xlabel('Sub-burst Bandwidth (MHz)')

axs['B'].sharey(axs['A'])
axs['C'].sharey(axs['A'])

axs['B'].yaxis.set_tick_params(labelleft=False)
axs['C'].yaxis.set_tick_params(labelleft=False)
for _, p in axs.items():
	p.set_title('')
	p.tick_params(labelsize=12)
	p.set_yscale("log")

plt.legend(
	loc='best',
	fontsize=10,
	# bbox_to_anchor=(0.5, -0.05),
	# ncols=6,
	reverse=True,
	# draggable=True
)
plt.tight_layout()
plt.subplots_adjust(wspace=0)
if savehist:
	plt.savefig("fig_hist.pdf")
	print("Saved fig_hist.pdf")
# plt.show() # histo
plt.close()
# exit()

#########################
### Corner plot:
#########################

plt.rcParams.update({'font.size': 13})

cornercols = [
	'bandwidth (MHz)',
	'center_f (MHz)',
	'duration (ms)',
	'dtdnu (ms/MHz)',
	# 'nudtdnu',
	# 'tb (ms)'
]

### Measurements to include

# R1 only
# sdf = measdf.loc[measdf.source == 'frb20121102A']
# sdf = measdf.loc[measdf.source == 'frb20220912A']

#  ultra-FRBs
# ultradf = measdf.loc[measdf['duration (ms)'] < 0.3].loc[measdf.source != 'frb20121102A']
ultradf = None

# Everything
fluxdf = pd.read_csv('fluxdensities.csv')
fluxdf['flux density (Jy)'].replace(0, np.nan, inplace=True)
sdf = measdf
if len(sdf) == len(fluxdf):
	sdf['flux density (Jy)'] = list(fluxdf['flux density (Jy)'])
else:
	print("Warning: Flux densities do not line up with measurements")
print(f"{len(measdf) = }")

fig = corner.corner(
	sdf[cornercols].to_numpy(),
	labels=[
		r"$\sigma_\nu$ (MHz)",
		r"$\nu_0$ (MHz)",
		r"$\sigma_t$ (ms)",
		r"$dt/d\nu$ (ms/MHz)",
	],
	axes_scale=[
		'log',
		'log',
		'log',
		'linear'
	],
	# quantiles=[0.16, 0.5, 0.84],
	plot_contours=False,
	title_kwargs={"fontsize": 12},
)
markers = resetmarkers()
axs = fig.axes
# fig.set_size_inches(10.5,9.7) # corner.corner makes 9.7 x 9.7

durcmap = cm.ScalarMappable(
	# norm=mpl.colors.Normalize(
	# 	vmin=sdf['duration (ms)'].min(),
	# 	vmax=sdf['duration (ms)'].max()
	# ),
	norm=mpl.colors.LogNorm(
		vmin=sdf['duration (ms)'].min(),
		vmax=20#sdf['duration (ms)'].max()
	),
	cmap=cm.viridis
)

fluxcmap = cm.ScalarMappable(
	norm=mpl.colors.LogNorm(
		vmin=fluxdf['flux density (Jy)'].min(),
		vmax=fluxdf['flux density (Jy)'].max(),
	),
	cmap=cm.viridis
)
print(
	f"{fluxdf['flux density (Jy)'].min() = }",
	' -- '
	f"{fluxdf['flux density (Jy)'].max() = }"
)
# fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.35, 0.02, 0.6]) # (left, bottom, width, height)
# fig.colorbar(durcmap, cax=cbar_ax)
cb = fig.colorbar(
	# durcmap,
	fluxcmap,
	cax=cbar_ax,
	label="Flux Density (Jy)",
)

s = 30
sdfbackup = sdf.copy()
edgecolor, lw = 'k', 0.5
axins = axs[13].inset_axes(
	[0.35, 0.175, 0.6, 0.6], # fig fraction [x0, y0, width, height]
	xlim=(420, 8600), # labels come on and off for some reason at lower values (2866)
	ylim=(-0.07,0.011),
	xticklabels=[],
	yticklabels=[]
)

for source in sdfbackup.source.unique():
	marker = next(markers)
	sdf = sdfbackup.loc[sdfbackup.source == source]
	# cornercolors = durcmap.to_rgba(sdf['duration (ms)'])
	cornercolors = fluxcmap.to_rgba(sdf['flux density (Jy)'])
	axs[4].scatter(
		sdf[cornercols[0]],
		sdf[cornercols[1]],
		marker=marker,
		s=s,
		color=cornercolors,
		edgecolor=edgecolor,
		linewidth=lw,
	)
	axs[8].scatter( # Hijacking these axes for the main legend
		sdf[cornercols[0]],
		sdf[cornercols[2]],
		marker=marker,
		s=s,
		color=cornercolors,
		edgecolor=edgecolor,
		linewidth=lw,
		label=f"{source.upper()}"# if not labelled else '' ,
	)
	axs[9].scatter(
		sdf[cornercols[1]],
		sdf[cornercols[2]],
		marker=marker,
		s=s,
		color=cornercolors,
		edgecolor=edgecolor,
		linewidth=lw
	)
	axs[12].scatter(
		sdf[cornercols[0]],
		sdf[cornercols[3]],
		marker=marker,
		s=s,
		color=cornercolors,
		edgecolor=edgecolor,
		linewidth=lw
	)
	axs[13].scatter(
		sdf[cornercols[1]],
		sdf[cornercols[3]],
		marker=marker,
		s=s,
		color=cornercolors,
		edgecolor=edgecolor,
		linewidth=lw
	)
	axins.scatter(
		sdf[cornercols[1]],
		sdf[cornercols[3]],
		marker=marker,
		s=s,
		color=cornercolors,
		edgecolor=edgecolor,
		linewidth=lw
	)
	axs[14].scatter(
		sdf[cornercols[2]],
		sdf[cornercols[3]],
		marker=marker,
		s=s,
		color=cornercolors,
		edgecolor=edgecolor,
		linewidth=lw
	)

	# Add some space between points and axes
	[ax.set_xlim(3.6, 565) for ax in [axs[4], axs[8], axs[12]]] # Bandwidth
	[ax.set_ylim(0.001, 11.3) for ax in [axs[8], axs[9]]] # Duration
	[ax.set_ylim(-0.53, 0.02) for ax in [axs[12], axs[13], axs[14]]] # dt/dnu
	[ax.set_xlim(120, 9000) for ax in [axs[9], axs[13]]] # Frequency
	axs[4].set_ylim(120, 9000)
	axs[14].set_xlim(0.0012,10.3)
axs[13].indicate_inset_zoom(axins, edgecolor="black")

freqs = np.linspace(100, 9000, num=500)

### duration vs frequency fits
cutoff = 0.3 # ms (as used above)
ultrafitdf = sdfbackup.loc[sdfbackup['duration (ms)'] < cutoff]
longfitdf  = sdfbackup.loc[sdfbackup['duration (ms)'] >= cutoff]

print("Ultra FRBs σ_t vs nu_0 fit:")
odrjob = scipy.odr.ODR(
	scipy.odr.RealData(
		ultrafitdf['duration (ms)'],
		ultrafitdf['center_f (MHz)'],
		sx=ultrafitdf['duration_err'],
		sy=ultrafitdf['center_f_err'],
	),
	scipy.odr.Model(lambda B, x: B[0]*x**-1),
	beta0=[25]
)
odrjob.set_job(fit_type=0)
odrfit = odrjob.run()
odrfit.pprint()
fiterr = np.sqrt(np.diag(odrfit.cov_beta))
# print(f"{fiterr = }")

print("Long FRBs σ_t vs nu_0 fit:")
odrjob = scipy.odr.ODR(
	scipy.odr.RealData(
		longfitdf['duration (ms)'],
		longfitdf['center_f (MHz)'],
		sx=longfitdf['duration_err'],
		sy=longfitdf['center_f_err'],
	),
	scipy.odr.Model(lambda B, x: B[0]*x**-1),
	beta0=[1000]
)
odrjob.set_job(fit_type=0)
odrfit2 = odrjob.run()
odrfit2.pprint()
fiterr2 = np.sqrt(np.diag(odrfit2.cov_beta))
# print(f"{fiterr2 = }")
# exit()

odrjob = scipy.odr.ODR(
	scipy.odr.RealData(
		measdf['bandwidth (MHz)'],
		measdf['center_f (MHz)'],
		sx=measdf['bandwidth_err'],
		sy=measdf['center_f_err'],
	),
	scipy.odr.Model(lambda B, x: B[0]*x),
	beta0=[25]
)
odrjob.set_job(fit_type=0)
nubnu_fit = odrjob.run()
print(f"{nubnu_fit.res_var = } {nubnu_fit.sd_beta = }")
nubnu_fit.pprint()
nubnu_err = nubnu_fit.sd_beta#np.sqrt(np.diag(nubnu_fit.cov_beta))
# exit()
###

# axs[9].plot(
# 	freqs,
# 	1474/freqs, # Chamma 2023
# 	# alpha=0
# 	'-.',
# 	label=r"$C_1\nu_0^{-1}$"
# )

# axs[9].plot(
# 	freqs,
# 	0.75e6*freqs**-2, # Picked by eye I think
# 	'-.',
# 	label=r"$C_2\nu_0^{-2}$"
# )

# axs[9].plot(
# 	freqs,
# 	0.1e6*freqs**-2, # Picked by eye I think
# 	'-.',
# 	label=r"$C_2\nu_0^{-2}$"
# )

# Fits
bands = np.linspace(1,565, num=500)
axs[4].plot(
	bands,
	(nubnu_fit.beta[0])*bands, # (1/0.14 is Chamma+2023 result)
	'-.',
	label=rf'${nubnu_fit.beta[0]:.0f}({nubnu_err[0]:.0f}) \sigma_\nu$'

)
axs[4].legend(fontsize=10, frameon=False)
prev_nubnu = np.sqrt(8*np.log(2))/0.14
axs[4].plot(
	bands,
	prev_nubnu*bands, # (1/0.14 is Chamma+2023 result)
	'-.',
	label=rf'${prev_nubnu:.1f} \sigma_\nu$',
	alpha=0.3

)
axs[4].legend(fontsize=10, frameon=False)

axs[9].plot(
	freqs,
	odrfit2.beta[0]/freqs, # Chamma 2023
	# alpha=0
	'-.',
	# label=rf"${odrfit2.beta[0]:.0f}\nu_0^{{-1}}$"
	# label=rf'{scilabel(odrfit2.beta[0], fiterr2[0], prec=3)}$\nu_0^{{-1}}$'
	label=rf'${odrfit2.beta[0]:.0f}({fiterr2[0]:.0f}) \nu_0^{{-1}}$'

)
axs[9].plot(
	freqs,
	odrfit.beta[0]/freqs, # Chamma 2023
	# alpha=0
	'-.',
	# label=rf"${odrfit.beta[0]:.2f}\nu_0^{{-1}}$"
	# label=rf'{scilabel(odrfit.beta[0], fiterr[0], prec=2)}$\nu_0^{{-1}}$'
	label=rf'${odrfit.beta[0]:.1f}({fiterr[0]*10:.0f}) \nu_0^{{-1}}$'
)

axs[9].legend(
	loc=0,
	fontsize=9,
	frameon=False,
	handlelength=1,
	handletextpad=0.4
)

sigmanus = np.linspace(0,400,num=400)
axs[12].plot(
	sigmanus,
	# -1500625/sigmanus**2,
	# -0.2e6/sigmanus**2,
	# -2570/sigmanus**1,
	alpha=0
)

C = 1/6.1e-5 # fit result from chamma+2023
axs[13].plot(
	freqs,
	-C/(freqs**2),
	'-.',
	zorder=-2,
	label=r"$-C\nu_0^{-2}$"
)
axins.plot(
	freqs,
	-C/(freqs**2),
	'-.',
	zorder=-2,
)
axins.set_xscale('log')
axins.tick_params(labelbottom=False)
axs[13].legend(
	loc=4,
	frameon=False,
	fontsize=11,
	borderpad=-0.2
)

axs[8].legend(
	ncols=2,
	bbox_to_anchor=(3.5,3)
)

betas = []
for n in [-1/2, -1, -2, -3]:
	odrjob = scipy.odr.ODR(
		scipy.odr.RealData(
			measdf['bandwidth (MHz)'],
			measdf['dtdnu (ms/MHz)'],
			sx=measdf['bandwidth_err'],
			sy=measdf['dtdnu_err'],
		),
		scipy.odr.Model(lambda B, x: B[0]*x**n),
		beta0=[-30]
	)
	odrjob.set_job(fit_type=0)
	slopebnu_fit = odrjob.run()
	slopebnu_fit.pprint()
	slopebnu_err = np.sqrt(np.diag(slopebnu_fit.cov_beta))
	betas.append(slopebnu_fit.beta[0])
# print(betas)

axs[12].plot(
	bands,
	# -0.2/bands**(1/2),
	# -3/bands,
	-30/bands**2,
	# -300/bands**3
	alpha=0
)

#### Ultra frbs ####
if ultradf:
	marker = 'd'
	axs[4].scatter(
		ultradf[cornercols[0]],
		ultradf[cornercols[1]],
		marker=marker,
		s=s,
		color=durcmap.to_rgba(ultradf['duration (ms)']),
		edgecolor='k'
	)
	axs[8].scatter(
		ultradf[cornercols[0]],
		ultradf[cornercols[2]],
		marker=marker,
		s=s,
		color=durcmap.to_rgba(ultradf['duration (ms)']),
		edgecolor='k'
	)
	axs[9].scatter(
		ultradf[cornercols[1]],
		ultradf[cornercols[2]],
		marker=marker,
		s=s,
		color=durcmap.to_rgba(ultradf['duration (ms)']),
		edgecolor='k'
	)
	axs[12].scatter(
		ultradf[cornercols[0]],
		ultradf[cornercols[3]],
		marker=marker,
		s=s,
		color=durcmap.to_rgba(ultradf['duration (ms)']),
		edgecolor='k'
	)
	axs[13].scatter(
		ultradf[cornercols[1]],
		ultradf[cornercols[3]],
		marker=marker,
		s=s,
		color=durcmap.to_rgba(ultradf['duration (ms)']),
		edgecolor='k'
	)
	axs[14].scatter(
		ultradf[cornercols[2]],
		ultradf[cornercols[3]],
		marker=marker,
		s=s,
		color=durcmap.to_rgba(ultradf['duration (ms)']),
		edgecolor='k'
	)

# for ax in axs:
# 	print(ax.get_xlabel(), ax.get_ylabel())

if savecorner:
	plt.savefig("fig_corner.pdf")
	print("Saved fig_corner.pdf.")
plt.show() # corner
plt.close()
# exit()

#############################
### ultra FRBs vs FRBs plots
#############################

fig, axs = plt.subplots(
	3, 1,
	figsize=(7,9),
	sharex=True
)

print(f"{len(measdf) = }")
edgecolor, lw = 'k', 0.5
for source in measdf.source.unique():
	marker = next(markers)
	sdf = measdf.loc[measdf.source == source]
	axs[0].scatter(
		sdf['center_f (MHz)'],
		sdf['nudtdnu'],
		marker=marker,
		color=durcmap.to_rgba(sdf['duration (ms)']),
		edgecolor=edgecolor,
		linewidth=lw,
		# label=f"{source.upper()}"# if not labelled else '' ,
	)
	axs[0].errorbar(
		sdf['center_f (MHz)'],
		sdf['nudtdnu'],
		xerr=sdf['center_f_err'],
		yerr=np.sqrt((sdf['center_f (MHz)']*sdf['dtdnu_err'])**2 + (sdf['center_f_err']*sdf['dtdnu (ms/MHz)'])**2),
		# ecolor=df['c'],
		ecolor=durcmap.to_rgba(sdf['duration (ms)']),
		alpha=0.5,
		zorder=0,
		linestyle='none'
	)
axs[0].set_yscale('symlog', linthresh=1e-3)
# axs[0].set_xlabel("Frequency (MHz)")
axs[0].set_ylabel('$\\nu (\\text{d}t/\\text{d}\\nu) $ (ms)')

for source in measdf.source.unique():
	marker = next(markers)
	sdf = measdf.loc[measdf.source == source]
	axs[1].scatter(
		sdf['center_f (MHz)'],
		1/sdf['nudtdnu'],
		marker=marker,
		color=durcmap.to_rgba(sdf['duration (ms)']),
		edgecolor=edgecolor,
		linewidth=lw,
		# label=f"{source.upper()}"# if not labelled else '' ,
	)
	# axs[1].errorbar(
	# 	sdf['center_f (MHz)'],
	# 	1/sdf['nudtdnu'],
	# 	xerr=sdf['center_f_err'],
	# 	yerr=np.sqrt((sdf['center_f (MHz)']*sdf['dtdnu_err'])**2 + (sdf['center_f_err']*sdf['dtdnu (ms/MHz)'])**2),
	# 	# ecolor=df['c'],
	# 	ecolor=durcmap.to_rgba(sdf['duration (ms)']),
	# 	alpha=0.5,
	# 	zorder=0,
	# 	linestyle='none'
	# )
axs[1].set_yscale('symlog', linthresh=1e-3)
# axs[1].set_xlabel("Frequency (MHz)")
axs[1].set_ylabel('$\\nu^{-1} (\\text{d}\\nu/\\text{d}t) $ (ms$^{-1}$)')

for source in measdf.source.unique():
	marker = next(markers)
	sdf = measdf.loc[measdf.source == source]
	axs[2].scatter(
		sdf['center_f (MHz)'],
		sdf['duration (ms)'],
		marker=marker,
		color=durcmap.to_rgba(sdf['duration (ms)']),
		edgecolor=edgecolor,
		linewidth=lw,
		# label=f"{source.upper()}"# if not labelled else '' ,
	)
	axs[2].errorbar(
		sdf['center_f (MHz)'],
		sdf['duration (ms)'],
		xerr=sdf['center_f_err'],
		yerr=sdf['duration_err'],
		# ecolor=df['c'],
		ecolor=durcmap.to_rgba(sdf['duration (ms)']),
		alpha=0.5,
		zorder=0,
		linestyle='none'
	)
axs[2].set_yscale('log')
axs[2].set_xscale('log')
axs[2].set_xlabel("Frequency (MHz)")
axs[2].set_ylabel('Duration (ms)')

# Fits
axs[2].set_prop_cycle(None) # resets the color cycle
axs[2].plot(
	freqs,
	odrfit2.beta[0]/freqs, # Chamma 2023
	'-.',
	label=rf'{scilabel(odrfit2.beta[0], fiterr2[0], prec=3)}$\nu_0^{{-1}}$'
)
axs[2].plot(
	freqs,
	odrfit.beta[0]/freqs, # Chamma 2023
	'-.',
	label=rf'{scilabel(odrfit.beta[0], fiterr[0], prec=2)}$\nu_0^{{-1}}$'

)
axs[2].legend(frameon=False)

# cbar_ax = fig.add_axes([0.85, 0.475, 0.175, 0.02]) # (left, bottom, width, height)
# fig.colorbar(durcmap, cax=cbar_ax)
cb = fig.colorbar(
	durcmap,
	# cax=cbar_ax,
	ax=axs[0],
	label="Duration (ms)",
	orientation='horizontal',
	location='top',
	fraction=0.05
)

# plt.subplots_adjust(hspace=0.01)
plt.tight_layout()
# plt.show() # Ultra frbs vs frbs plot
plt.close()

#############################
### nu_0 vs fractional bandwidth
#############################

obsbands = {
	"gajjar2018"           : 8000 - 4000,
	"michilli2018"         : 4900 - 4100,
	"oostrum2020"          : 1450 - 1250,
	"li2021"               : 1500 - 1000,
	"scholz2016"           : 2400 - 1600,
	"aggarwal2021"         : 1774 - 974,
	"snelders2023"         : 9300 - 3900,
	"20121102_Arecibo_1"   : 1730 - 1150,
	"hewitt2023"           : 1742 - 1230,
	"sheikh2024"           : 2334 - 900,
	"zhang2023"            : 1500 - 1000,
	"nimmo2023"            : 1600 - 1200,
	"201124_Effelsberg"    : 1520 - 1200,
	"20180301_FAST"        : 1500 - 1000,
	"20180814A_CHIME"      : 800 - 400,
	"20180916B_Effelsberg" : 6000 - 4000,
	"20180916B_LOFAR"      : 188 - 110,
	"20180916B_CHIME"      : 800 - 400,
	"20180916B_uGMRT"      : 750 - 550,
	"20190804E_CHIME"      : 800 - 400,
	"20190915D_CHIME"      : 800 - 400,
	"20200223B_CHIME"      : 800 - 400,
	"20200929C_CHIME"      : 800 - 400,
	"20201130A_CHIME"      : 800 - 400,
}

measdf['obsband'] = measdf.dataset.map(obsbands)
fig, axs = plt.subplots(1, 2, figsize=(8.5,4.5))

xdata = 'center_f (MHz)'
xerr = 'center_f_err'
sc = axs[0].scatter(
	measdf[xdata],
	measdf['bandwidth (MHz)']/measdf['center_f (MHz)'],
	c=measdf['bandwidth (MHz)'],
)
fig.colorbar(sc, ax=axs[1], label='Bandwidth $\\sigma_\\nu$ (MHz)')
axs[0].set_xlabel("Center Frequency $\\nu_0$ (MHz)")
axs[0].set_ylabel("FBW of Burst (=$\\sigma_\\nu/\\nu_0$)")
# axs[0].set_title("Fractional Bandwidth (FBW)")

axs[1].scatter(
	measdf[xdata],
	measdf['bandwidth (MHz)']/measdf['obsband'],
	c=measdf['bandwidth (MHz)'],
)
axs[1].set_xlabel("Center Frequency $\\nu_0$ (MHz)")
axs[1].set_ylabel("Absolute FBW of Burst (=$\\sigma_\\nu/\\Delta\\nu_\\text{obs}$)")
# axs[1].set_title("Absolute Fractional Bandwidth (FBW)")

# axs[0].errorbar(
# 	measdf[xdata],
# 	measdf['bandwidth (MHz)']/measdf['center_f (MHz)'],
# 	xerr=measdf[xerr],
# 	marker='',
# 	linestyle='',
# 	zorder=-1,
# 	linewidth=1
# )

plt.tight_layout()
# plt.savefig('fbw.pdf')
plt.show()




















