%%%%%% stempo_LplusS_2d_example.m %%%%%%
%
% Example code for reconstructing the STEMPO phantom data using Low-rank +
% sparse decomposition (L+S) scheme. The aim is to separate the slowly
% changing low-rank background L and the more rapidly changing dynamic
% component S which should be made of only few nonzero terms, i.e. is
% sparse.
%
%%%%%%
%
% We wish to solve the minimization task:
%
% argmin_{A(L+S) = m} || L ||_* + mu*|| WS ||_1,                (1)
%
% where
% A is the forward operator, m is the sinogram and the reconstruction is
% split into the low-rank and sparse terms L and S respectively. L is
% constrained using the nuclear norm || ||_* which is the sum of singular
% values of the matrix (also denoted L) whose columns consists of the 
% different time steps: L = [L_1, ..., L_T].
%
% To further increase the sparsity of the dynamic component, we consider
% the wavelet transform WS_t of every time step of the dynamic component. 
%
% Mu is a positive regularization parameters and W is the 2D wavelet
% transform.
%
% In practice however we solve the regularization version of eq.(1) instead:
%
% argmin_{L, S} 1/2*|| A(L+S) - m ||_2^2 + mu_L || L ||_* + mu_S || WS ||_1,
%
% where we have separate regularization parameters mu_L and mu_s.
% 
%%%%%%
%
% References:
% [Low-rank + sparse decomposition of dynamic MRI (with the associated 
% forward operator and Fourier domain sparsity instead of wavelets and CT)]
% Otazo, R., Candes, E., & Sodickson, D. K. (2015)
% "Low‐rank plus sparse matrix decomposition for accelerated dynamic MRI
% with separation of background and dynamic components."
% Magnetic resonance in medicine, 73(3), 1125-1136.
% doi: 10.1002/mrm.25240.
% Original code: https://cai2r.net/resources/ls-reconstruction-matlab-code/
% 
% [RPCA applied to 4D( 3D + time) tomography]
% Gao, H., Cai, J. F., Shen, Z., & Zhao, H. (2011).
% "Robust principal component analysis-based four-dimensional computed
% tomography."
% Physics in Medicine & Biology, 56(11), 3181.
% 
%%%%%%
% 
% Requirements:
% ASTRA Toolbox
% https://www.astra-toolbox.com/
% Recommended v1.9 or higher
%
% Spot - A Linear-Operator Toolbox
% https://www.cs.ubc.ca/labs/scl/spot/
% Recommended v1.2
%
% HelTomo Toolbox
% https://github.com/Diagonalizable/HelTomo
% v2.0
%
% Wavelet Toolbox
% https://mathworks.com/products/wavelet.html
%
%%%%%%
%
% Created 14.9.2022 - Last edited 23.9.2022
% Tommi Heikkilä
% University of Helsinki

% Clear workspace
clear all
close all

%% Load data

% Sequence of 8x45 projection scans with 8 degree angle interval
binning = 16;
load(sprintf('stempo_seq8x45_2d_b%d.mat',binning))

%% Choose parameters

N = 2240 / binning; % Spatial resolution of 2d slices

% We wish to split the 8 full rotations (45 projections each) into as many
% time steps as possible (since that greatly limits the SVD and number of
% singular values). We can take 24 projections per time step and advance
% only by 4 projections for the next time step such that two consecutive
% projections share 20 projections:
% p   :  1   2   3   4   5  ... 24   25   26   27   28   29 ... 360
% t=1 :  X   X   X   X   X  ...  X
% t=2 :                  X  ...  X    X    X    X    X
% etc.
Nangles = 24;
angShift = 4;
T = (CtData.parameters.numberImages - Nangles + angShift) / angShift;

% Projection angles are stored in columns
angleArray = 8*(0:1:Nangles-1)' + 8*angShift*(0:1:T-1);

% Reorganize the data in a similar manner to match the projection angles
mInd = (1:Nangles)' + angShift*(0:1:T-1);
m = permute(reshape(CtData.sinogram(mInd(:),:),[Nangles,T,N]),[1,3,2]);
% Permuting the array guarantees the time steps stay in order once m is
% dropped into a single column vector

% Visualize data
figure(1)
for t = 1:T
    imagesc(m(:,:,t)')
    title(sprintf('Data m_{%d}', t))
    colormap gray
    axis off
    drawnow
    pause(0.02)
end

%% Forward operator

% Build a block diagonal forward operator
opCell = cell(1,T);
for t = 1:T
    % Change the projection angles stored in CtData
    CtData.parameters.angles = angleArray(:,t);
    % Create and store the operator in a cell array
    opCell{t} = create_ct_operator_2d_fan_astra(CtData, N, N);
end
% cell{:} gives the content of a cell array as comma separated list
A = blkdiag(opCell{:});

% Normalize data and operator
Anorm = normest(A);
A = A/Anorm;
m = m(:)/Anorm;

%% Wavelet transform

% Wavelet decomposition level and type
param.wlevel = 3;
param.wname = 'db2';
[~, wSz] = wavedec2(zeros(N,N),param.wlevel,param.wname);
param.wSz = wSz;

%% Run L+S algorithm
% Set parameters
param.muL = 0.01;
param.muS = 0.001;
param.maxIter = 500;
param.tol = 5e-4;
param.sz = [N,N,T];

% Iterate
[L, S, iter, info] = LplusSalgorithm(A, m, param);
LplusS = L + S;

%% Look at outcome

figure(2);
montage(LplusS,'DisplayRange', [])

% Visualize data
figure('Position',[200 400 1100 500])
for t = 1:T
    imagesc([L(:,:,t), S(:,:,t), LplusS(:,:,t)])
    colormap gray
    axis image
    axis off
    title({sprintf('Time step: %d', t);...
        'L                    S                  L+S'})
    drawnow
    pause(0.02)
end

%% Helper functions
function [L, S, iter, info] = LplusSalgorithm(A, m, param)
    % Main L + S algorithm for reaching the solution iteratively
    fprintf('Begin L + S algorithm... \n')
    tic;
    %% Unload L+S parameters
    level = param.wlevel;
    wname = param.wname;
    wSz = param.wSz;
    muL = param.muL;
    muS = param.muS;
    maxIter = param.maxIter;
    tol = param.tol;
    sz = param.sz;
    N = sz(1);
    T = sz(3);
    
    m = m(:); % Drop to column vector

    % Backproject
    M = reshape(A'*m,[N*N,T]);
    Lpre = M;

    S=zeros(N*N,T); 
    iter=0;

    plotFlag = true;
    
    % Store useful values during iteration
    relChange = nan(1,maxIter);
    dataFit = nan(1,maxIter);
    nuclear = nan(1,maxIter);
    l1Norm = nan(1,maxIter);

    %% Iterate
    while iter < maxIter
        iter=iter+1;
        
        % low-rank update
        M0 = M;
        [Ut,St,Vt] = svd(M-S,0);
        St = diag(SoftThresh(diag(St),St(1)*muL));
        L = Ut*St*Vt';

        % soft threshold M - Lpre on wavelet domain
        WS = SoftThresh(Wfwd(reshape(M - Lpre,sz), level, wname), muS);
        S = reshape(Wadj(WS, wSz, wname),[N*N,T]);

        % data consistency
        LplusS = L + S;
        dif = A*LplusS(:) - m;
        M = LplusS - reshape(A'*dif,[N*N,T]);


        % L_{k-1} for the next iteration
        Lpre = L;
        
        relChange(iter) = norm(M(:)-M0(:))/norm(M0(:));
        dataFit(iter) = norm(dif);
        nuclear(iter) = sum(diag(St));
        l1Norm(iter) = norm(WS(:),1);
        
        if mod(iter,10) == 0
            fprintf('Iteration number %d reached \n', iter);
            fprintf('Relative change: %.5f \n', relChange(iter));
            fprintf('Cost function: %.2f \n', dataFit(iter)^2 + muL*nuclear(iter) + muS*l1Norm(iter));
        end
        
        if (mod(iter,20) == 0) && plotFlag
            figure(100)
            title(sprintf('Reconstruction at iter: %d', iter));
            montage([reshape(L,sz), reshape(S,sz)], 'DisplayRange', []);
        end
        
        % Stop once relative change is below tolerance
        if relChange(iter) < tol
            fprintf('Stopping criterion reached after %d iterations \n',iter);
            break
        end
    end
    L = reshape(L,sz);
    S = reshape(S,sz);
    
    timeTot = toc;
    
    fprintf('Total computational time: %.1f s, approximately %.2f s per iteration \n', timeTot, timeTot / iter);
    
    if iter == maxIter
        fprintf('Maximum iteration count reached! Iteration stopped \n');
    else
        % Cut stored arrays to correct length
        relChange = relChange(1:iter);
        dataFit = dataFit(1:iter);
        nuclear = nuclear(1:iter);
        l1Norm = l1Norm(1:iter);
    end
    info.relChange = relChange;
    info.dataFit = dataFit;
    info.nuclear = nuclear;
    info.l1Norm = l1Norm;
    info.functionalValues = dataFit.^2 + muL*nuclear + muS*l1Norm;
    info.timeTot = timeTot;
end

function w = Wfwd(x, level, wname)
    %%% Perform 2D wavelet transform on every layer of x (3D array)
    
    T = size(x,3);
    [C,~] = wavedec2(x(:,:,1),level,wname);
    
    if T == 1 % Special case for 2D array
        w = C;
        return
    end
    w = zeros(T, length(C));
    w(1,:) = C;
    for t = 2:T
        [C,~] = wavedec2(x(:,:,t),level,wname);
        w(t,:) = C;
    end
end

function X = Wadj(w, wSz, wname)
    %%% Perform adjoint (inverse) of 2D wavelet transform on every layer of w
    
    T = size(w,1);
    x = waverec2(w(1,:),wSz,wname);
    
    if T == 1 % Special case for 1D array
        X = x;
        return
    end
    
    xSz = size(x);
    X = zeros([xSz, T]);
    X(:,:,1) = x;
    
    for t = 2:T
        X(:,:,t) = waverec2(w(t,:),wSz,wname);
    end
end

function y = SoftThresh(x,mu)
    %%% Soft thresholding function
    if ~isreal(x)
        warning('We expect to only threshold real variables!')
    end
    y = (x - mu).*(x > mu) + (x + mu).*(x < -mu);
end   