import pandas as pd

from plotter import HandballPlot


class Dataset:

    def __init__(self, coordinates, possessions):
        self.coordinates = coordinates
        self.possessions = possessions

    def get_player_positions_for_possession(self, game, period, possession_id):
        game_coord = self.coordinates[(self.coordinates["game"] == game) & (self.coordinates["period"] == period)]
        poss = self.possessions[(self.possessions["game"] == game) & (self.possessions["period"] == period) & (
                self.possessions["possession_id"] == possession_id)]

        time_start = poss["time_start"].iloc[0]
        time_end = poss["time_end"].iloc[0]

        play_coord = game_coord[(game_coord["time"] >= time_start) & (game_coord["time"] <= time_end)]

        df_table = pd.pivot_table(play_coord, values=['x', 'y'], columns="player", index="time")
        return df_table

    def get_image_for_possession(self, game, period, possession_id):
        df = self.get_player_positions_for_possession(game=game, period=period, possession_id=possession_id)
        players = df["x"].columns.values
        df.columns = df.columns.to_series().str.join('_')

        handball_plot = HandballPlot().handball_plot(title=f"Possession {possession_id} for Game {game}, period {period}")
        for player in players:
            xs = df[f"x_{player}"].to_numpy()
            ys = df[f"y_{player}"].to_numpy()
            handball_plot.add_trajectories(x=xs, y=ys, label=player)

        handball_plot.add_legend()

        return handball_plot


def builder(positions_file_path, possessions_file_path):
    coords = pd.read_csv(filepath_or_buffer=positions_file_path)
    possessions = pd.read_csv(filepath_or_buffer=possessions_file_path)
    return Dataset(coordinates=coords, possessions=possessions)
