% DEMO 2 for a Sylvester Equation AX+XB+fg'=0, where A,B are 
% nonsymmetric and come from FD discretizations of 3D elliptic operators
% with convection terms.
% This is essentially Example 2 from [Kuerschner, Inexact linear solves in the low-rank
% ADI iteration for large Sylvester equations, ETNA, 2024]

clear
clf
load('demo2_data.mat')
n=size(A,1);
m=size(B,1);
r=size(F,2);
[~,Rg]=qr(F,0);
[~,Rf]=qr(G,0);
res0=norm(Rf*Rg');
%%
l0A=30; % number of ADI a-shifts
l0B=30; % number of ADI b-shifts
kpA=10; % number of Ritz values of A
kmA=20; % number of inverse Ritz values of A
kpB=10; % number of Ritz values of B
kmB=20; % number of inverse Ritz values of A

% generate ADI shifts
% Ritzvalues of  A using (inexact) Arnoldi
arnopts={};
[Hp,~] = arn_pl(A,[],kpA,F*ones(r,1),arnopts);
rwp = eig(Hp(1:kpA,1:kpA));
% inverse Ritzvalues of A (using iterative solves)
arnopts.inner_tol=1e-9;
arnopts.maxit_inner=500;
arnopts.linsolver='bicgstab';
precAset.droptol=0.1;
arnopts.Mtype=[]; 
[arnopts.M1,arnopts.M2]=ilu(A,precAset);
[Hm,~] = arn_inv(A,[],kmA,F*ones(r,1),arnopts);
rwm = ones(kmA,1)./eig(Hm(1:kmA,1:kmA));
sA=sort([rwp;rwm]);
sA=sA(1:l0A);
% Ritzvalues of  B
Barnopts=arnopts;
[Hp,~] = arn_pl(B,[],kpB,G*ones(r,1),Barnopts);
rwp = eig(Hp(1:kpB,1:kpB));
% inverse Ritzvalues of B (using iterative solves)
[Barnopts.M1,Barnopts.M2]=ilu(B,precAset);
[Hm,~] = arn_inv(B,[],kmB,G*ones(r,1),Barnopts);
rwm = ones(kmB,1)./eig(Hm(1:kmB,1:kmB));
sB=sort([rwp;rwm]);
sB=sB(1:l0B);
% heuristic selection
[sA,sB]=pseudominmax(sA,sB);
 %% 
 % global ADI settings
 maxit=50;
 tol=1e-8;
% settings for linear solves
opts.itolmax=1e-0;
opts.itolmin=1e-12;
opts.maxit_inner=500;
opts.intolstrat='fixed';
opts.rgap_update = 0;
opts.backlook=0;
opts.debug=0;
opts.bal_q=1;
  opts.M1A=[];
 opts.M2A=[];
  opts.M1B=[];
 opts.M2B=[];
 opts.bal_q=1;
 opts.symA=0;
 opts.symB=0;
 opts.savg=1;
 %  settings for linear solves
opts.inner_tol=tol/20;
opts.itolmax=1e-1;
opts.itolmin=opts.inner_tol;
opts.maxit_inner=500;

opts.rgap_update = 0;
opts.simple_relax=1;
 opts.linsolverA='bicgstab';
opts.linsolverB='bicgstab';
opts.intolstrat='fixed';
opts.MtypeA=[]; opts.MtypeB=[];
precAset.droptol=0.1;
[opts.M1A,opts.M2A]=ilu(A,precAset);
[opts.M1B,opts.M2B]=ilu(B',precAset);

 %% fixed inner tols
 opts.intolstrat='fixed';

 tic
 [Z2,D2,Y2,res2,niter2,timings2,out2]=lr_adi_sylv(A,B,[],[],F,G,sA,sB,maxit,tol,opts);
 t2=toc
resadi_f_true=syl_r_norm(A,B,F,G,Z2,D2,Y2,m,[],[])/res0;

 %% relaxed inner tols
 opts.intolstrat='relax';
 opts.rgap_update = 1;
 opts.backlook=1;
 tic
 [Z3,D3,Y3,res3,niter3,timings3,out3]=lr_adi_sylv(A,B,[],[],F,G,sA,sB,maxit,tol,opts);
 t3=toc
resadi_r1_true=syl_r_norm(A,B,F,G,Z3,D3,Y3,m,[],[])/res0;

%% relaxed inner tols, B-mode
 opts.intolstrat='bal_relax';
 opts.rgap_update = 1;
 opts.backlook=1;
 opts.bal_q=0;
 tic
 [Z4,D4,Y4,res4,niter4,timings4,out4]=lr_adi_sylv(A,B,[],[],F,G,sA,sB,maxit,tol,opts);
 t4=toc
 resadi_r2_true=syl_r_norm(A,B,F,G,Z4,D4,Y4,m,[],[])/res0;
%%
fprintf('fixed-tol: \t total: %d, \t in-it A: %d,\t in-it B: %d\t\n',...
    sum(sum(out2.nrit,2)),(sum(out2.nrit(1,:),2)),(sum(out2.nrit(2,:),2)))
fprintf('relax-tol: \t total: %d, \t in-it A: %d,\t in-it B: %d, reldiff2fix: %4.2f%% \n',...
    sum(sum(out3.nrit,2)),(sum(out3.nrit(1,:),2)),(sum(out3.nrit(2,:),2)),sum(sum(out3.nrit))/sum(sum(out2.nrit))*100)
fprintf('relax-tol-B: \t total: %d, \t in-it A: %d,\t in-it B: %d, reldiff2fix: %4.2f%% \t\n',...
    sum(sum(out4.nrit,2)),(sum(out4.nrit(1,:),2)),(sum(out4.nrit(2,:),2)),sum(sum(out4.nrit))/sum(sum(out2.nrit))*100)

 

%%
figure(1),
 semilogy(1:length(res2(1,:)),res2(1,:),'r-+',...  
  1:length(res3(1,:)),res3(1,:),'b-+',...
  1:length(res4(1,:)),res4(1,:),'m-.')
 xlabel('step')
 ylabel('sylv res')
 legend('fixed','relax','relax 2')

 figure(2),
 semilogy(1:length(res3(1,:)),res3(1,:),'b-+',...
  1:length(res3(1,:)),out3.inres(1,:),'b:+',...
 1:length(res3(1,:)),out3.inres(2,:),'b:v',...
  1:length(res4(1,:)),out4.inres(1,:),'r:<',...
 1:length(res4(1,:)),out4.inres(2,:),'r:.' )
 xlabel('step')
 ylabel('sylv & inner res')
 legend('||R||','inresA\_relax','inresB\_relax','inresA\_relax-B','inresB\_relax-B')

figure(3),
 plot(1:niter2,cumsum(sum(out2.nrit,1)),'k-.',...
     1:niter3,cumsum(sum(out3.nrit,1)),'b-+',...
  1:niter4,cumsum(sum(out4.nrit,1)),'m-s')
 xlabel('step')
 ylabel('sum(inner iters)')
 legend('fix','relax','relax-B')
 title('cummulative sum of inner steps')


 figure(4),
 semilogy(1:length(res3(1,:)),out3.inres(1,:),'k--',...
     1:length(res3(1,:)),out3.itol(1,:),'k:',...
  1:length(res3(1,:)),out3.inres(2,:),'b-+',...
  1:length(res3(1,:)),out3.itol(2,:),'b:')
 legend('inres,A','itol,B','inres,A','itol,B')
 xlabel('step')
 ylabel('inner res and tol')
 title('basic relaxation')


 figure(5),
   semilogy(1:length(res4(1,:)),out4.inres(1,:),'k--',...
  1:length(res4(1,:)),out4.itol(1,:),'k:',...
  1:length(res4(1,:)),out4.inres(2,:),'b-+',...
  1:length(res4(1,:)),out4.itol(2,:),'b:')
 legend('inres,A','itol,B','inres,A','itol,B')
 xlabel('step')
 ylabel('inner res and tol')
 title('B- relaxation')