"""Shared parameters and routines for the second figure of dipole_clusters.

Version 1.2.1, May 17, 2024, tested with Python 3.11.7, code analysis 10.00/10
"""

import sys
import copy
import numpy as np
from scipy.spatial.transform import Rotation as R
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, RadioButtons, Button
from mpl_toolkits.axes_grid1 import make_axes_locatable
try:
    import cluster_animation_plots as cap
except ImportError:
    import dipole_clusters_modules.cluster_animation_plots as cap

try:
    from dipole_clusters_modules import dipole_cluster as dct
    from dipole_clusters_modules.plot_tools import Tl, Gt, bgrs
except ImportError:
    sys.path.append(r'D:\Py\general')
    import dipole_cluster as dct
    from plot_tools import Tl, Gt, bgrs


class Mg:
    """My globals for cube_cube_plots."""

    n_c1 = None
    n_c2 = None
    clu_dist = 20
    g1 = dct.GeoPlane(width=2., st_sz=0.01, route=np.array([1., 0., 0.]),
                      d_p=1., origin=np.array([0.0, 0., 0.0]), length=200)
    g2 = dct.GeoPlane(width=2., st_sz=0.01, route=np.array([1., 0., 0.]),
                      d_p=-1., origin=np.array([0.0, 0., 0.0]), length=200)
    route_adj = [None]*3
    dp_adj = None  # radio button to change the distance between the planes
    trp_adj = None  # slider to change the transparency
    loglog = True  # parameter for ax[1]
    loglog_toggle = None  # radio button to change loglog
    switch_plt = None  # button to switche to the 1. page of the animation
    fig = None  # the second page of the animation showing the interaction
    ax = [None]*7  # 7 subplots for the the second page.
    t_xyz = [0, 0, 0]  # three turning angles for R
    turn_adj = [None]*3  # 3 sliders for turning the second cluster

    @staticmethod
    def listing():
        """List of parameters, for formal reasons only."""
        print(dir(Mg))

    @staticmethod
    def number_of_elements():
        """List of parameters, for formal reasons only."""
        print('MG has ', len(dir(Mg)), 'elements')


def plot_layout_pair():
    """Layout for 4 boxs and sliders to display the pair interaction."""
    Mg.fig = plt.figure('Interaction of two clusters',
                        figsize=(10*16/9, 10))  # 4 subplots
    Gt.set_my_rc()
    q_x = 0.25  # width of 3 plots
    q_y = q_x*16/9  # height of that plots, but the ratio is set automatically
    Mg.ax[0] = plt.axes([0.01, 0.15, q_x, q_y], projection='3d')  # dips in 3d
    Mg.ax[2] = plt.axes([0.67, 0.15, q_x, q_y], projection='3d')  # dips in 3d
    Mg.ax[3] = plt.axes([0.34, 0.15, q_x, q_y], projection='3d')  # 2 planes

    Mg.ax[1] = plt.axes([0.21, 0.65, 0.29, 0.34])
    Mg.ax[4] = plt.axes([0.54, 0.65, 0.33, 0.33])
    Mg.ax[5] = plt.axes([0.025, 0.63, 0.12, 0.15*16/9*0.5])
    Mg.ax[6] = plt.axes([0.875, 0.63, 0.12, 0.15*16/9*0.5])

    col = 'lightgoldenrodyellow'
    l_2 = 0.35  # left anchor for positions
    l_1 = 0.025  # left starting point
    l_3 = 0.69
    h_1 = 0.02
    h_2 = h_1 + 0.03
    s_l = 0.26
    s_h = 0.02
    for i, lab in enumerate([r'$\alpha_{\parallel ,\mathrm{a}}$',
                             r'$\beta_{\perp ,\mathrm{h}}$',
                             r'$\gamma_{\perp ,\mathrm{u}}$']):
        Mg.turn_adj[i] = Slider(plt.axes([l_3, h_1+i*0.03, s_l, s_h]), lab,
                                0, 360, valfmt='%1.0f', valinit=Mg.t_xyz[i])
    for i, lab in enumerate([' x ', ' y ', ' z ']):
        Mg.route_adj[i] = Slider(plt.axes([l_1, h_1+i*0.03, s_l, s_h]),
                                 lab, 0, 5, valfmt='%1.1f',
                                 valinit=Mg.g1.route()[i])

    Mg.dp_adj = Slider(plt.axes([l_2, h_1,  s_l, s_h]),
                       r"$d_\mathrm{p}$", 0, 20, valfmt='%1.1f',
                       valinit=Mg.g1.d_p())
    Mg.trp_adj = Slider(plt.axes([l_2, h_2, s_l, s_h]),
                        "$\\alpha$", 0, 1, valfmt='%1.1f', valinit=0.5)

    Mg.loglog_toggle = RadioButtons(plt.axes([0.50, 0.65, 0.026, 0.34],
                                             fc=col),
                                    (' L\n o\n g', ' L\n i\n n'),
                                    active=int(not Mg.loglog))
    Mg.switch_plt = Button(plt.axes([0.005, 0.965, 0.03, 0.03]),
                           r'$\leftarrow$', color=col, hovercolor='0.975')
    Mg.switch_plt.on_clicked(newplt)
    for slider in Mg.turn_adj:
        slider.on_changed(turn_update)
    for slider in Mg.route_adj:
        slider.on_changed(route_update)
    Mg.dp_adj.on_changed(dp_update)
    Mg.trp_adj.on_changed(trp_update)
    Mg.loglog_toggle.on_clicked(toggle_log)
    plt_7_axes()


def newplt(_):
    """Switch to another plot."""
    Gt.msg('newplt started')
    cap.Pl.flush_figs()
    plt.figure(cap.Pl.fig)
    plt.show()
    Gt.msg('newplt finished')


def turn_update(_):
    """Set angle for cluster 2."""
    Gt.msg('turn_update started')
    for i, slider in enumerate(Mg.turn_adj):
        Mg.t_xyz[i] = slider.val
    plt_7_axes()
    Gt.msg('turn_update finished')


def trp_update(_):
    """Set the tranparency of the sphere or the plane."""
    Gt.msg('trp_update started')
    bild4(Mg.ax[3], Mg.n_c1, Mg.n_c2)
    cap.Pl.flush_figs()
    Gt.msg('trp_update finished')


def route_update(_):
    """Set components of the display line perpendicular to the plane."""
    Gt.msg('route_update started')
    Mg.g1.new_geo(route=[Mg.route_adj[0].val, Mg.route_adj[1].val,
                         Mg.route_adj[2].val])
    plt_7_axes()
    Gt.msg('route_update finished')


def dp_update(_):
    """Set the distance of the plane."""
    Gt.msg('dp_update started')
    Mg.g1.new_geo(d_p=Mg.dp_adj.val)
    plt_7_axes()
    Gt.msg('dp_update finished')


def toggle_log(label):
    """Toggle between linear and loglog presentation of U(r)."""
    Gt.msg('toggle_log started')
    Mg.loglog = label == ' L\n o\n g'
    bild2(Mg.ax[1], Mg.g1, Mg.n_c1, Mg.n_c2)
    cap.Pl.flush_figs()
    Gt.msg('toggle_log finished')


def turn_3ax(clu, t_xyz):
    """Create a new cluster by rotating the input clu."""
    rot = [None, None, None]
    rot[0] = R.from_rotvec([t_xyz[0]*Mg.g2.direction], degrees=True)
    rot[1] = R.from_rotvec([t_xyz[1]*Mg.g2.d_cross_y], degrees=True)
    rot[2] = R.from_rotvec([t_xyz[2]*Mg.g2.perp], degrees=True)
    turned = copy.deepcopy(clu)
    for r_ax in rot:  # do 3 rotations consecutively, alternative did not work
        turned.r_vec = r_ax.apply(turned.r_vec)
        turned.p_vec = r_ax.apply(turned.p_vec)
    turned.refresh_tables()  # can that be avoided?
    return turned


def plt_7_axes():
    """Plot 7 subplots."""
    Mg.n_c1 = cap.Mg.clu_act
    Mg.n_c2 = turn_3ax(cap.Mg.clu_act, Mg.t_xyz)
    bild1(Mg.ax[0], Mg.g1)
    bild2(Mg.ax[1], Mg.g1, Mg.n_c1, Mg.n_c2)
    bild3(Mg.ax[2], Mg.g1, Mg.n_c2)
    bild4(Mg.ax[3], Mg.n_c1, Mg.n_c2)
    bild5(Mg.ax[4])
    bild6(Mg.ax[5])
    bild7(Mg.ax[6])
    cap.Pl.flush_figs()


def plot_plane(axe, ens, loc, trp):
    """For 3 planes."""
    ens -= ens.min()
    ens /= ens.max()
    plane_colors = np.empty(ens.shape)  # the size of the color map
    plane_colors = bgrs(ens)      # the type of color coding
    plane_colors[:, :, 3] = trp      # transparency
    h_x = np.ones(ens.shape)*loc
    axe.plot_surface(h_x, Mg.g1.x_pl, Mg.g1.y_pl, facecolors=plane_colors,
                     linewidth=0, shade=False)


def maxwell(cube1, cube2):
    """Get component of the stress tensor in the middle plane."""
    st_g1 = copy.deepcopy(Mg.g1)
    st_g1.new_geo(d_p=Mg.clu_dist/2)  # the middle plane
    st_g2 = copy.deepcopy(Mg.g2)
    st_g2.new_geo(d_p=-Mg.clu_dist/2)  # the middle plane

    h_1 = cube1.ind_x(st_g1.surface[:, :, 0], st_g1.surface[:, :, 1],
                      st_g1.surface[:, :, 2], st_g1.direction)
    h_2 = cube1.ind_x(st_g1.surface[:, :, 0], st_g1.surface[:, :, 1],
                      st_g1.surface[:, :, 2], st_g1.perp)
    h_3 = cube1.ind_x(st_g1.surface[:, :, 0], st_g1.surface[:, :, 1],
                      st_g1.surface[:, :, 2], st_g1.d_cross_y)
    h_4 = cube2.ind_x(st_g2.surface[:, :, 0], st_g2.surface[:, :, 1],
                      st_g2.surface[:, :, 2], st_g2.direction)
    h_5 = cube2.ind_x(st_g2.surface[:, :, 0], st_g2.surface[:, :, 1],
                      st_g2.surface[:, :, 2], st_g2.perp)
    h_6 = cube2.ind_x(st_g2.surface[:, :, 0], st_g2.surface[:, :, 1],
                      st_g2.surface[:, :, 2], st_g2.d_cross_y)
    h_1 += h_4
    h_2 += h_5
    h_3 += h_6
    return h_1*h_1 - h_2*h_2 - h_3*h_3  # factor 1/mue_0/2 ignored


def plt_3_planes(axe, cube1, cube2):
    """Plot planes with g connected to cube1 in extra 3d plot."""
    Mg.g1.new_geo(width=2)  # just in case ....
    ens1 = cube1.potential(Mg.g1.surface[:, :, 0], Mg.g1.surface[:, :, 1],
                           Mg.g1.surface[:, :, 2])
    Mg.g2.new_geo(route=Mg.g1.route(), d_p=-Mg.g1.d_p(), width=Mg.g1.width())
    ens2 = cube2.potential(Mg.g2.surface[:, :, 0], Mg.g2.surface[:, :, 1],
                           Mg.g2.surface[:, :, 2])
    ens3 = maxwell(cube1, cube2)
    plot_plane(axe, ens1, Mg.g1.d_p(), Mg.trp_adj.val)
    plot_plane(axe, ens2, Mg.clu_dist+Mg.g2.d_p(), Mg.trp_adj.val)
    plot_plane(axe, ens3, Mg.clu_dist/2, 1)  # not transparent


def bild1(axi, geo):
    """Plot the figure on the left hand side."""
    cap.Pl.plot_clu(axi, Mg.n_c1)
    Tl.plot_direction(axi, geo)


def bild2(axi, geo, cube1, cube2):
    """Plot U(r) on linear or loglog scale."""
    axi.cla()
    xlab = 'distance $r$ between the two cluster centers along '
    xlab += f'({geo.direction[0]:3.1f}, ' + \
            f'{geo.direction[1]:3.1f}, ' + \
            f'{geo.direction[2]:3.1f})'
    axi.set_xlabel(xlab)
    dist = np.geomspace(2*cube1.radius(),
                        20*cube1.radius(),
                        201)  # any factor 2 to be considered?
    hhh = np.zeros(shape=(201, 3))
    for i, dis in enumerate(dist):
        hhh[i] = dis*geo.direction
    pot = interaction(cube1, cube2, hhh)
    col = 'red' if pot[-1] > 0 else 'blue'
    if Mg.loglog:
        i = -int(len(pot)/10)  # somewhat arbitrary
        slope = np.log(pot[-1] / pot[i])
        slope /= np.log(dist[-1]/dist[i])

        axi.set_ylabel('potential Energy $|U(r)|$')
        axi.loglog(dist, np.abs(pot), color=col, label=r'$U(r)$', alpha=0.5)
        hstr = r'$U\propto r^{' + f'{slope:2.0f}' + '}$'
        axi.loglog(dist[i*10: -1],
                   abs(pot[-1])*pow(dist[i*10: -1]/dist[-1], slope),
                   color="black", linestyle='dashed',
                   label=hstr)
        axi.legend()
    else:
        axi.set_ylabel('potential Energy $U(r)$')
        axi.plot(dist, pot, color=col, label=r"$B_\mathrm{x}$")


def bild3(axi, geo, mycube):
    """Subplot of the rotated cluster on the right hand side."""
    cap.Pl.plot_clu(axi, mycube)
    plot_target(axi, geo)


def bild4(axi, cube1, cube2):
    """Plot the interaction field around d."""
    axi.cla()
    axi.xaxis.set_major_locator(plt.MultipleLocator(5))
    axi.yaxis.set_major_locator(plt.MultipleLocator(1))
    axi.zaxis.set_major_locator(plt.MultipleLocator(1))
    axi.set_xlim3d(-1, Mg.clu_dist+1)
    axi.set_ylim3d(-2.1, 2.1)
    axi.set_zlim3d(-2.1, 2.1)
    axi.set_xlabel(r'$\mathbf{r}_\mathrm{along}$')
    axi.set_ylabel(r'$\mathbf{r}_{\perp ,\mathrm{hor}}$')
    axi.set_zlabel(r'$\mathbf{r}_{\perp ,\mathrm{up}}$')
    plt_3_planes(axi, cube1, cube2)


def bild5(axi):
    """Plot the field component perpendicular to the plane."""
    divider = make_axes_locatable(axi)
    cax = divider.append_axes("right", size="5%", pad=0.1)
    axi.cla()
    cax.cla()
    axi.xaxis.set_major_locator(plt.MultipleLocator(60))
    axi.yaxis.set_major_locator(plt.MultipleLocator(60))
    axi.set_xlim(0, 360)
    axi.set_ylim(0, 360)
    axi.set_xlabel(r'rotation angle $\beta$')
    axi.set_ylabel(r'rotation angle $\gamma$')
    x_pl, y_pl = ((np.linspace(0, 360, 37), np.linspace(0, 360, 37)))
    h_5 = np.zeros((37, 37), dtype='double')
    for i in range(0, 37):  # corresponds to beta
        for j in range(0, 37):  # corresponds to gamma
            h_5[j, i] = two_clu_interaction(Mg.t_xyz[0], i*10., j*10.)
            # if i !=10 else 10
    axi.contour(x_pl, y_pl, h_5, cmap=plt.get_cmap('gray'))
    surf = axi.imshow(h_5, interpolation='bilinear', origin='lower',
                      cmap=bgrs, extent=(0, 360, 0, 360))
    cbar = Mg.fig.colorbar(surf, cax=cax, ticks=[h_5.min(), 0, h_5.max()],
                           shrink=0.5, aspect=5)
    cbar.ax.set_yticklabels(['', '', ''])
    axi.text(360, 360, f'       {h_5.max():.2g}', fontsize=10, va='center')
    axi.text(360, 180, '        0', fontsize=10, va='center')
    axi.text(360,   0, f'       {h_5.min():.2g}', fontsize=10, va='center')
    axi.plot(Mg.t_xyz[1], Mg.t_xyz[2], 'o', color='cyan')


def bild6(axi):
    """Orientation of original cluster."""
    Tl.plot_theta_phi(axi, Mg.n_c1)


def bild7(axi):
    """Orientation of rotated cluster."""
    Tl.plot_theta_phi(axi, Mg.n_c2)


def plot_target(axi, geo):
    """Indicate the direction r for the lines and planes."""
    start = -1.8*geo.direction
    axi.quiver(start[0], start[1], start[2],
               geo.direction[0], geo.direction[1], geo.direction[2],
               length=1.8, arrow_length_ratio=0.2, normalize=False, color='c')


def interaction(cu1, cu2, dist):
    """Calculate the interaction potential at distance dist(xpos,ypos,zpos)."""
    summe = np.zeros(len(dist))
    for k, dst in enumerate(dist):
        for i in range(0, cu1.n_dip()):
            for j in range(0, cu2.n_dip()):
                r_2 = np.sum(np.square(cu1.r_vec[i]-cu2.r_vec[j]-dst))
                p_1 = np.dot(cu1.p_vec[i], cu2.p_vec[j])*r_2
                p_2 = np.dot(cu1.p_vec[i], cu1.r_vec[i]-cu2.r_vec[j]-dst)
                p_3 = np.dot(cu2.p_vec[j], cu1.r_vec[i]-cu2.r_vec[j]-dst)
                summe[k] += (p_1 - 3*p_2*p_3)/pow(r_2, 2.5)
    return summe


def two_clu_interaction(alpha, beta, gamma):
    """Examine interaction with a rotated cluster at distance 10."""
    cu1 = cap.Mg.clu_act
    cu2 = turn_3ax(cap.Mg.clu_act, [alpha, beta, gamma])
    return interaction(cu1, cu2, [10*Mg.g1.direction])


if __name__ == '__main__':
    Mg.number_of_elements()
    cap.Mg.select_start_cluster(cluster='Cube')
    a = np.array([1, 2, 3]*5)
    a = a.reshape(5, 3)
    print(a)
    print('new:', interaction(cap.Mg.clu_act, cap.Mg.clu_act, a))
