# ---
# jupyter:
#   jupytext:
#     formats: ipynb,py:hydrogen
#     text_representation:
#       extension: .py
#       format_name: hydrogen
#       format_version: '1.3'
#       jupytext_version: 1.16.2
#   kernelspec:
#     display_name: Python 3 (ipykernel)
#     language: python
#     name: python3
# ---

# %% [markdown]
# # Banner Charts and AWRF
#
# This notebook computes AWRF and produces the banner charts for the music sim paper.

# %%
from pathlib import Path
import re
import pickle
import pandas as pd
import numpy as np
import plotnine as pn
from tqdm.auto import tqdm

# %%
out_dir = Path('output')
res_dir = Path('results')
plot_dir = Path('plots')

# %%
plot_dir.mkdir(exist_ok=True)


# %% [markdown]
# ## Presentation Functions
#
# Functions for parsing and re-encoding run information for consistent display.

# %%
def decode_run(name):
    parts = re.sub('_.*', '', name).split('+')
    if len(parts) < 3:
        parts.append('IA')
    algo, rank, choice = parts
    return algo, rank, choice


# %%
def recode_for_display(df):
    df['algo'] = df['algo'].astype('category').cat.rename_categories({
        'TorchSampledMF': 'BPR',
        'ials': 'IALS'
    })
    df['rerank'] = df['rerank'].astype('category').cat.rename_categories({
        'fair': 'FAIR',
        'movedRerank': 'MoveUp',
        'lambdaRerank0': 'None',
        'lambdaRerank7': 'λ7',
        'lambdaRerank5': 'λ5',
    }).cat.reorder_categories(['None', 'MoveUp', 'λ5', 'λ7', 'FAIR'])
    df['choice'] = df['choice'].astype('category').cat.rename_categories({
        'random': 'Rnd',
        'biased': 'Biased',
        'deterministic': 'Det'
    })
    return df.rename(columns={
        'algo': 'BaseAlgo'
    })


# %% [markdown]
# ## Compute AWRF
#
# We need to load the data from all off our runs, and for each one compute AWRF.
#
# We will use log discounting for AWRF, for consistency with nDCG.  Our 0-correction will be $\operatorname{max}\{i, 2\}$, also for consistency.

# %%
N = 10
discount = np.reciprocal(np.log2(np.maximum(np.arange(N) + 1, 2)))

# %%
paths = [d for d in out_dir.glob('*') if d.is_dir()]
len(paths)

# %%
awrfs = {}
loop = tqdm(total=len(paths) * 6)
for run in paths:
    with open(run / 'items_gender.pkl', 'rb') as inf:
        genders = pickle.load(inf)
        genders = np.array(genders)
        protected = genders != 'Male'

    for idir in run.glob('[0-9]'):
        loop.set_postfix_str(f'{run.stem}/{idir.stem}')
        recs = np.load(idir / 'predicted.npz')['data']
        # recs has users on rows, items on cols
        recs = recs[:, :N]
        # logical: protected and unprotected recs
        r_prot = protected[recs]
        r_unprot = ~protected[recs]

        # compute exposure
        exp_prot = np.sum(r_prot * discount, axis=1)
        exp_unprot = np.sum(r_unprot * discount, axis=1)

        # save results
        df = pd.DataFrame({
            'exp_prot': exp_prot,
            'exp_unprot': exp_unprot,
        })
        df.index.name = 'user'
        algo, rank, choice = decode_run(run.stem)
        awrfs[(algo, rank, choice, idir.stem)] = df
        loop.update()

# %%
awrf_df = pd.concat(awrfs, names=('algo', 'rerank', 'choice', 'iter'))
awrf_df = awrf_df.reset_index().astype({'iter': 'i4'})
awrf_df

# %% [markdown]
# Now we normalize those values for computing AWRF:

# %%
awrf_df['exp_tot'] = awrf_df['exp_prot'] + awrf_df['exp_unprot']
awrf = awrf_df[['algo', 'rerank', 'choice', 'iter', 'user']].assign(
    prot = awrf_df['exp_prot'] / awrf_df['exp_tot'],
    unprot = awrf_df['exp_unprot'] / awrf_df['exp_tot'],
)
awrf['dist'] = awrf['prot'] - 0.5
awrf.head()

# %%
awrf = recode_for_display(awrf)

# %%
awrf_means = awrf.groupby(['BaseAlgo', 'rerank', 'choice', 'iter'])[['prot', 'unprot', 'dist']].mean()
awrf_means.reset_index(inplace=True)
awrf_means

# %% [markdown]
# ## AWRF Banner Image

# %%
plot = (
    pn.ggplot(awrf_means[awrf_means['iter'] > 0]) +
    pn.aes(x='iter', y='dist', color='BaseAlgo') +
    pn.geom_line() +
    pn.geom_hline(yintercept=0, linetype='dashed') +
    pn.facet_grid('choice', 'rerank') +
    pn.ylab('Distance (0 is fair)') +
    pn.xlab('Iteration') +
    pn.scale_color_brewer("qual", "Dark2") +
    pn.theme_minimal() +
    pn.theme(
        axis_title=pn.element_text(size=8),
        legend_direction='horizontal',
        legend_position="top",
        legend_title=pn.element_blank(),
    )
)
plot.save(plot_dir / 'banner-awrf.pdf', width=9, height=4)
plot.show()

# %% [markdown]
# ## Integrated Gini Chart
#
# Now let's generate our Gini chart:

# %%
res_dfs = {
    decode_run(res_f.stem): pd.read_csv(res_f)
    for res_f in res_dir.glob('*.csv')
}
results = pd.concat(res_dfs, names=['algo', 'rerank', 'choice', 'iter'])
results

# %%
assert results.index.is_unique
results = results.reset_index(['algo', 'rerank', 'choice']).reset_index(drop=True)
results

# %%
results['rerank'].unique()

# %% [markdown]
# Why is there a rerank 10?

# %%
results = results[results['rerank'] != 'lambdaRerank10']
results = results[results['iter'] > 0]

# %%
results = recode_for_display(results.copy())

# %%
gini = results[['BaseAlgo', 'rerank', 'choice', 'iter', 'GINI@10_all', 'GINI@10_female', 'GINI@10_male']]
gini = gini.rename(columns={
    'GINI@10_all': 'All',
    'GINI@10_female': 'Female',
    'GINI@10_male': 'Male',
})
gini

# %%
gini_tall = gini.melt(['BaseAlgo', 'rerank', 'choice', 'iter'], var_name='Gender', value_name='Gini')

# %%
plot = (
    pn.ggplot(gini_tall)
    + pn.aes(x='iter', y='Gini', color='BaseAlgo', linetype='Gender')
    + pn.geom_line()
    + pn.facet_grid('choice', 'rerank')
    + pn.theme_minimal()
    + pn.xlab('Iteration')
    + pn.scale_color_brewer("qual", "Dark2")
    + pn.theme_minimal()
    + pn.theme(
        axis_title=pn.element_text(size=8),
        legend_direction='horizontal',
        legend_position="top",
        legend_title=pn.element_blank(),
    )
)
plot.save(plot_dir / 'banner-gini.pdf', width=9, height=4)
plot.show()

# %% [markdown]
# ### First Female

# %%
plot = (
    pn.ggplot(results)
    + pn.aes(x='iter', y='First_female', color='BaseAlgo')
    + pn.geom_line()
    + pn.facet_grid('choice', 'rerank')
    + pn.theme_minimal()
    + pn.ylab('Position of First Female Artist')
    + pn.xlab('Iteration')
    + pn.scale_color_brewer("qual", "Dark2")
    + pn.theme_minimal()
    + pn.theme(
        axis_title=pn.element_text(size=8),
        legend_direction='horizontal',
        legend_position="top",
        legend_title=pn.element_blank(),
    )
)
plot.save(plot_dir / 'banner-first-female.pdf', width=9, height=4)
plot.show()
