import itertools
import multiprocessing
import os
import sys
import tqdm
import random

from absl import flags
from sympy import *
import numpy as np


NUM_CPUS = (
    int(os.environ.get("NUM_CPUS")) if "NUM_CPUS" in os.environ else os.cpu_count()
)

FLAGS = flags.FLAGS


def init_args():
    flags.FLAGS(sys.argv)


def run(func_py, inputs):
    # func_py = "from sympy import *\n" + func_py
    exec(func_py, globals())
    result = foo(*inputs)
    del globals()["foo"]

    return result


def generate_random_inputs(arg_shapes, N):
    random_inputs = []
    for shape in arg_shapes:
        if len(shape) == 0:
            random_inputs.append(random.uniform(0, 1))
        elif len(shape) == 1:
            random_inputs.append(np.random.rand(N).astype(np.float32))
        elif len(shape) == 2:
            random_inputs.append(np.random.rand(N, N).astype(np.float32))
        elif len(shape) == 3:
            random_inputs.append(np.random.rand(N, N, N).astype(np.float32))
    return random_inputs


def print_green(s):
    print("\033[1;42m" + str(s) + "\033[0m")


def print_yellow(s):
    print("\033[1;43m" + str(s) + "\033[0m")


def filter_ops_by_target(ops):
    if FLAGS.target == "hlo":
        return [
            op for op in ops if op.startswith("chlo.") or op.startswith("stablehlo.")
        ] + ["stablehlo.dot"]
    elif FLAGS.target == "numpy":
        return [op for op in ops if op.startswith("jnp.")]
    else:
        raise Exception


def smart_map(
    func,
    iterable,
    desc="",
    spawn=False,
    chunksize=1,
    parallel_threshold=1000,
    single_threaded=False,
):
    if len(iterable) < parallel_threshold or single_threaded:
        results = []
        for item in tqdm.tqdm(iterable, desc=desc):
            results.append(func(item))
        return results

    else:
        if spawn:
            ctx = multiprocessing.get_context("spawn")
        else:
            ctx = multiprocessing.get_context("fork")

        with ctx.Pool(NUM_CPUS) as pool:
            results = list(
                tqdm.tqdm(
                    pool.imap(func, iterable, chunksize=chunksize),
                    total=len(iterable),
                    desc=desc,
                )
            )
        return list(results)


def smart_starmap(
    func, iterable, length, desc="", spawn=False, chunksize=1, parallel_threshold=1000
):
    if length < parallel_threshold:
        results = []
        for item in tqdm.tqdm(iterable, desc=desc):
            results += func(*item)
        return results

    else:
        if spawn:
            ctx = multiprocessing.get_context("spawn")
        else:
            ctx = multiprocessing.get_context("fork")

        with ctx.Pool(NUM_CPUS) as pool:
            results = itertools.chain.from_iterable(
                pool.starmap(
                    func,
                    tqdm.tqdm(iterable, total=length, desc=desc),
                    chunksize=chunksize,
                )
            )
        return list(results)
