%% prepare workspace
addpath('code');

%{
jcmsuiteDir = '/path/to/JCMsuite';
addpath([jcmsuiteDir '/ThirdPartySupport/Matlab']);

options = struct('Hostname','localhost','Multiplicity',7,'NThreads',1);
jcmwave_daemon_shutdown; jcmwave_daemon_add_workstation(options);
%}

%% select data points
% Here the sampling points are defined. Between two branch points they are 
% equidistantly placed on a quarter circle and transformed back to the real 
% line with a corresponding coordinate transformation consisting of two
% radicals. For the refererence we add equidistantly distributed sampling
% points. 

k1 = linspace(1,1/0.29,257).'; cuts = 1./[0.9 0.45 0.3];
k1 = k1(~ismembertol(k1,1./[0.3 0.45 0.9],1e-5)); 
a = cutsqrt(cuts(1)+1e-4,cuts(1:2)); b = cutsqrt(cuts(2)-1e-4,cuts(1:2));
c12 = abs(a)*exp(1i*linspace(angle(a),angle(b),129));
a = cutsqrt(cuts(2)+1e-4,cuts(2:3)); b = cutsqrt(cuts(3)-1e-4,cuts(2:3));
c23 = abs(a)*exp(1i*linspace(angle(a),angle(b),129));
k2 = real([invcutsqrt(linspace(cutsqrt(1,cuts(1)),2e-2i,25),cuts(1)),...
    invcutsqrt(c12,cuts(1:2)), invcutsqrt(c23,cuts(2:3)),...
    invcutsqrt(linspace(2e-2,cutsqrt(1/0.29,cuts(3)),25),cuts(3))]).';
k3 = real([invcutsqrt(linspace(cutsqrt(1,cuts(1)),2e-2i,25),cuts(1)),...
    invcutsqrt(c12,cuts(1:2)), invcutsqrt(c23,cuts(2:3)),...
    invcutsqrt(linspace(2e-2,cutsqrt(1/0.29,cuts(3)),25),cuts(3))]).';

keys.P = 600; keys.n = 1.3; keys.w = 50; keys.theta = 30;
keys.finiteElementDegree = 5; keys.pml = [pwd '/JCMsuite/pml.txt'];

% join all the data points
k = uniquetol([k1;k2]); 

%% collect data
[f,r0,r,a] = collect_data(k,[pwd '/data/full.mat'],keys);

disp(max(abs((1-(r+a)))));

[R,pp,rr] = branch_cut_aaa(cuts,f(:,1),k,'tol',5e-7);
sel1 = k/cuts(2)>0.75; sel2 = k/cuts(2)<1.2;
[~,pp1] = branch_cut_aaa(cuts(2:3),f(sel1,1),k(sel1),'tol',5e-7);
[~,pp2] = branch_cut_aaa(cuts(1:2),f(sel2,1),k(sel2),'tol',5e-7);
pp1 = pp1{2}(real(pp1{2}/cuts(2))>0.8); 
pp2 = pp2{2}(real(pp2{2}/cuts(2))<0.8);
[~,ppp] = aaa(f(:,1),k,'tol',5e-7); 
sel = real(pp{2})<k(end)&real(pp{2})>k(1); 
pp = pp{2}(sel); rr = rr{2}(sel);
p_ref = arnoldi_resonance(pp,[pwd '/data/p_ref.mat'],keys);
[pp,ndx] = sort(pp,'ComparisonMethod','real'); rr = rr(ndx);

x = linspace(k(1),k(end),1000001).'; fx = R(x);

%% plot specular reflection
ndx_uniform = find(ismembertol(k,k1)); 
ndx_cuts = find(ismembertol(k,k2));
fs = 10; text_args = {'Interpreter','latex','Fontsize',fs}; 

f2 = figure(22); clf(22); f2.Units = 'centimeters';
f2.Position(3:4) = [16 9];
ax2a = axes(f2); ax2a.NextPlot = 'add'; ax2a.Box = 'on';
ax2a.Position = [0.09 0.375 0.5 0.61]; ax2a.TickLabelInterpreter = 'latex';
ax2a.FontSize = fs; ax2a.XTickLabel = []; 

ax2b = axes(f2); ax2b.NextPlot = 'add'; ax2b.Box = 'on';
ax2b.Position = [0.09 0.1 0.5 0.25]; ax2b.FontSize = fs;
ax2b.TickLabelInterpreter = 'latex';

fp = R(conj(pp)); ax2a.YLim = [0 1.05]; 
ms = rr.'.*conj(fp.')./(x-pp.')+conj(rr.').*fp.'./(x-conj(pp.'))+1;
ax2a.ColorOrder = parula(7); ax2a.ColorOrderIndex = 1;
plot(ax2a,x/cuts(2),abs(fx).^2,'LineWidth',1,'Color',[0.5 0.5 0.5 0.7]);
plot(ax2a,k2/cuts(2),r0(ndx_cuts),'k.','MarkerSize',4);
ax2b.ColorOrderIndex = 2; plot(ax2b,pp1/cuts(2),'x','MarkerSize',5);
ax2b.ColorOrderIndex = 2; plot(ax2b,pp2/cuts(2),'x','MarkerSize',5);
ax2b.ColorOrderIndex = 1; plot(ax2b,ppp/cuts(2),'x','MarkerSize',3);
ax2a.XLim = [0.45 1.55]; ax2b.XLim = [0.45 1.55]; 
ax2b.YLim = [-0.045 0.005]; ax2b.Children(2).LineWidth = 1;
ax2b.Children(3).LineWidth = 1;
ylabel(ax2a,'Specular Reflection $R_0$',text_args{:})
xlabel(ax2b,'Re$(k)/k_0$',text_args{:})
ylabel(ax2b,'Im$(k)/k_0$',text_args{:})
ax2a.YLabel.Position(1) = ax2b.YLabel.Position(1);
l = legend(ax2a,{'Approximation' 'Reference'},text_args{1:end-1},8);

l.Location = 'southeast'; l.Box = 'off'; l.ItemTokenSize(1) = 12;
l.Position(2) = l.Position(2)-0.01; drawnow;
ln = plot(ax2b,nan,nan,'x','Color',[0.5 0.5 0.5],'MarkerSize',4);
lg = legend(ln,'Poles',text_args{1:end-1},8,'location','southeast'); 
lg.ItemTokenSize(1) = 12; lg.Box = 'off'; lg.Position(1) = l.Position(1);
parent = l.EntryContainer.NodeChildren(2).Icon.Transform.Children;

%% rational approximation
n = 3;

tol = 0.5*10.^(-4:-1:-6);
ax2c = axes(f2); ax2c.NextPlot = 'add'; 
ax2c.Position = [0.7 0.555 0.29 0.43]; ax2c.XTickLabel = [];
ax2d = axes(f2); ax2d.NextPlot = 'add'; 
ax2d.Position = [0.7 0.1 0.29 0.43];
ax2c.FontSize = fs; ax2d.FontSize = fs; 
ax2c.TickLabelInterpreter = 'latex'; ax2d.TickLabelInterpreter = 'latex';
ax2d.YScale = 'log'; ax2c.Box = 'on'; ax2d.Box = 'on'; 
c = ax2c.ColorOrder; ax2d.YLim = [1e-7 1e-1]; 
ls = {'-' '--' ':'}; args = {'LineStyle',[]};
yline(ax2d,tol(1),'LineStyle',ls{1},'Color',[0.3 0.3 0.3]);
yline(ax2d,tol(2),'LineStyle',ls{2},'Color',[0.3 0.3 0.3]);
yline(ax2d,tol(3),'LineStyle',ls{3},'Color',[0.3 0.3 0.3]);

for it1 = 1:length(tol)
    err = zeros(2,n); n_poles = zeros(2,n); n_points = zeros(1,n);
    for it2 = 1:n
        step = 2^(it2);
        sel1 = ndx_cuts([1:step:25 26:step:154 155:step:283 284:step:end]);
        % sel1 = ndx_uniform(1:step:end);
        [R,p] = branch_cut_aaa(cuts,f(sel1,1),k(sel1),'tol',tol(it1));
        n_poles(1,it2) = size(p{1},1); n_points(it2) = length(sel1);
        err(1,it2) = max(abs(R(k)-f(:,1)));
        % err(1,it2) = max(min(abs(p{2}-p_ref.'),[],1));
        [R,p] = aaa(f(sel1,1),k(sel1),'tol',tol(it1));
        n_poles(2,it2) = size(p,1); 
        err(2,it2) = max(abs(R(k)-f(:,1)));
        % err(2,it2) = max(min(abs(p-p_ref.'),[],1));
    end
    args{2} = ls{it1};
    plot(ax2c,n_points,n_poles(1,:),'.-','Color',c(2,:),args{:});
    plot(ax2d,n_points,err(1,:),'.-','Color',c(2,:),args{:});
    plot(ax2c,n_points,n_poles(2,:),'.-','Color',c(1,:),args{:});
    plot(ax2d,n_points,err(2,:),'.-','Color',c(1,:),args{:});
end

ax2c.YLim = [7.5 n_poles(2,1)+1]; 
ax2d.XTick = n_points(end:-1:1); ax2c.XTick = n_points(end:-1:1);
ax2d.XLim = [n_points(end)-8 n_points(1)+8]; ax2c.XLim = ax2d.XLim;
xlabel(ax2d,'\# Support Points',text_args{:});
ylabel(ax2d,'Error',text_args{:});
ylabel(ax2c,'\# Poles',text_args{:});
ax2d.YLabel.Position(1) = ax2d.YLabel.Position(1)+2;
ax2c.YLabel.Position(1) = ax2d.YLabel.Position(1);
ax2d.XLabel.Units = 'centimeters'; ax2b.XLabel.Units = 'centimeters';
ax2b.XLabel.Position(2) = ax2b.XLabel.Position(2)+0.1;
ax2d.XLabel.Position(2) = ax2b.XLabel.Position(2);
ax2d.YTick = [1e-6 1e-4 1e-2]; ax2d.YMinorTick = 'off';

l1 = plot(ax2c,nan,nan,'Color',c(1,:),'DisplayName','$k$'); 
l2 = plot(ax2c,nan,nan,'Color',c(2,:),'DisplayName','$\tilde{k}$'); 
l12 = legend([l1 l2],text_args{:},'Box','off','Location','SouthWest');
text_args{end} = 8; l12.Position(2) = l12.Position(2)-0.02;
text(ax2d,128,8e-5,'$5\times 10^{-5}$',text_args{:})
text(ax2d,128,8e-6,'$5\times 10^{-6}$',text_args{:})
text(ax2d,128,8e-7,'$5\times 10^{-7}$',text_args{:})