# this SciPy code was used to produce Figure 1B-G in the paper entitled "A novel catapult mechanism for male spiders to avoid sexual cannibalism" that was published in Current Biology by Zhang et al.

from scipy import *
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import math
from io import StringIO   # StringIO behaves like a file object


# This import registers the 3D projection, but is otherwise unused.
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401 unused import


#s1mat = np.array ([2,2+3.j])
smat = np.loadtxt(r"C:\Users\Yangjie\Documents\research\Shichang\Shichangcode\viedo13data.txt", dtype='float',  comments='#');

tN=len(smat);
ts   =len(smat);#62;#19;#raw data uses only 19 segments for time. 



#62;#75;#19;


T= 1/1500.;
L= 0.0218; #viedo 13

tsl=np.arange(0,ts,1)*T;
t= tsl; 
 #linspace(0,18,19)#
smat[:,0]=tsl;
smat[:,1:7:1]= L*smat[:,1:7:1]; 
#fit1    = np.polynomial.polynomial.polyval(t,coef3r[:,0]);
#polyval uses increasing power !
#fit2    = np.polynomial.polynomial.polyval(t,coef3r[:,1]);

#pstnl= np.polynomial.polynomial.polyval(t,coef3,tensor=True)


#s2  = s1mat[:,(0,3,4)];
#s3  = s1mat[:,(0,5,6)];
order = 5; #19
fitoc3   = np.polyfit(t,smat[:,1:7:1],order,full=True);
#stop at 7, so won't take 7.
coef3    = fitoc3[0];       #decreasing power
coef3r   = np.flipud(coef3) #increasing power

pstn= np.ones((6,tN));
vlct= np.ones((6,tN));
aclr= np.ones((6,tN));
spradii=np.ones((3,tN));
sprate=np.ones((3,tN));

for i in range(0,6):
    pi   = np.poly1d(coef3[:,i])
    vi   = np.polyder(coef3[:,i])
    ai   = np.polyder(coef3[:,i],2)
    
    pstn[i,:] = pi(t)
    vlct[i,:] = np.polynomial.polynomial.polyval(t,np.flipud(vi))
    aclr[i,:] = np.polynomial.polynomial.polyval(t,np.flipud(ai))
     
    #polyval uses increasing power ! 
    
# spin radii and spin rate

spradii[0,:]   =     ( (vlct[0,:]**2+vlct[1,:]**2) ** 1.5 ) / abs ( vlct[0,:]*aclr[1,:] - vlct[1,:]*aclr[0,:] ) ;
sprate[0,:]    =    abs ( vlct[0,:]*aclr[1,:] - vlct[1,:]*aclr[0,:] )  /  ( (vlct[0,:]**2+vlct[1,:]**2) )  / (2*np.pi) ;

spradii[1,:]   =     ( (vlct[2,:]**2+vlct[3,:]**2) ** 1.5 ) / abs( vlct[2,:]*aclr[3,:] - vlct[3,:]*aclr[2,:] ) ;
sprate[1,:]    =     abs( vlct[2,:]*aclr[3,:] - vlct[3,:]*aclr[2,:] )  /  ( (vlct[2,:]**2+vlct[3,:]**2) ) / (2*np.pi) ;

spradii[2,:]   =     ( (vlct[4,:]**2+vlct[5,:]**2) ** 1.5 ) / abs ( vlct[4,:]*aclr[5,:] - vlct[5,:]*aclr[4,:] ) ;
sprate[2,:]    =     abs( vlct[4,:]*aclr[5,:] - vlct[5,:]*aclr[4,:] )  / ( (vlct[4,:]**2+vlct[5,:]**2) )  / (2*np.pi) ;

# Set up a figure twice as tall as it is wide
fig = plt.figure(figsize=plt.figaspect(4./4.))
#fig.suptitle('Male')



ax = fig.add_subplot(3, 2, 1, projection='3d')



#surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1,
#                       linewidth=0, antialiased=False)


##########################
#fig = plt.figure()
#ax = fig.gca(projection =  '3d')
tsl = tsl *1000 # ms
ax.w_xaxis.gridlines.set_lw(1.0)
ax.w_yaxis.gridlines.set_lw(1.0)
ax.w_zaxis.gridlines.set_lw(1.0)
ax.plot(smat[:,1],tsl,smat[:,2],'ks',linewidth=2)
ax.plot(smat[:,3],tsl,smat[:,4], 'rs',linewidth=2)
ax.plot(smat[:,5],tsl,smat[:,6],'gs',linewidth=2)

ax.plot(pstn[0],tsl,pstn[1],'k',label='1',linewidth=2);
ax.plot(pstn[2],tsl,pstn[3],'r',label='2',linewidth=2);
ax.plot(pstn[4],tsl,pstn[5],'g',label='3',linewidth=2);



ax.set_xlabel('\n$x$[cm]',fontsize=15, linespacing =3.2);
ax.set_ylabel('\n$t$[ms]',fontsize=15, linespacing =3.2);
ax.set_zlabel('\n$y$[cm]',fontsize=15, linespacing =3.2);
ax.text2D(0.05, 0.85, "(a)", transform=ax.transAxes, fontsize =20)
plt.legend(loc='center left');


fig.add_subplot(3, 2, 2)
#plt.subplot(3,2,2);
plt.plot(vlct[0,1:],vlct[1,1:],'ks',vlct[2,1:],vlct[3,1:],'rs',vlct[4,1:],vlct[5,1:],'gs',
         linewidth=2);
#this works only for tN=ts case. Otherwise, delete this line pls.
plt.scatter(vlct[0,0],vlct[1,0], s=80, c="k", marker= "*", label="initial points")
plt.scatter(vlct[2,0],vlct[3,0], s=80, c="r", marker= "*")
plt.scatter(vlct[4,0],vlct[5,0], s=80, c="g", marker= "*")
#, vlct[2,0],vlct[3,0],c="r",vlct[4,0], vlct[5,0], c="g");
plt.plot(vlct[0],vlct[1],'k',linewidth=2);
plt.plot(vlct[2],vlct[3],'r',linewidth=2);
plt.plot(vlct[4],vlct[5],'g',linewidth=2);
plt.arrow(vlct[0,-2], vlct[1, -2], +vlct[0,-1] -vlct[0,-2], +vlct[1,-1] - vlct[1,-2], head_width=0.05, head_length=0.1, fc='k', ec='k');

plt.xlabel('$v_x$[cm/s]',fontsize=15);
plt.ylabel('$v_y$[cm/s]',fontsize=15);
plt.legend(loc='lower right');
plt.xlim([-10, 40])


plt.text(-8, 83, r'(b)', fontsize=20)
plt.grid();


fig.add_subplot(3, 2, 3)
aclr= aclr/100 # in unit of m/s2. 
plt.plot(aclr[0],aclr[1],'ks',aclr[2],aclr[3],'rs',aclr[4],aclr[5],'gs',
         linewidth=2);#this works only for tN=ts case. Otherwise, delete this line pls. 
plt.plot(aclr[0],aclr[1],'k',linewidth=2);
plt.plot(aclr[2],aclr[3],'r',linewidth=2);
plt.plot(aclr[4],aclr[5],'g',linewidth=2);
plt.scatter(aclr[0,0],aclr[1,0], s=80, c="k", marker= "*", label="initial points")
plt.scatter(aclr[2,0],aclr[3,0], s=80, c="r", marker= "*")
plt.scatter(aclr[4,0],aclr[5,0], s=80, c="g", marker= "*")

plt.xlabel('$a_x$[m/s$^2$]',fontsize=15)
plt.ylabel('$a_y$[m/s$^2$]',fontsize=15)
plt.legend(loc='lower right');

plt.text(-360, 380, r'(c)', fontsize=20)
plt.grid();


fig.add_subplot(3, 2, 4)

#plt.plot(vlct[0],vlct[1],'ks',vlct[2],vlct[3],'rs',vlct[4],vlct[5],'gs',
         #linewidth=2);#this works only for tN=ts case. Otherwise, delete this line pls. 

#plt.plot(vlct[2],vlct[3],'r',label='2',linewidth=2);
#plt.plot(vlct[4],vlct[5],'g',label='3',linewidth=2);

plt.plot(tsl, sqrt(vlct[0]**2+vlct[1]**2),'k',linewidth=2);
plt.plot(tsl, sqrt(vlct[2]**2+vlct[3]**2),'r',linewidth=2);
plt.plot(tsl, sqrt(vlct[4]**2+vlct[5]**2),'g',linewidth=2);

plt.xlabel('$t$[ms]',fontsize=15);
plt.ylabel('$v$[cm/s]',fontsize=15);
#plt.legend(loc='upper left');

plt.text(0.5, 75, r'(d)', fontsize=20)
plt.grid();


fig.add_subplot(3, 2, 5)
plt.plot(tsl, sqrt(aclr[0]**2+aclr[1]**2),'k',linewidth=2);
plt.plot(tsl, sqrt(aclr[2]**2+aclr[3]**2),'r',linewidth=2);
plt.plot(tsl, sqrt(aclr[4]**2+aclr[5]**2),'g',linewidth=2);
plt.xlabel('$t$[ms]', fontsize=15);
plt.ylabel('$a$[m/s$^2$]',fontsize=15);
#plt.legend(loc='upper left'); 
plt.grid();
plt.text(0.5,500, r'(e)', fontsize=20)
#rotation curvature radii
#plt.plot(vlct[0],vlct[1],'ks',vlct[2],vlct[3],'rs',vlct[4],vlct[5],'gs',
         #linewidth=2);#this works only for tN=ts case. Otherwise, delete this line pls. 
##plt.plot(tsl,spradii[0,:],'k',
##         label='1',linewidth=2);
##plt.plot(tsl,spradii[1,:],'r',
##         label='2',linewidth=2);
##plt.plot(tsl,spradii[2,:],'g',
##         label='3',linewidth=2);#
##plt.xlabel('$t$',fontsize=25);
##plt.ylabel('$r$',fontsize=25);
##plt.legend(loc='upper left');


fig.add_subplot(3, 2, 6)
plt.plot(tsl,sprate[0,:],'k',
         label='1',linewidth=2);
plt.plot(tsl,sprate[1,:],'r',
         label='2',linewidth=2);
plt.plot(tsl,sprate[2,:],'g',
         label='3',linewidth=2);
plt.xlabel('$t$[ms]',fontsize=15);
plt.ylabel('$s(t)$[rev/s]',fontsize=15);
#plt.legend(loc='upper right');
plt.ylim([0, 600])
plt.text(0.5, 500, r'(f)', fontsize=20)
plt.grid();
plt.tight_layout(pad = 0.15, w_pad = 0.05, h_pad =0.0); 
plt.show()

##########################

#print("max spin rate for node 1",sprate[0,:].argmax(), sprate[0,:].max())
#print("max spin rate for node 2",sprate[1,:].argmax(), sprate[1,:].max())
#print("max spin rate for node 3",sprate[2,:].argmax(), sprate[2,:].max())


#print ("projection angle for node 1: ", np.arctan( vlct[1,0]/vlct[0,0])*180/np.pi)
#print ("projection angle for node 2: ", np.arctan( vlct[3,0]/vlct[2,0])*180/np.pi)
#print ("projection angle for node 3: ", np.arctan( vlct[5,0]/vlct[4,0])*180/np.pi)

#projection angle:
#np.arctan( vlct[1,0]/vlct[0,0])*180/np.pi

#np.arctan( vlct[3,0]/vlct[2,0])*180/np.pi

#np.arctan( vlct[5,0]/vlct[4,0])*180/np.pi
#=63.7971deg for viedo 13. 

#sqv= sqrt(vlct[4]**2+vlct[5]**2);
#aqv= sqrt(aclr[4]**2+aclr[5]**2); 
#print ("max velocity: ", sqv.max())
#print ("max acceleration: ", aqv.max())