% DEMO 1 for a Sylvester Equation AX+XB+fg'=0, where A,B are 
% symmetric and come from FD discretizations of the 3D-Laplace
% This is essentially Example 1 from [Kuerschner, Inexact linear solves in the low-rank
% ADI iteration for large Sylvester equations, ETNA, 2024]

clear
clf
load('demo1_data.mat')
n=size(A,1);
m=size(B,1);
r=size(F,2);
[~,Rg]=qr(F,0);
[~,Rf]=qr(G,0);
res0=norm(Rf*Rg');
%%
% shift generation via Sabino's method
eigsopts.tol=1e-12;
eigsopts.isreal=1;
sA=eigs(A,15,'BE',eigsopts);
sB=eigs(B,15,'BE',eigsopts);
[sA,sB] = adi_para_sylv(min(real(-sA)),max(real(-sA)),min(real(-sB)),max(real(-sB)),30); 
sA=-sA;
sB=-sB;
%%
maxit=50;
tol=1e-8;

opts.inner_tol=tol/20;
opts.itolmax=1e-1;
opts.itolmin=opts.inner_tol;
opts.maxit_inner=500;
opts.intolstrat='fixed';
opts.rgap_update = 0;
opts.simple_relax=1;
opts.backlook=0;
opts.debug=0;
opts.bal_q=1;
opts.savg=1;
opts.M1A=[];
opts.M2A=[];
opts.M1B=[];
opts.M2B=[];
opts.bal_q=1;
opts.symA=1;
opts.symB=1;
opts.MtypeA=[]; opts.MtypeB=[];
%% exact method
opts.linsolverA='exact';
opts.linsolverB='exact';
tic
[Ze,De,Ye,res,niter,timings]=lr_adi_sylv(A,B,[],[],F,G,sA,sB,maxit,tol,opts);
t1=toc
rex = syl_r_norm(A,B,F,G,Ze,De,Ye,m,[],[])/res0
%%
% settings for linear solves
opts.inner_tol=tol/20;
opts.maxit_inner=500;
opts.rgap_update = 0;
opts.backlook=0;
opts.debug=0;
opts.symA=1;
opts.symB=1;
opts.linsolverA='minres';
opts.linsolverB='minres';
opts.intolstrat='fixed';
precAset.droptol=0.1;
[opts.M1A]=ichol(-A,precAset);opts.M2A=[];
[opts.M1B]=ichol(-B,precAset);opts.M2B=[];
opts.MtypeA='twosided';
opts.MtypeB='twosided';
%% fixed inner tols
opts.intolstrat='fixed';
tic
[Z2,D2,Y2,res2,niter2,timings2,out2]=lr_adi_sylv(A,B,[],[],F,G,sA,sB,maxit,tol,opts);
t2=toc
resadi_f_true=syl_r_norm(A,B,F,G,Z2,D2,Y2,m,[],[],[],[])/res0;
%% relaxed inner tols, no backlook
opts.intolstrat='relax';
opts.rgap_update = 1;
opts.backlook=0;
tic
[Z3,D3,Y3,res3,niter3,timings3,out3]=lr_adi_sylv(A,B,[],[],F,G,sA,sB,maxit,tol,opts);
t3=toc
resadi_r1_true=syl_r_norm(A,B,F,G,Z3,D3,Y3,m,[],[])/res0;
%% relaxed inner tols + backlooking
opts.intolstrat='relax';
opts.rgap_update = 1;
opts.backlook=1;
tic
[Z3b,D3b,Y3b,res3b,niter3b,timings3b,out3b]=lr_adi_sylv(A,B,[],[],F,G,sA,sB,maxit,tol,opts);
t3b=toc
resadi_r3b_true=syl_r_norm(A,B,F,G,Z3b,D3b,Y3b,m,[],[])/res0;
%% relaxed, B-mode, no backlook
opts.intolstrat='bal_relax';
opts.rgap_update = 1;
opts.backlook=0;
opts.bal_q=0;
tic
[Z4,D4,Y4,res4,niter4,timings4,out4]=lr_adi_sylv(A,B,[],[],F,G,sA,sB,maxit,tol,opts);
t4=toc
resadi_r4_true=syl_r_norm(A,B,F,G,Z4,D4,Y4,m,[],[])/res0;
%% relaxed, B-mode + backlooking
opts.intolstrat='bal_relax';
opts.rgap_update = 1;
opts.backlook=1;
opts.bal_q=0;
tic
[Z4b,D4b,Y4b,res4b,niter4b,timings4b,out4b]=lr_adi_sylv(A,B,[],[],F,G,sA,sB,maxit,tol,opts);
t4b=toc
resadi_r4b_true=syl_r_norm(A,B,F,G,Z4,D4,Y4,m,[],[])/res0;
%%
fprintf('fixed-tol: \t total: %d, \t in-it A: %d,\t in-it B: %d\t\n',...
    sum(sum(out2.nrit,2)),(sum(out2.nrit(1,:),2)),(sum(out2.nrit(2,:),2)))
fprintf('relax-tol,no BL: \t total: %d, \t in-it A: %d,\t in-it B: %d, reldiff2fix: %4.2f%% \n',...
    sum(sum(out3.nrit,2)),(sum(out3.nrit(1,:),2)),(sum(out3.nrit(2,:),2)),sum(sum(out3.nrit))/sum(sum(out2.nrit))*100)
fprintf('relax-tol+ BL: \t total: %d, \t in-it A: %d,\t in-it B: %d, reldiff2fix: %4.2f%% \n',...
    sum(sum(out3b.nrit,2)),(sum(out3b.nrit(1,:),2)),(sum(out3b.nrit(2,:),2)),sum(sum(out3b.nrit))/sum(sum(out2.nrit))*100)
fprintf('relax-tol-B, no BL: \t total: %d, \t in-it A: %d,\t in-it B: %d, reldiff2fix: %4.2f%% \t\n',...
    sum(sum(out4.nrit,2)),(sum(out4.nrit(1,:),2)),(sum(out4.nrit(2,:),2)),sum(sum(out4.nrit))/sum(sum(out2.nrit))*100)
fprintf('relax-tol-B, + BL: \t total: %d, \t in-it A: %d,\t in-it B: %d, reldiff2fix: %4.2f%% \t\n',...
    sum(sum(out4b.nrit,2)),(sum(out4b.nrit(1,:),2)),(sum(out4b.nrit(2,:),2)),sum(sum(out4b.nrit))/sum(sum(out2.nrit))*100)
%%
figure(1),
semilogy(1:length(res(1,:)),res(1,:),'k-.',...
    1:length(res2(1,:)),res2(1,:),'r-+',...
    1:length(res3b(1,:)),res3b(1,:),'b-+',...
    1:length(res4b(1,:)),res4b(1,:),'m-.')
xlabel('step')
ylabel('sylv res')
legend('direct','fixed','relax','relax 2')


figure(2),
semilogy(1:length(res3b(1,:)),res3b(1,:),'b-+',...
    1:length(res3(1,:)),out3.inres(1,:),'b:+',...
    1:length(res3(1,:)),out3.inres(2,:),'b:v',...
    1:length(res3b(1,:)),out3b.inres(1,:),'m:*',...
    1:length(res3b(1,:)),out3b.inres(2,:),'m:s')
xlabel('step')
ylabel('sylv & inner res')
legend('||R||','inresA\_relax','inresB\_relax','inresA\_relaxBL','inresB\_relaxBL')


figure(3),
plot(1:niter2,cumsum(sum(out2.nrit,1)),'k-.',...
    1:niter3,cumsum(sum(out3.nrit,1)),'b-+',...
    1:niter3b,cumsum(sum(out3b.nrit,1)),'b-v',...
    1:niter4,cumsum(sum(out4.nrit,1)),'m-s',...
    1:niter4,cumsum(sum(out4b.nrit,1)),'m-.')
xlabel('step')
ylabel('sum(inner iters)')
legend('fix','relax noBL','relax BL','Brelax noBL','Brelax BL')
title('cummulative sum of inner steps')



figure(4),
semilogy(1:length(res3b(1,:)),out3b.inres(1,:),'k--',...
    1:length(res3b(1,:)),out3b.itol(1,:),'k:',...
    1:length(res3b(1,:)),out3b.inres(2,:),'b-+',...
    1:length(res3b(1,:)),out3b.itol(2,:),'b:')
legend('inres,A','itol,B','inres,A','itol,B')
xlabel('step')
ylabel('inner res and tol')
title('basic relaxation')

figure(5),
semilogy(1:length(res4b(1,:)),out4b.inres(1,:),'k--',...
    1:length(res4b(1,:)),out4b.itol(1,:),'k:',...
    1:length(res4b(1,:)),out4b.inres(2,:),'b-+',...
    1:length(res4b(1,:)),out4b.itol(2,:),'b:')
legend('inres,A','itol,B','inres,A','itol,B')
xlabel('step')
ylabel('inner res and tol')
title('B-mode relaxation')
