% DEMO 4 for a generalized Sylvester Equation AXC+MXB+fg'=0,
% where A,M and B,C are symmetric. The matrices come from a FEM model of
% the heat transer of a machine tool [Sauerzapf/Naumann/Vettermann/Saak,
% https://zenodo.org/records/10017861, 2023]

% This demo is essentially Example 4 from [Kuerschner, Inexact linear solves in the low-rank
% ADI iteration for large Sylvester equations, ETNA, 2024]

clear
clf
% first get the dataset from https://zenodo.org/records/10017861
load simplifiedMachineToolFineSA1.mat;
kk=load('simplifiedMachineToolFineSA2.mat'); 
B=kk.A;
C=kk.E;
n=size(A,1);
m=size(B,1);
r=2;
F=randn(n,r);
G=randn(m,r);
[~,Rg]=qr(F,0);
[~,Rf]=qr(G,0);
res0=norm(Rf*Rg')
F=F/norm(F)*sqrt(res0);
G=G/norm(G)*sqrt(res0);
%%
l0A=30; % number of ADI a-shifts
l0B=30; % number of ADI b-shifts
kpA=10; % number of Ritz values of A
kmA=20; % number of inverse Ritz values of A
kpB=10; % number of Ritz values of B
kmB=20; % number of inverse Ritz values of A

% generate ADI shifts
% Ritzvalues of  A using (inexact) Arnoldi
arnopts.linsolver='minres';
arnopts.inner_tol=1e-9;
arnopts.maxit_inner=500;
precAset.droptol=0.1;
precAset.michol='off';
[arnopts.M1]=ichol(E,precAset);
arnopts.M2=arnopts.M1';
arnopts.Mtype='twosided';
[Hp,~] = arn_pl(A,E,kpA,F*ones(r,1),arnopts);
rwp = eig(Hp(1:kpA,1:kpA));
% inverse Ritzvalues of A (using iterative solves)
arnopts.linsolver='minres';
precAset.droptol=0.1;
arnopts.M1=ichol(-A,precAset);
arnopts.M2=arnopts.M1';
arnopts.Mtype='twosided';
[Hm,~] = arn_inv(A,E,kmA,F*ones(r,1),arnopts);
rwm = ones(kmA,1)./eig(Hm(1:kmA,1:kmA));
sA=sort([rwp;rwm]);
sA=sA(1:l0A);
% Ritzvalues of  B (using iterative solves)
Barnopts=arnopts;
precAset.droptol=0.1;
precAset.michol='off';
Barnopts.M1=ichol(C,precAset);
Barnopts.M2=Barnopts.M1';
[Hp,~] = arn_pl(B,C,kpB,G*ones(r,1),Barnopts);
rwp = eig(Hp(1:kpB,1:kpB));
% inverse Ritzvalues of B (using iterative solves)
[Barnopts.M1]=ichol(-B,precAset);
Barnopts.M2=Barnopts.M1';
[Hm,~] = arn_inv(B,C,kmB,G*ones(r,1),Barnopts);
rwm = ones(kmB,1)./eig(Hm(1:kmB,1:kmB));
sB=sort([rwp;rwm]);
sB=(sB(1:l0B));
% heuristic selection
[sA,sB]=pseudominmax(sA,sB);
%%
% global ADI settings
maxit=100;
tol=1e-8;
% settings for linear solves
opts.inner_tol=tol/20;
opts.itolmax=1e-1;
opts.itolmin=opts.inner_tol;
opts.maxit_inner=500;
opts.intolstrat='fixed';
opts.rgap_update = 0;
opts.backlook=0;
opts.debug=0;
 opts.savg=1;
  opts.M1A=[];
 opts.M2A=[];
  opts.M1B=[];
 opts.M2B=[];
 opts.bal_q=1;
opts.MtypeA=[]; opts.MtypeB=[];
opts.linsolverA='exact';
opts.linsolverB='exact';
 opts.symA=1;
 opts.symB=1;

 %% direct linear solves
 tic
 [Ze,De,Ye,res,niter,timings]=lr_adi_sylv(A,B,E,C,F,G,sA,sB,maxit,tol,opts);
 t1=toc
 rex = syl_r_norm(A,B,F,G,Ze,De,Ye,m,E,C)/res0
%% fixed inner tols
opts.inner_tol=tol/100;
opts.itolmax=1e-1;
opts.itolmin=opts.inner_tol;
opts.maxit_inner=500;
opts.linsolverA='minres';
opts.linsolverB='minres'; 
opts.rgap_update = 0;
opts.backlook=0;
opts.debug=0;
 opts.bal_q=1;
opts.MtypeA=[]; opts.MtypeB=[];
 precAset.droptol=0.1;
 [opts.M1A]=ichol(-A-sB(1)*E,precAset);
 opts.M2A=opts.M1A';
  [opts.M1B]=ichol(-B'-sA(1)*C',precAset);
  opts.M2B=opts.M1B';
% settings for preconditioner updating
ilopt.droptol=0.1;
opts.updprec.poptsA=ilopt;
opts.updprec.precfunA = @(x,y) ichol(x,y);
opts.updprec.poptsB=ilopt;
opts.updprec.precfunB = @(x,y) ichol(x,y);
 %%
 opts.intolstrat='fixed';
 tic
 [Z2,D2,Y2,res2,niter2,timings2,out2]=lr_adi_sylv(A,B,E,C,F,G,sA,sB,maxit,tol,opts);
 t2=toc
res2_fix=syl_r_norm(A,B,F,G,Z2,D2,Y2,m,E,C,[],[])/res0
 %% relaxed inner tols
 opts.intolstrat='relax';
 opts.rgap_update = 1;
 opts.backlook=1;
 opts.savg=0.1;
 tic
 [Z3,D3,Y3,res3,niter3,timings3,out3]=lr_adi_sylv(A,B,E,C,F,G,sA,sB,maxit,tol,opts);
 t3=toc
res3_ex=syl_r_norm(A,B,F,G,Z3,D3,Y3,m,E,C,[],[])/res0
%% relaxed inner tols, B-mode
 opts.intolstrat='bal_relax';
 opts.rgap_update = 1;
 opts.backlook=1;
 opts.bal_q=0;
 tic
 [Z4,D4,Y4,res4,niter4,timings4,out4]=lr_adi_sylv(A,B,E,C,F,G,sA,sB,maxit,tol,opts);
 t4=toc
 res4_ex=syl_r_norm(A,B,F,G,Z4,D4,Y4,m,E,C,[],[])/res0
%%
fprintf('fixed-tol: \t total: %d, \t in-it A: %d,\t in-it B: %d\t\n',...
    sum(sum(out2.nrit,2)),(sum(out2.nrit(1,:),2)),(sum(out2.nrit(2,:),2)))
fprintf('relax-tol: \t total: %d, \t in-it A: %d,\t in-it B: %d, reldiff2fix: %4.2f%% \n',...
    sum(sum(out3.nrit,2)),(sum(out3.nrit(1,:),2)),(sum(out3.nrit(2,:),2)),sum(sum(out3.nrit))/sum(sum(out2.nrit))*100)
fprintf('relax-tol-B: \t total: %d, \t in-it A: %d,\t in-it B: %d, reldiff2fix: %4.2f%% \t\n',...
    sum(sum(out4.nrit,2)),(sum(out4.nrit(1,:),2)),(sum(out4.nrit(2,:),2)),sum(sum(out4.nrit))/sum(sum(out2.nrit))*100)
%%
figure(1),
semilogy(1:length(res(1,:)),res(1,:),'k-.',...
    1:length(res2(1,:)),res2(1,:),'r-+',...
    1:length(res3(1,:)),res3(1,:),'b-+',...
    1:length(res4(1,:)),res4(1,:),'m-.')
xlabel('step')
ylabel('sylv res')
legend('direct','fixed','relax','relax B')

figure(2),
semilogy(1:length(res3(1,:)),res3(1,:),'b-+',...
    1:length(res3(1,:)),out3.inres(1,:),'b:+',...
    1:length(res3(1,:)),out3.inres(2,:),'b:v',...
    1:length(res4(1,:)),out4.inres(1,:),'r:<',...
    1:length(res4(1,:)),out4.inres(2,:),'r:.' )
xlabel('step')
ylabel('sylv & inner res')
legend('||R||','inresA\_relax','inresB\_relax','inresA\_relax-B','inresB\_relax-B')

figure(3),
plot(1:niter2,cumsum(sum(out2.nrit,1)),'k-.',...
    1:niter3,cumsum(sum(out3.nrit,1)),'b-+',...
    1:niter4,cumsum(sum(out4.nrit,1)),'m-s')
xlabel('step')
ylabel('sum(inner iters)')
legend('fix','relax','relax-B')
title('cummulative sum of inner steps')