%% prepare workspace
addpath('code');

%% select data points
% a sampling scheme is used that uses more support points in the vicinity
% of the branch cut
c1 = 1/0.9; N = 601;
k1 = linspace(cutsqrt(0.9*c1,c1),2e-2i,25);
k2 = linspace(2e-2,cutsqrt(1.1*c1,c1),25);
f = load('data/w_200.mat'); f = f.f; % here the data is required
[R,~,~,~,zj] = aaa(f(:,1),[k1 k2],'tol',1e-6);

%% plot approximation in canonical domain
% in the canonical, i.e., the tranformed domain the function is 
% single-valued and well suited for rational approximation 

fs = 10; text_args = {'Interpreter','latex','Fontsize',fs}; 
k1m = abs(cutsqrt(0.9*c1,c1)); 
k1M = abs(cutsqrt(1.1*c1,c1));
yh1 = [linspace(.75,0,101) zeros(1,101)];
xh1 = [zeros(1,101) linspace(0,.75,101)];
args = {'EdgeColor','none','FaceLighting','gouraud','DiffuseStrength',...
    0.8,'SpecularStrength',0.5};
f1 = figure(1); clf(1); ax1 = axes(f1); ax1.NextPlot = 'add';
f1.Units = 'centimeters'; f1.Position(3:4) = [15 15]; ax1.Box = 'on';
view(ax1,3); ax1.XLim = [-0.5 1]; ax1.Visible = 'off';
ax1.ZLim = [0 1]; ax1.YLim = [-0.5 1]; ax1.DataAspectRatio = [1 1 2];
ax1.TickLabelInterpreter = 'latex'; ax1.Projection = 'perspective';
[X,Y] = meshgrid(linspace(0,0.75,N),linspace(0,.75,N));
F = R(X+1i*Y); F = reshape(F,size(X)); 
F = F.*conj(F); C = log(F);
ax1.CLim = [-3 3]; ax1.Colormap = jet(512);  
s1 = surf(X,Y,F,C,args{:}); 
[Xx,Yy] = meshgrid(linspace(-0.5,.75,N),linspace(-0.5,.75,N));
Ff = R(Xx+1i*Yy); Ff = reshape(Ff,size(Xx)); 
Ff = Ff.*conj(Ff); Cc = log(Ff);
s2 = pcolor(ax1,Xx,Yy,Cc); s2.EdgeColor = 'none'; s2.SpecularStrength = 0;
s2.AlphaDataMapping = 'none'; s2.FaceAlpha = 'flat'; 
s2.AlphaData = ones(size(s2.CData)); s2.AlphaData(Xx+Yy<0) = 0.5;
F = R(X+1i*Y); F = reshape(F,size(X)); 
F = F.*conj(F); C = log(F);
l = light; l.Style = 'local'; l. Position = [1 0.02 2];
plot3(xh1,yh1,abs(R(xh1+1i*yh1)).^2,'k','LineWidth',2)
ax1.Position = [0.13 0.12 0.68 0.86]; ax1.View = [-25 20];
kh = linspace(-.5,.75,6); 
for it = 1:length(kh)
    l1 = plot3(ax1,[-0.5 .75],[kh(it) kh(it)],[0 0],'b');
    l2 = plot3(ax1,[kh(it) kh(it)],[-0.5 .75],[0 0],'r');
    if kh(it)==0
        plot3(ax1,[0 .75],[kh(it) kh(it)],[0 0],'b','LineWidth',2);
        plot3(ax1,[kh(it) kh(it)],[0 .75],[0 0],'r','LineWidth',2);
    end
end
plot3(ax1,[-0.5 0.5],[0.5 -0.5],[0 0],'k','LineWidth',2)
ax1.Position([1 3]) = [0.01 .98];

%% plot approximation in physical domain
% when mapping back to the physical domain, a branch cut has to be
% introduced. Here it is chosen to match the digagonal line through the
% center in the canonical domain

args = {'EdgeColor','none','FaceLighting','gouraud','DiffuseStrength',...
    0.8,'SpecularStrength',0.5};
XY = invcutsqrt(X+1i*Y,c1)/c1;
f2 = figure(2); clf(2); ax2 = axes(f2); ax2.NextPlot = 'add';
f2.Units = 'centimeters'; f2.Position(3:4) = [15 15]; ax2.Box = 'on';
view(ax2,3); grid(ax2,'on'); ax2.DataAspectRatio = [1 1 6];
ax2.ZLim = [0 1]; ax2.Projection = 'perspective';
ax2.TickLabelInterpreter = 'latex'; ax2.Visible = 'off';
ax2.CLim = [-3 3]; ax2.Colormap = jet(512); 
surf(real(XY),imag(XY),F,C,args{:}); CC = Cc;
CC(Xx+Yy<0) = nan; XY = invcutsqrt(Xx+1i*Yy,c1)/c1;
s3 = pcolor(ax2,real(XY),imag(XY),CC); 
s3.SpecularStrength = 0; s3.FaceAlpha = 1; s3.EdgeColor = 'none';
l = light; l.Style = 'local'; l. Position = [1 0.02 2];
ax2.View = [-25 15]; ax2.Position([1 3]) = [-.25 1.5];
xh = linspace(invcutsqrt(.75i,c1),invcutsqrt(.75,c1),1001)/c1;
xh_ = cutsqrt(xh*c1,c1);
plot3(real(xh),zeros(size(xh)),abs(R(xh_)).^2,'k','LineWidth',2)
for it = 1:length(kh)
    x = linspace(max(-0.5,-kh(it)),.75,101);
    XY = invcutsqrt(x+1i*kh(it),c1)/c1;
    l1 = plot3(ax2,real(XY),imag(XY),zeros(size(XY)),'b');
    XY = invcutsqrt(1i*x+kh(it),c1)/c1;
    l2 = plot3(ax2,real(XY),imag(XY),zeros(size(XY)),'r');
    if kh(it)==0, l1.LineWidth = 2; l2.LineWidth = 2; end
end
XY = invcutsqrt(linspace(-0.5,0,101) + 1i*linspace(.5,0,101),c1)/c1;
plot3(ax2,real(XY),imag(XY),zeros(size(XY)),'k','LineWidth',3)


%% plot different branch cuts in canonical domain
% the choice of the branch cut is not unique

f3 = figure(3); clf(3); ax3 = axes(f3); ax3.NextPlot = 'add';
f3.Units = 'centimeters'; f3.Position(3:4) = [5 5];
ax3.Visible = 'off'; ax3.CLim = [-3 3]; ax3.Colormap = jet(512);
s2 = pcolor(ax3,Xx,Yy,Cc); s2.EdgeColor = 'none'; 
s2.AlphaDataMapping = 'none'; s2.FaceAlpha = 'flat';
s2.AlphaData = ones(size(s2.CData)); s2.AlphaData(3*Xx+Yy<0) = 0.5;
s2.SpecularStrength = 0; ax3.YLim = [-0.5 0.75]; 
ax3.DataAspectRatio = [1 1 1]; ax3.XLim = [-0.5 0.75]; 
for it = 1:length(kh)
    l1 = plot3(ax3,[-0.5 .75],[kh(it) kh(it)],[0 0],'b');
    l2 = plot3(ax3,[kh(it) kh(it)],[-0.5 .75],[0 0],'r');
    if kh(it)==0, l1.LineWidth = 1.5; l2.LineWidth = 1.5; end
end
plot3(ax3,[-0.25 1/6],[.75 -0.5],[0 0],'k','LineWidth',1.5)

f33 = figure(33); clf(33); ax33 = axes(f33); ax33.NextPlot = 'add';
f33.Units = 'centimeters'; f33.Position(3:4) = [5 5];
ax33.Visible = 'off'; ax33.CLim = [-3 3]; ax33.Colormap = jet(512);
s2 = pcolor(ax33,Xx,Yy,Cc); s2.EdgeColor = 'none'; ax33.YLim = [-0.5 0.75];
s2.AlphaDataMapping = 'none'; s2.FaceAlpha = 'flat';
s2.AlphaData = ones(size(s2.CData)); s2.AlphaData(Xx+3*Yy<0) = 0.5;
s2.SpecularStrength = 0; ax33.XLim = [-0.5 0.75];
ax33.DataAspectRatio = [1 1 1];
for it = 1:length(kh)
    l1 = plot3(ax33,[-0.5 .75],[kh(it) kh(it)],[0 0],'b');
    l2 = plot3(ax33,[kh(it) kh(it)],[-0.5 .75],[0 0],'r');
    if kh(it)==0, l1.LineWidth = 1.5; l2.LineWidth = 1.5; end
end
plot3(ax33,[.75 -0.5],[-0.25 1/6],[0 0],'k','LineWidth',1.5)

%% plot different beranch cuts in physical domain
% with a different branch cut previously hidden poles are displayed

f4 = figure(4); clf(4); ax4 = axes(f4); ax4.NextPlot = 'add';
ax4.Visible = 'off'; ax4.CLim = [-3 3]; ax4.Colormap = jet(512);
f4.Units = 'centimeters'; f4.Position(3:4) = [12 8];
CC = Cc; CC(3*Xx+Yy<0) = nan; XY = invcutsqrt(Xx+1i*Yy,c1)/c1;
s2 = pcolor(ax4,real(XY),imag(XY),CC); s2.EdgeColor = 'none';
s2.SpecularStrength = 0; s2.FaceAlpha = 1; ax4.ZLim = [0 1];
ax4.DataAspectRatio = [1 1 6]; view(ax4,3); ax4.View = [-25 15];
ax4.Projection = 'perspective';
l = light; l.Style = 'infinite'; l. Position = [0 0 1];
for it = 1:length(kh)
    xm = max(-0.5,0.25-it*1/12);
    if xm<0.75
        x = linspace(xm,.75,101);
        XY = invcutsqrt(x+1i*kh(it),c1)/c1;
        l1 = plot3(ax4,real(XY),imag(XY),zeros(size(XY)),'b');
    end
    xm = max(-0.5,1.5-(it-1)*0.75);
    if xm<0.75
        x = linspace(max(-0.5,1.5-it*0.5),.75,101);
        XY = invcutsqrt(1i*x+kh(it),c1)/c1;
        l2 = plot3(ax4,real(XY),imag(XY),zeros(size(XY)),'r');
    end
    if kh(it)==0, l1.LineWidth = 1.5; l2.LineWidth = 1.5; end
end
XY = invcutsqrt(linspace(-0.25,0,101) + 1i*linspace(.75,0,101),c1)/c1;
plot3(ax4,real(XY),imag(XY),zeros(size(XY)),'k','LineWidth',1.6)

CC = Cc; CC(Xx+3*Yy<0) = nan; XY = invcutsqrt(Xx+1i*Yy,c1)/c1;
s2 = pcolor(ax4,real(XY),imag(XY),CC); s2.EdgeColor = 'none';
s2.SpecularStrength = 0; s2.FaceAlpha = 1;
for it = 1:length(kh)
    xm = max(-0.5,1.5-(it-1)*0.75);
    if xm<0.75
        x = linspace(xm,.75,101);
        XY = invcutsqrt(x+1i*kh(it),c1)/c1;
        l1 = plot3(ax4,real(XY),imag(XY),ones(size(XY)),'b');
    end
    xm = max(-0.5,0.25-it*1/12);
    if xm<0.75
        x = linspace(xm,.751,101);
        XY = invcutsqrt(1i*x+kh(it),c1)/c1;
        l2 = plot3(ax4,real(XY),imag(XY),ones(size(XY)),'r');
    end
    if kh(it)==0, l1.LineWidth = 1.5; l2.LineWidth = 1.5; end
end
XY = invcutsqrt(linspace(0,.75,101) + 1i*linspace(0,-0.25,101),c1)/c1;
plot3(ax4,real(XY),imag(XY),ones(size(XY)),'k','LineWidth',1.6)
s2.ZData = s2.ZData+1;

ax4.Position = [-0.6 -0.12 2 1.2];

%% plot sampling
% the sampling visualised in the different domains

kc = [k1 k2]; kp = invcutsqrt([k1 k2],c1); 
ff = f(:,1).*conj(f(:,1)); zs = zeros(size(k1));
f6 = figure(6); clf(6); ax6 = axes(f6); ax6.DataAspectRatio = [1 1 5];
ax6.NextPlot = 'add'; ax6.Visible = 'off'; ax6.View = [-25 15];
f7 = figure(7); clf(7); ax7 = axes(f7);  ax7.DataAspectRatio = [1 100 20];
ax7.NextPlot = 'add'; ax7.Visible = 'off'; ax7.View = [-25 15];
plot3(ax6,[0 0.6],[0 0],[0 0],'b-','LineWidth',1.5)
plot3(ax6,[0 0],[0 0.54],[0 0],'r-','LineWidth',1.5)
plot3(ax7,[0.98 c1],[0 0],[0 0],'r-','LineWidth',1.5)
plot3(ax7,[c1 1.24],[0 0],[0 0],'b-','LineWidth',1.5)
sel = ismember(kc,zj); col = [240,143,28]/256;
for it = 1:length(kc)
    x = real(kc(it)); y = imag(kc(it)); z = ff(it);
    l1 = plot3(ax6,[x x],[y y],[0 z],'k-');
    l2 = plot3(ax7,real(kp([it it])),[0 0],[0 z],'k');
    if sel(it),l1.Color = col; l1.LineWidth = 1; end
    if it==1, l2.ZData(1) = -0.1; end
end
yh1 = [linspace(0.5,0,101) zeros(1,101)];
xh1 = [zeros(1,101) linspace(0,0.55,101)];
plot3(ax6,xh1,yh1,abs(R(xh1+1i*yh1)).^2,'k','LineWidth',1)
plot3(ax6,real(kc(sel)),imag(kc(sel)),ff(sel),'.','MarkerSize',14,'Color',col);
plot3(ax6,real(kc(~sel)),imag(kc(~sel)),ff(~sel),'k.','MarkerSize',14);
plot3(ax7,real(kp),[zs zs],ff,'k.','MarkerSize',9);
plot3(ax7,[c1 c1],[0 0],[0 -.1],'k');
plot3(ax6,[0 0],[0 0],[0 -.1],'k');
