"""Clusters and Shapes (GeoPlane(GeoLine)) to illustrate them.

last change May 1, 2024, PY 3.11.7 tested, analysis 10.00/10
Ingo Rehberg, University of Bayreuth
"""
import random
from itertools import permutations
import numpy as np
from scipy.spatial.transform import Rotation as R
from scipy import stats
import scipy.cluster.hierarchy as sch


class GeoLine:
    """Class describing the line to be examined."""

    def __init__(self, **kw):  # perp. plane
        self.line_d = {'length': 20., 'route': np.array([0., 0., 1.]),
                       'origin': np.array([0, 0, 0])}
        self.new_line(**kw)

    def length(self):
        """Return the actual length of the line."""
        return self.line_d['length']

    def origin(self):
        """Return the actual origin for the line."""
        return self.line_d['origin']

    def route(self):
        """Return the actual unscaled direction pf the line."""
        return self.line_d['route']

    def new_line(self, **kw):
        """Set a new line in 3d-space."""
        self.line_d.update(kw)
        ric = np.array(self.route())
        self.direction = ric / np.sqrt(np.dot(ric, ric))
        ori = self.line_d['origin']
        leng = self.line_d['length']
        hhh = np.concatenate((np.array([0.]),
                              np.geomspace(1.e-5, leng, 1000)))  # 28.2.2024
        self.p_x = hhh*self.direction[0] - ori[0]
        self.p_y = hhh*self.direction[1] - ori[1]
        self.p_z = hhh*self.direction[2] - ori[2]
        self.p_s = np.sqrt(pow(self.p_x, 2) +
                           pow(self.p_y, 2) +
                           pow(self.p_z, 2))  # dist to orig, changed 28.2.2024
        self.cdist = np.sqrt(pow(self.p_x+ori[0], 2) +
                             pow(self.p_y+ori[1], 2) +
                             pow(self.p_z+ori[2], 2))  # distance from center


class GeoPlane(GeoLine):
    """Defines a plane perpendicular to the given line."""

    def __init__(self, **kw):  # perpendicular plane
        super().__init__()
        plane_d = {'d_p': 0.5, 'width': 1.6, 'st_sz': 0.01}
        self.geo_d = {**self.line_d, **plane_d}
        ric = self.geo_d['route']
        self.direction = ric/np.sqrt(np.dot(ric, ric))
        self.new_geo(**kw)

    def d_p(self):
        """Return the distance of the plane."""
        return self.geo_d['d_p']

    def width(self):
        """Return the width of the quadratic plane."""
        return self.geo_d['width']  # quadratic plane from - to + width

    def st_sz(self):
        """Return the size of the the plane grid."""
        return self.geo_d['st_sz']

    def iii(self):
        """Return number of grid elements within one direction."""
        return int(2.*self.width()/self.st_sz()+1.1)  # resolution of the plane

    def new_geo(self, **kw):
        """Generate a plane perpendicular to "direction", within the d_p."""
        self.geo_d.update(kw)
        self.new_line(**self.geo_d)
        if self.direction[0] != 0:  # used by explore_dipole_clusters
            self.d_cross_y = (-self.direction[1], self.direction[0], 0) \
                / np.sqrt(self.direction[1]*self.direction[1] +
                          self.direction[0]*self.direction[0])
        else:
            self.d_cross_y = (self.direction[2], 0, -self.direction[0])\
                / np.sqrt(
                self.direction[2]*self.direction[2] +
                self.direction[0]*self.direction[0])

        self.perp = np.cross(self.direction, self.d_cross_y)
        self.surface = np.zeros((self.iii(), self.iii(), 3))  # set to zero
        wid = self.width()
        self.x_pl, self.y_pl = np.meshgrid(
            np.arange(-wid, wid+self.st_sz()/2., self.st_sz()),
            np.arange(-wid, wid+self.st_sz()/2., self.st_sz()))
        for i in range(self.iii()):
            f_i = -wid+i*self.st_sz()
            for j in range(self.iii()):
                f_j = -wid+j*self.st_sz()
                for k in range(3):
                    self.surface[j, i, k] += (f_i*self.d_cross_y[k] +
                                              f_j*self.perp[k] +
                                              self.direction[k]*self.d_p())


class Cluster:
    """Define an ensemble of magnetic dipoles, together with some functions."""

    def __init__(self, number_of_dipoles=7):
        n_dip = number_of_dipoles
        # self.exclude = n_dip+1  # calc. field & potential without this one
        self.r_vec = np.zeros((n_dip, 3), dtype='double')  # positions
        self.p_vec = np.zeros((n_dip, 3), dtype='double')  # dipole moment
        self.det = np.zeros((n_dip*n_dip, 3), dtype='double')  # int.
        self.dp3t = np.zeros(n_dip*n_dip, dtype='double')  # internal
        self.b_ext = np.zeros(3, dtype='double')  # the external field
        self.lengths = np.ones(n_dip)  # use this array to rescale p_vec
        self.create_indices()

    def n_dip(self):
        """Return the number of dipoles in the cluster."""
        return len(self.r_vec)

    def scale_p(self):
        """Scale all p_vec by the original lengths array."""
        hhh = self.lengths/np.sqrt(np.sum(self.p_vec**2, axis=1))  # 10xfaster
        self.p_vec *= hhh[:, np.newaxis]

    def set_length(self, i, length):
        """Set the dipole moment no i to length."""
        self.lengths[i] = length
        self.p_vec[i] *= self.lengths[i]/np.sqrt(np.dot(self.p_vec[i],
                                                        self.p_vec[i]))

    def length(self, i):
        """Give strength (length of vector) of dipole no. i  in the cluster."""
        return np.sqrt(np.sum(self.p_vec[i, :]**2))

    def create_indices(self):
        """Set up index tables for a faster operation with numpy."""
        # self.index = np.ones(self.n_dip(), dtype='int')  # internal use
        index = np.ones(self.n_dip(), dtype='int')  # internal use
        for k in range(self.n_dip()):
            index[k] = k*self.n_dip()
        num_ww = (self.n_dip()*(self.n_dip()-1))//2  # num. of interac. pairs
        i_p = np.zeros(num_ww, dtype='int')  # internal use
        p_p = np.zeros(num_ww, dtype='int')
        d_p = np.zeros(num_ww, dtype='int')
        l_o = 0
        for i in range(1, self.n_dip()):  # no self interaction
            for prtnr in range(0, i):  # partner < ldip
                i_p[l_o] = i
                p_p[l_o] = prtnr
                d_p[l_o] = prtnr*self.n_dip() + i
                l_o += 1
        self.ipa = [i_p, p_p, d_p, index]  # reduce instance attrib.

    def refresh_tables(self):
        """Call after change of position or length."""
        for i in range(0, self.n_dip()):  # calcuate geom. cluster constants
            for prtnr in range(0, self.n_dip()):  # redundancies are ok here
                if prtnr == i:
                    continue
                pnt = prtnr*self.n_dip() + i
                d_h = self.r_vec[i] - self.r_vec[prtnr]
                dist = np.sqrt(np.dot(d_h, d_h))
                self.det[pnt] = d_h * np.sqrt(3.) / dist
                self.dp3t[pnt] = 1. / pow(dist, 3)

    def b_loc_fast(self, i):
        """Get local field vector faster."""
        my_ind = self.ipa[3] + i
        a_h = (self.det[my_ind, 0] * self.p_vec[:, 0] +
               self.det[my_ind, 1] * self.p_vec[:, 1] +
               self.det[my_ind, 2] * self.p_vec[:, 2])
        b_h = self.det[my_ind]*np.array([a_h, a_h, a_h]).T
        b_h -= self.p_vec
        b_h *= np.array([self.dp3t[my_ind]]*3).T
        return np.sum(b_h, axis=0)+self.b_ext

    def adjust_dir_loc(self, i):
        """Adjust the dipole i to the local field from the other ones."""
        my_ind = self.ipa[3]+i
        a_h = (self.det[my_ind, 0] * self.p_vec[:, 0] +
               self.det[my_ind, 1] * self.p_vec[:, 1] +
               self.det[my_ind, 2] * self.p_vec[:, 2])
        b_h = self.det[my_ind]*np.array([a_h, a_h, a_h]).T
        b_h -= self.p_vec
        b_h *= np.array([self.dp3t[my_ind]]*3).T
        c_h = np.sum(b_h, axis=0)+self.b_ext
        c_h *= self.lengths[i] / \
            np.sqrt(c_h[0]*c_h[0]+c_h[1]*c_h[1]+c_h[2]*c_h[2])
        self.p_vec[i] = c_h

    def w_total_fast(self):
        """Get the total energy faster for large arrays."""
        return np.sum((
            (self.p_vec[self.ipa[0], 0]*self.p_vec[self.ipa[1], 0] +
             self.p_vec[self.ipa[0], 1]*self.p_vec[self.ipa[1], 1] +
             self.p_vec[self.ipa[0], 2]*self.p_vec[self.ipa[1], 2])
            - (self.p_vec[self.ipa[0], 0]*self.det[self.ipa[2], 0] +
               self.p_vec[self.ipa[0], 1]*self.det[self.ipa[2], 1] +
               self.p_vec[self.ipa[0], 2]*self.det[self.ipa[2], 2])
            * (self.p_vec[self.ipa[1], 0]*self.det[self.ipa[2], 0] +
               self.p_vec[self.ipa[1], 1]*self.det[self.ipa[2], 1] +
               self.p_vec[self.ipa[1], 2]*self.det[self.ipa[2], 2])
        ) * self.dp3t[self.ipa[2]])/self.n_dip()

    def potential(self, x_p, y_p, z_p, excl=None):
        """Calculate the potential at positions x_p, y_p, z_p."""
        rvec = np.delete(self.r_vec, excl) if excl is not None else self.r_vec
        pvec = np.delete(self.p_vec, excl) if excl is not None else self.p_vec
        old_shape = np.shape(x_p)  # assuming all 3 shapes are equal
        h_3_n = np.transpose(np.array([x_p.ravel(), y_p.ravel(), z_p.ravel()]))
        r_3_n = -rvec[:, np.newaxis] + h_3_n  # 8, 4, 3
        r_2 = np.sum(r_3_n*r_3_n, axis=2)
        nen = np.clip(np.power(r_2, 1.5), 1.e-30, None)  # avoid Zeros
        scpr = pvec[:, np.newaxis]*r_3_n
        pot = np.sum(scpr/nen[:, :, np.newaxis],
                     axis=(0, 2)) + np.sum(h_3_n*self.b_ext, axis=1)
        return pot.reshape(old_shape)

    def b_flux(self, x_p, y_p, z_p, excl=None):
        """Calculate the 3-d field at broadcastable positions x, y, z."""
        rvec = np.delete(self.r_vec, excl) if excl is not None else self.r_vec
        pvec = np.delete(self.p_vec, excl) if excl is not None else self.p_vec
        lbx = np.array([x_p, y_p, z_p])*0.
        for i, r_v in enumerate(rvec):
            r_x = x_p-r_v[0]
            r_y = y_p-r_v[1]
            r_z = z_p-r_v[2]
            r_2 = r_x*r_x+r_y*r_y+r_z*r_z
            scpr = pvec[i, 0]*r_x+pvec[i, 1]*r_y+pvec[i, 2]*r_z
            # nen = pow(r_2, 2.5)
            # scale the field to B_inside
            lbx[0] += np.where(r_2 < 0.25, pvec[i, 0],
                               (3.*scpr*r_x - r_2*pvec[i, 0])/pow(r_2, 2.5)/16)
            lbx[1] += np.where(r_2 < 0.25, pvec[i, 1],
                               (3.*scpr*r_y - r_2*pvec[i, 1])/pow(r_2, 2.5)/16)
            lbx[2] += np.where(r_2 < 0.25, pvec[i, 2],
                               (3.*scpr*r_z - r_2*pvec[i, 2])/pow(r_2, 2.5)/16)
        return lbx[0]+self.b_ext[0], lbx[1]+self.b_ext[1], lbx[2]+self.b_ext[2]

    def b_flux_2(self, x_p, y_p, z_p):
        """Calculate the field energy density at position x, y, z."""
        lbx, lby, lbz = self.b_flux(x_p, y_p, z_p)
        return lbx*lbx + lby*lby + lbz*lbz

    def radial_flux(self, x_p, y_p, z_p):
        """Calculate the r-component of the field at position x, y, z."""
        lbx, lby, lbz = self.b_flux(x_p, y_p, z_p)
        return (lbx*x_p+lby*y_p+lbz*z_p)/np.sqrt(x_p*x_p+y_p*y_p+z_p*z_p)

    def relaxation(self):
        """Turn all dipoles into the local B-field."""
        for i in range(self.n_dip()):
            xyz = self.b_loc_fast(i)  # external field is included here
            xyz *= self.lengths[i]/np.sqrt(np.dot(xyz, xyz))  # l = lengths[i]
            self.p_vec[i] = xyz

    def ind_x(self, x_p, y_p, z_p, direction):
        """Calculate the field component along dir at position px, py, pz."""
        lbx, lby, lbz = self.b_flux(x_p, y_p, z_p)
        return lbx*direction[0]+lby*direction[1]+lbz*direction[2]

    def create_rnd(self):
        """Set orientations randomly and keep the length."""
        for i in range(self.n_dip()):
            self.p_vec[i] = (random.random()*2-1,
                             random.random()*2-1,
                             random.random()*2-1)
            self.p_vec[i] *= self.lengths[i]/np.sqrt(np.dot(self.p_vec[i],
                                                            self.p_vec[i]))

    def radius(self):
        """Get cluster radius as maximal distance from origin, +diam./2."""
        return np.sqrt(np.sum(self.r_vec**2, axis=1).max())+0.5

    def extension(self):
        """Get cluster extension along x, y, z +diameter."""
        return np.array([self.r_vec[:, 0].max()-self.r_vec[:, 0].min(),
                        self.r_vec[:, 1].max()-self.r_vec[:, 1].min(),
                        self.r_vec[:, 2].max()-self.r_vec[:, 2].min()]) + 1

    def m_tor(self):
        """Get the toroidial moment of cluster."""
        return np.sum(np.cross(self.r_vec, self.p_vec), axis=0)


def mindist(c_n):
    """Quick and dirty tool."""
    min_dist = 1954
    for i in range(c_n.n_dip()):
        for j in range(i+1, c_n.n_dip()):
            dist = np.sqrt(np.sum((c_n.r_vec[i]-c_n.r_vec[j])**2))
            if dist < min_dist:
                min_dist = dist
    return min_dist


def pm8(i):
    """Switch between plus and minus, just a shorthand notation."""
    return pow(-1, i//8)


def pm4(i):
    """Switch between plus and minus, just a shorthand notation."""
    return pow(-1, i//4)


def pm2(i):
    """Switch between plus and minus."""
    return pow(-1, i//2)


def pm1(i):
    """Switch between plus and minus."""
    return pow(-1, i)


def perm_e(i):
    """Get the even permutations of a triple."""
    return [[i[0], i[1], i[2]], [i[1], i[2], i[0]], [i[2], i[0], i[1]]]


def par(lst):
    """Get some parity of a list."""
    return sum(1 for (x, px) in enumerate(lst)
               for (y, py) in enumerate(lst)
               if x < y and px > py
               ) % 2 == 0


def perm_e_norpt(liste):
    """Create a list of even permutations of lst, used for 4d."""
    cmp = par(liste)
    return [list(p) for p in permutations(liste) if par(p) == cmp]


class Tetrahedron(Cluster):
    """Notation from Wikipedia, Platonic solid."""

    def __init__(self):
        super().__init__(number_of_dipoles=4)
        self.set_tetrahedron()

    def set_tetrahedron(self):
        """Define the vertex coordinates."""
        self.r_vec[0] = (1, 1,  1)
        self.r_vec[1:] = np.array(perm_e([1, -1, -1]))
        self.r_vec /= np.sqrt(8.)
        self.create_rnd()
        self.refresh_tables()


class Octahedron(Cluster):
    """Notation from Wikipedia, Platonic solid."""

    def __init__(self):
        super().__init__(number_of_dipoles=6)
        self.set_octahedron()

    def set_octahedron(self):
        """Define the vertex coordinates."""
        self.r_vec[0] = (1, 0,  0)
        self.r_vec[1] = (-1, 0,  0)
        self.r_vec[2] = (0,  1,  0)
        self.r_vec[3] = (0, -1,  0)
        self.r_vec[4] = (0,  0, -1)
        self.r_vec[5] = (0,  0, +1)
        self.r_vec /= np.sqrt(2.)
        rzy = R.from_euler('zy', [np.arcsin(np.sqrt(1/2.)),
                                  np.arctan(np.sqrt(2.))], degrees=False)
        self.r_vec = rzy.apply(self.r_vec)

        self.create_rnd()
        self.refresh_tables()


class Cube(Cluster):
    """Hexahedron and larger curbes."""

    def __init__(self, tau_deg=0, edge=(2, 2, 2)):
        super().__init__(number_of_dipoles=edge[0]*edge[1]*edge[2])
        self.edge = list(edge)
        self.tau = None if tau_deg is None else np.pi/180. * tau_deg
        self.set_cube(tau_deg)

    def set_cube(self, tau_deg):
        """Define cubes with variable edge lengths."""
        count = -1
        for i in range(self.edge[0]):
            for j in range(self.edge[1]):
                for k in range(self.edge[2]):
                    count += 1
                    self.r_vec[count] = (i, j, k)
        self.r_vec -= np.array([(self.edge[0]-1)/2,
                                (self.edge[1]-1)/2,
                                (self.edge[2]-1)/2])
        if tau_deg is not None and self.edge == [2, 2, 2]:
            self.set_ang(tau_deg)
        else:
            self.create_rnd()
        self.refresh_tables()

    def set_ang(self, tau_deg):
        """Ground state: Schönke, Schneider, Rehberg, PRBB91, 020410 (2015)."""
        self.tau = np.pi/180. * tau_deg  # in radian
        if self.edge != [2, 2, 2]:
            self.create_rnd()
        else:
            self.p_vec[:, 0] = -np.sin(self.tau - 4./3.*np.pi)\
                * self.r_vec[:, 1]*self.r_vec[:, 2]  # wild guess
            self.p_vec[:, 1] = -np.sin(self.tau - 2./3.*np.pi)\
                * self.r_vec[:, 0]*self.r_vec[:, 2]
            self.p_vec[:, 2] = -np.sin(self.tau)\
                * self.r_vec[:, 0]*self.r_vec[:, 1]
            self.p_vec *= np.sqrt(2./3.)*4


class Icosahedron(Cluster):
    """Notation from Wikipedia, Platonic solid."""

    def __init__(self):
        super().__init__(number_of_dipoles=12)
        self.set_icosahedron()

    def set_icosahedron(self):
        """Define the vertex coordinates."""
        phi = (1+np.sqrt(5.))/2  # golden ratio
        for i in range(4):
            self.r_vec[3*i:3*i+3] = perm_e([0, pm2(i), phi*pm1(i)])
        self.r_vec /= 2.
        self.create_rnd()
        self.refresh_tables()


class Dodecahedron(Cluster):
    """Notation from Wikipedia, Platonic solid."""

    def __init__(self):
        super().__init__(number_of_dipoles=20)
        self.set_dodecahedron()

    def set_dodecahedron(self):
        """Define the vertex coordinates."""
        phi = (1+np.sqrt(5.))/2  # golden ratio
        for i in range(8):
            self.r_vec[i] = [pm4(i), pm2(i), pm1(i)]  # list or tuple?
        for i in range(4):
            self.r_vec[8+3*i:11+3*i] = perm_e([0., pm2(i)/phi,  pm1(i)*phi])
        self.r_vec *= phi/2
        self.create_rnd()
        self.refresh_tables()


class Tube(Cluster):
    """Stack of rings."""

    def __init__(self, tri=False, n_rz=(6, 5), dst=0, mom=1):
        super().__init__(number_of_dipoles=n_rz[0]*n_rz[1])
        self.l_r = n_rz[0]  # length of a ring
        self.n_s = n_rz[1]  # number in stack
        self.m_v = [0, self.n_s-1]  # list of rings with variable dipole mom.
        self.tri = tri
        self.v_dist = 1  # for tri=False
        self.dst = dst  # distance between two half stacks
        self.mom = mom  # just for convenience
        self.set_stack()

    def set_ring(self, sta, delta, radius, offset):
        """Define the vertex coordinates of a single ring."""
        for r_p in range(self.l_r):  # ring position
            ang = r_p*delta
            if all([self.tri, sta % 2 == 1]):
                ang += delta/2
            dpn = r_p + sta*self.l_r  # dipol number
            self.r_vec[dpn] = (np.cos(ang)*radius,
                               np.sin(ang)*radius,
                               (sta-(self.n_s-1)/2)*self.v_dist + offset)

    def set_stack(self):
        """Define the vertex coordinates."""
        radius = 0.5/np.sin(np.pi/self.l_r)  # radius of a ring of d=1 spheres
        delta = 2*np.pi/self.l_r
        if self.tri:
            dxh = (np.cos(0)-np.cos(delta/2))*radius
            dyh = (np.sin(0)-np.sin(delta/2))*radius
            self.v_dist = np.sqrt(1-dxh*dxh-dyh*dyh)  # for set_ring
        for sta in range(self.n_s):  # stack number
            self.set_ring(sta, delta, radius, 0)
        self.adjust_dst(self.dst)  # somewhat redundant, but more flexible
        self.create_rnd()
        self.adjust_m(self.mom)
        self.refresh_tables()

    def adjust_m(self, alpha):
        """Set moments of some rings to alpha."""
        self.mom = alpha
        for ring in self.m_v:
            for i in range(ring*self.l_r, (ring+1)*self.l_r):
                self.set_length(i, alpha)  # keep dir, change l

    def adjust_dst(self, distance):
        """Set distance of the outer rings to alpha."""
        self.dst = distance  # The distance between 2 half stacks
        radius = 0.5/np.sin(np.pi/self.l_r)  # radius of a ring of d=1 spheres
        delta = 2*np.pi/self.l_r
        for ring in range(self.n_s//2):  # ok for even
            self.set_ring(ring, delta, radius, -self.dst/2)
        for ring in range((self.n_s+1)//2, self.n_s):  # odd:middle stays fixed
            self.set_ring(ring, delta, radius, self.dst/2)
        self.refresh_tables()

    def set_radial(self):
        """Create the direction data for the arrows."""
        radius = 0.5/np.sin(np.pi/self.l_r)  # radius of a ring of d=1 spheres
        for sta in range(self.n_s):  # stack number
            sign = 1 if sta % 2 == 0 else -1
            for r_p in range(self.l_r):  # ring position
                dpn = r_p + sta*self.l_r  # dipol number
                self.p_vec[dpn] = (-sign*self.r_vec[dpn, 1]/radius,
                                   sign*self.r_vec[dpn, 0]/radius, 0)

    def pot_s(self, x_p, y_p, z_p, i):
        """Calculate the potential of dip[i] at position x, y, z."""
        r_x = x_p-self.r_vec[i, 0]
        r_y = y_p-self.r_vec[i, 1]
        r_z = z_p-self.r_vec[i, 2]
        r_2 = pow(r_x, 2)+pow(r_y, 2)+pow(r_z, 2)
        nen = np.clip(pow(r_2, 1.5), 1.e-30, None)
        scpr = self.p_vec[i, 0]*r_x + \
            self.p_vec[i, 1]*r_y + self.p_vec[i, 2]*r_z
        return scpr/nen

    def potential(self, x_p, y_p, z_p, excl=None):
        """Calculate potential at positions x, y, z. Special tube function."""
        pot = 0.
        half = self.l_r//2
        odd = bool(self.l_r % 2)
        for sta in range(self.n_s):  # stack number
            pot_ring = 0
            for r_p in range(half):  # ring position only even l_r
                dpn = r_p + sta*self.l_r
                # if dpn == exclude:
                #    continue
                acc = self.pot_s(x_p, y_p, z_p, dpn) + \
                    self.pot_s(x_p, y_p, z_p, dpn+half)
                pot_ring += acc
            if odd:
                pot_ring += self.pot_s(x_p, y_p, z_p, (sta+1)*self.l_r-1)
            pot += pot_ring
        return pot+x_p*self.b_ext[0]+y_p*self.b_ext[1]+z_p*self.b_ext[2]


class Rings(Cluster):
    """Concentric rings."""

    def __init__(self, l_r=(24, 30, 36), rad=(49.6, 63.6, 77.6)):
        super().__init__(number_of_dipoles=np.sum(np.array(l_r)))
        self.l_r = l_r  # length of a ring
        self.rad = np.array(rad)
        self.rad *= 0.5/np.sin(np.pi/l_r[0]) / rad[0]
        self.set_stack()

    def set_ring(self, num, dpn, delta, radius):
        """Define the vertex coordinates of a single ring."""
        for r_p in range(num):  # ring position
            ang = r_p*delta
#            if all([True, sta % 2 == 1]):
#                ang += delta/2
            self.r_vec[r_p+dpn] = (np.cos(ang)*radius,
                                   np.sin(ang)*radius,
                                   0)

    def set_stack(self):
        """Define the vertex coordinates."""
        dpn = 0
        for num in self.l_r:  # stack number
            radius = 0.5/np.sin(np.pi/num)  # radius of a ring of d=1 spheres
            delta = 2*np.pi/num
            self.set_ring(num, dpn, delta, radius)
            dpn += num
        # self.adjust_dst(self.dst)  # somewhat redundant, but more flexible
        self.create_rnd()
        self.refresh_tables()

    def set_radial(self):
        """Create the direction data for the arrows."""
        occ = 0
        for i, radius in enumerate(self.rad):  # stack number
            for r_p in range(self.l_r[i]):  # ring position
                dpn = r_p + occ  # dipol number
                self.p_vec[dpn] = (-self.r_vec[dpn, 1]/radius,
                                   self.r_vec[dpn, 0]/radius, 0)
            occ += self.l_r[i]

    def pot_s(self, x_p, y_p, z_p, i):
        """Calculate the potential of dip[i] at position x, y, z."""
        r_x = x_p-self.r_vec[i, 0]
        r_y = y_p-self.r_vec[i, 1]
        r_z = z_p-self.r_vec[i, 2]
        r_2 = pow(r_x, 2)+pow(r_y, 2)+pow(r_z, 2)
        nen = np.clip(pow(r_2, 1.5), 1.e-30, None)
        scpr = self.p_vec[i, 0]*r_x + \
            self.p_vec[i, 1]*r_y + self.p_vec[i, 2]*r_z
        return scpr/nen


class HexagonFilled(Cluster):
    """Stack of filled hexagons."""

    def __init__(self, n_s):
        super().__init__(number_of_dipoles=7*n_s)
        self.n_s = n_s  # number in stack
        self.relax_start = 0
        self.hex_2d = False

        self.set_stack()

    def set_stack(self):
        """Define the vertex coordinates."""
        radius = 0.5/np.sin(np.pi/6)  # radius of a 6-ring of d=1 spheres
        for sta in range(self.n_s):  # stack number
            self.r_vec[sta*7] = (0, 0, sta-(self.n_s-1)/2)
            for r_p in range(1, 7):  # ring position
                dpn = r_p + sta*7    # dipol number
                self.r_vec[dpn] = (np.cos(r_p*2*np.pi/6)*radius,
                                   np.sin(r_p*2*np.pi/6)*radius,
                                   sta-(self.n_s-1)/2)
        self.create_rnd()
        self.refresh_tables()

    def adjust_m(self, alpha):
        """Set moments of some rings to alpha."""
        for i in range(self.n_s):
            self.set_length(i*7, alpha)  # keep dir, change l

    def adjust_center(self, alpha):
        """Set moment of the center dipole to alpha."""
        self.p_vec[0] *= alpha / np.sqrt(np.dot(self.p_vec[0], self.p_vec[0]))
        self.lengths[0] = alpha

    def set_center_height(self, height):
        """Set moment of the center dipole to alpha."""
        self.r_vec[0, 2] = height
        self.refresh_tables()

    def relaxation(self):
        """Turn all dipoles into the local B-field, overwrites general func."""
        # start = 0 if mg.free else 1
        for i in range(self.relax_start, self.n_dip()):
            xyz = self.b_loc_fast(i)  # external field is included here
            if self.hex_2d:
                xyz[2] = 0.
            xyz *= self.length(i)/np.sqrt(np.dot(xyz, xyz))  # unity vector
            self.p_vec[i] = xyz


class HCP(Cluster):
    """Hexagonal close packed."""

    def __init__(self):
        super().__init__(number_of_dipoles=13)
        self.set_hcp()

    def set_hcp(self):
        """Define the vertex coordinates."""
        radius = 1.1  # radius of the sphere around the center
        rsq = radius*radius
        x_l = 1
        y_l = 1
        z_l = 1
        i_p = 0
        for i in range(-x_l, x_l+1):
            for j in range(-y_l, y_l+1):
                for k in range(-z_l, z_l+1):
                    hhh = (i+j*0.5, j*np.sqrt(3.)/2., k*np.sqrt(8./3.))
                    if np.dot(hhh, hhh) < rsq:
                        self.r_vec[i_p] = hhh
                        i_p += 1

                    hhh = (i+j*0.5 + (1+0.5)/3.,
                           j*np.sqrt(3.)/2.+(np.sqrt(3.)/2.)/3.,
                           k*np.sqrt(8./3.)+np.sqrt(2./3.))
                    if np.dot(hhh, hhh) < rsq:
                        self.r_vec[i_p] = hhh
                        i_p += 1
        self.create_rnd()
        self.refresh_tables()


class FCC(Cluster):
    """Face centered cubic."""

    def __init__(self, r_b=(-0.1, 2.1), ijk=(5, 5, 1)):
        self.r_b = list(r_b)
        self.ijk = list(ijk)
        super().__init__(number_of_dipoles=self.set_fcc(locate=False))
        self.set_fcc(locate=True)

    def pos(self, i, j, k):
        """Define coordinates within the fcc-lattice."""
        return (i+j*0.5+k*0.5,
                j*np.sqrt(3.)/2.+k*np.sqrt(1./12),
                k*np.sqrt(2./3.))

    def set_fcc(self, locate=True):
        """Count number of dipoles,  locate them if locate=True."""
        rsq = self.r_b[1]*self.r_b[1]
        i_p = 0
        for i in range(self.ijk[0]):
            for j in range(self.ijk[1]):
                for k in range(self.ijk[2]):
                    hhh = self.pos(i-(self.ijk[0]-1)/2,
                                   j-(self.ijk[1]-1)/2,
                                   k-(self.ijk[2]-1)/2)
                    if np.dot(hhh, hhh) < rsq and \
                       np.dot(hhh, hhh) > self.r_b[0]:
                        if locate:
                            self.r_vec[i_p] = hhh
                        i_p += 1
        if locate:
            self.create_rnd()
            self.refresh_tables()
        return i_p


class Rco(Cluster):
    """Rhombicuboctahedron, Herrnhuter Stern."""

    def __init__(self):
        super().__init__(number_of_dipoles=24)
        self.set_rco()

    def set_rco(self):
        """Define vertex coordinates of a small Rhombicuboctahedron."""
        phi = 1. + np.sqrt(2.)  # from wikipedia
        for i in range(8):
            self.r_vec[3*i:3*i+3] = perm_e([pm4(i), pm2(i), pm1(i)*phi])
        self.r_vec /= 2
        self.create_rnd()
        self.refresh_tables()


class C60(Cluster):
    """C_60 structure, Buckyball."""

    def __init__(self):
        super().__init__(number_of_dipoles=60)
        self.set_c60()

    def set_c60(self):
        """Define vertex coordinates of truncated icosahedron."""
        phi = (1. + np.sqrt(5.)) / 2  # golden mean
        for i in range(4):
            self.r_vec[3*i:3*i+3] = perm_e([0, pm2(i), 3*phi*pm1(i)])
        for i in range(4, 12):
            self.r_vec[3*i:3*i+3] = perm_e([pm4(i), (2+phi)*pm2(i),
                                            2*phi*pm1(i)])
        for i in range(12, 20):
            self.r_vec[3*i:3*i+3] = perm_e([phi*pm4(i), 2*pm2(i),
                                            (2*phi+1)*pm1(i)])
        self.r_vec /= 2
        self.create_rnd()
        self.refresh_tables()


def create_tetraplex(prm):
    """Define vertex coordinates of a 4-dim tetraplex."""
    pos = np.zeros((120, 4), dtype='double')  # positions
    for i in range(8):
        p_h = i//2
        hhh = pm1(i)
        pos[i] = [hhh*(p_h == 3), hhh*(p_h == 2),
                  hhh*(p_h == 1), hhh*(p_h == 0)]
    for i in range(8, 24):
        pos[i] = [prm[1]*pm8(i), prm[1]*pm4(i), prm[1]*pm2(i), prm[1]*pm1(i)]
    for i in range(8):
        j = 24+i*12
        pos[j:j+12] = perm_e_norpt([prm[0]*pm4(i), prm[1]*pm2(i),
                                    prm[2]*pm1(i), 0])
    return pos


def largest_cluster(rvec):
    """Create the largest possible cluster with minimal distance."""
    zzz = sch.linkage(rvec)
    ttt = sch.fcluster(zzz, zzz[:, 2].min()+0.0001, criterion='distance')
    return rvec[ttt == stats.mode(ttt, keepdims=False)[0]]


class Tetraplex(Cluster):
    """Take a subset from the 4-dimensional Tetraplex."""

    def __init__(self, without=0, cut=0):
        phi = (1. + np.sqrt(5.)) / 2  # golden mean
        prm = [phi/2, 0.5, 0.5/phi]
        pos = create_tetraplex(prm)
        cut_val = np.unique(pos[:, without])  # list of possible values, small
        choose = [0, 1, 2, 3]
        choose.remove(without)
        rvec = pos[:, choose]  # create a 3d-subset of coords, here [1,2,3]
        if cut != 0 and cut < len(cut_val):
            dell = []
            for i in range(len(pos)):
                if pos[i, without] != cut_val[cut]:  # taking a cut
                    dell.append(i)
            rvec = np.delete(rvec, dell, axis=0)
        rvec = np.unique(rvec, axis=0)
        rvec = largest_cluster(rvec)
        super().__init__(number_of_dipoles=len(rvec))
        self.cut = cut
        self.r_vec = rvec
        self.set_tetraplex()

    def set_tetraplex(self):
        """Define vertex coordinates of the 4-dim Tetraplex."""
        self.r_vec /= mindist(self)
        self.create_rnd()
        self.refresh_tables()


if __name__ == '__main__':
    from matplotlib import pyplot as plt

    def plot_contact(axi, c_n):
        """Plot contact lines by plotting all distances with length < 1.01."""
        prt = np.zeros((2, 3))
        for i, prt[0] in enumerate(c_n.r_vec):
            for prt[1] in c_n.r_vec[i+1:]:
                if np.sum((prt[0]-prt[1])**2) < 1.02:
                    axi.plot(prt[:, 0], prt[:, 1], prt[:, 2], '-', c='grey')

    clu = Rings()  # test to demonstrate the use
    clu.set_radial()
    print('number of vertices:', len(clu.r_vec))
    fig = plt.figure(figsize=(10, 6))
    axis = plt.axes([0, 0., 1, 1], projection='3d')
    axis.scatter(clu.r_vec[:, 0], clu.r_vec[:, 1], clu.r_vec[:, 2], s=50)
    plot_contact(axis, clu)
    plt.show()
