import numpy as np
from .commons import nm2

def sample_positions_trivial(centers, ranges, num=1):

    pos_list = []
    for i in range(num):
        positions = {}
        for i in centers.keys():
            positions[i] = np.random.uniform(np.array(centers[i]) - np.array(ranges[i]), 
                                             np.array(centers[i]) + np.array(ranges[i])
                                             ).astype(np.float32)
        pos_list.append(positions)

    return pos_list


def sample_positions_valid(centers, ranges, num=1):

    initial = np.array(centers["initial"])
    pos_list = []

    while len(pos_list) < num:

        goal = np.random.uniform(np.array(centers["goal"]) - np.array(ranges["goal"]), 
                                 np.array(centers["goal"]) + np.array(ranges["goal"]))
        
        if nm2(goal[0:2] - initial[0:2]) > 0.3:
            obstacle = np.random.uniform((initial + goal)/2 - np.array(ranges["obstacle"]), 
                                         (initial + goal)/2 + np.array(ranges["obstacle"]))
            pos_list.append({'initial': initial.astype(np.float32),
                             'goal': goal.astype(np.float32),
                             'obstacle': obstacle.astype(np.float32)})

    return pos_list