from .spinup.ddpg import ReplayBuffer
import numpy as np


class DemoBuffer(ReplayBuffer):

    def __init__(self, obs_dim, act_dim, size):
        self.obs_dim = obs_dim
        self.act_dim = act_dim
        super().__init__(obs_dim, act_dim, size)

    def load_from_file_with_shuffle(self, file_name, load_size):
        demo_data = np.load(file_name)
        demo_index = np.arange(load_size)
        np.random.shuffle(demo_index)

        for i in demo_index:
            self.store(demo_data[i][0:self.obs_dim],
                       demo_data[i][self.obs_dim: self.obs_dim+self.act_dim],
                       demo_data[i][self.obs_dim+self.act_dim: self.obs_dim+self.act_dim+1],
                       demo_data[i][self.obs_dim+self.act_dim+1: self.obs_dim*2+self.act_dim+1],
                       demo_data[i][-1])
