import numpy as np
from mpi4py import MPI
from qutip import *
import pypluto as pp  # For plasma turbulence

# MPI setup
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

def hybrid_run(a, B_ext, M):
    """Single parameter set simulation"""
    # 1. Holographic inversion
    Theta = (-1j * np.pi * sigmaz()).expm()
    psi_out = Theta * (qt.hadamard_transform(1) * qt.basis(2,0)
    
    # 2. Kerr-Newman dynamics
    r_H = 1 + np.sqrt(1 - a**2)
    Q_eff = B_ext * r_H**2
    P_jet = (a**2 + Q_eff**2) * (1 - np.sqrt(1 - a**2))
    e_ratio = np.exp(-2 * np.pi * Q_eff)
    
    # 3. White hole burst
    t0, tau = 0.5e-21, 1e-23
    t = np.linspace(0, 1e-21, 1000)
    T_wh = M * (1.6e-19) * np.exp(-(t - t0)**2 / tau**2)
    
    # 4. Turbulent jet (PyPLUTO)
    jet_data = pp.analyze(f"a{a}_B{B_ext}_M{M}.h5")
    fractal_dim = pp.boxcount(jet_data.rho)
    
    return fractal_dim, np.max(T_wh), e_ratio

# Parameter grids (distributed via MPI)
a_grid = np.linspace(0.1, 0.99, 10)
B_grid = np.logspace(6, 9, 10)
M_grid = np.logspace(-12, -8, 10)

results = []
for i in range(rank, len(a_grid), size):
    for j, B in enumerate(B_grid):
        for k, M in enumerate(M_grid):
            fd, h_max, e_frac = hybrid_run(a_grid[i], B, M)
            results.append([a_grid[i], B, M, fd, h_max, e_frac])

# Gather results
all_results = comm.gather(results, root=0)