import numpy as np
import matplotlib.pyplot as plt
import json
from tqdm import tqdm

with open("trialsub.json") as f:
    # organized as dictionary Dict[str, int] indicating
    # trial number as key and subject identifier as value
    trialsub = json.loads(f.read())


def newfig():
    "create a new figure with 25 subplots"
    return plt.subplots(5, 5, figsize=(10, 10), dpi=100)


emg_data = np.load("dimep-dataset.npy")
fig, ax = newfig()
aix = 0
fix = 0
ax = ax.flatten()
with tqdm(total=emg_data.shape[0], desc="Plot trials") as pbar:
    # show here only for the first 50 trials
    for tix, trial in enumerate(emg_data[0:50, :]):
        if aix >= 25:
            fig.tight_layout()
            fig, ax = newfig()
            ax = ax.flatten()
            aix = 0
            fix += 1
        ax[aix].plot(trial, linewidth=0.5)
        ax[aix].plot([500, 500], [-1000, 1000], "r:")
        ax[aix].set_title(f"Trial:{tix} Subject:{trialsub[str(tix)]}")
        ax[aix].set_xticks([])
        ax[aix].set_yticks([])
        lim = max(abs(trial[0:450]))
        lim = max((max(abs(trial[0:450])), max(abs(trial[550:]))))
        lim = np.ceil(0.1 * lim) * 10
        ax[aix].set_ylim((-lim, lim))
        ax[aix].set_xlim(300, 700)
        pbar.update(1)
        aix += 1

fig.tight_layout()
plt.show()
