%%%%%% stempo_pdfp_wavelet_2d_example.m %%%%%%
%
% Example code for reconstructing the STEMPO phantom data using l^1 wavelet
% regularization of a spatio-temporal 3d (2d + time) object. Functional is
% minimized using PDFP (primal-dual fixed point) algorithm and the
% regularization parameter is tuned automatically based on sparsity of the
% ground truth.
%
%%%%%%
%
% We wish to solve the minimization task
%
% argmin_{f >= 0} 1/2*|| Af - m ||_2^2 + alpha*|| Wf ||_1,
%
% where
%
%     |A_1      |      |f_1|         |m_1|
% A = |   ...   |, f = |...| and m = |...| .
%     |      A_T|      |f_T|         |m_T|
%
% Alpha is a positive regularization parameters and W is the 3D wavelet
% transform.
%
%%%%%%
%
% References:
% [Wavelet domain sparsity and PDFP]
% Purisha, Z., Rimpeläinen, J., Bubba, T., & Siltanen, S. (2017).
% "Controlled wavelet domain sparsity for x-ray tomography."
% Measurement Science and Technology, 29(1), 014002.
%
% [Spatio-temporal regularization (using shearlets and wavelets)]
% Bubba, T. A., Heikkilä, T., Huotari, S., Salmon, Y., & Siltanen, S. (2020).
% "Sparse dynamic tomography: a shearlet-based approach for iodine perfusion
% in plant stems."
% Inverse Problems, 36(9), 094002.
% 
%%%%%%
% 
% Requirements:
% ASTRA Toolbox
% https://www.astra-toolbox.com/
% Recommended v1.9 or higher
%
% Spot - A Linear-Operator Toolbox
% https://www.cs.ubc.ca/labs/scl/spot/
% Recommended v1.2
%
% HelTomo Toolbox
% https://github.com/Diagonalizable/HelTomo
% v2.0
%
% Wavelet Toolbox
% https://mathworks.com/products/wavelet.html
%
%%%%%%
%
% Created 9.9.2022 - Last edited 13.9.2022
% Tommi Heikkilä
% University of Helsinki

% Clear workspace
clear all
close all

%% Load data

% Sequence of 8x45 projection scans with 8 degree angle interval
binning = 8;
load(sprintf('stempo_seq8x45_2d_b%d.mat',binning))

%% Choose parameters

N = 2240 / binning; % Spatial resolution of 2d slices

% We wish to split the 8 full rotations (45 projections each) into 16
% half-turns (time steps). For simplicity we duplicate some projections to
% get 23 projections per time step. Note that allowing more overlap between
% the projections of consecutive time steps gives better time-resolution.
% Using more projections per time step decreases limited angle artifacts
% but increases motion artifacts.
Nangles = 23;
T = 16;
angleArray = [(0:8:176)';(176:8:359)'] + 360*[0:1:7];
angleArray = reshape(angleArray, [Nangles, T]);

% Visualization
figure;
plot(reshape(1:Nangles*T,[Nangles,T]), mod(angleArray, 360), '.', 'MarkerSize', 10)
title('Projection sampling')
ylabel('Angles (deg)')
xlabel('Time step')
ylim([0 360]);
xlim([0 Nangles*T])
xticks(0:23:Nangles*T)
xticklabels([sprintfc('      %d', 1:T), {' '}])
hold on
for f = 23:23:Nangles*T; xline(f,':'); end
hold off

% Reorganize the data in a similar manner to match the projection angles
mInd = [(1:23)'; (23:45)'] + 45*[0:7]; % Some projections get duplicated
m = permute(reshape(CtData.sinogram(mInd(:),:),[Nangles,T,N]),[1,3,2]);

%% Forward operator

% Build a block diagonal forward operator
% Note: it would suffice to just create two unique operators but this more
% general approach allows for a wider choice of projection samplings
opCell = cell(1,T);
for t = 1:T
    % Change the projection angles stored in CtData
    CtData.parameters.angles = angleArray(:,t);
    % Create and store the operator in a cell array
    opCell{t} = create_ct_operator_2d_fan_astra(CtData, N, N);
end
% cell{:} gives the content of a cell array as comma separated list
A = blkdiag(opCell{:});

% Normalize data and operator
Anorm = normest(A);
A = A/Anorm;
m = m(:)/Anorm;

%% Wavelet transform

% Wavelet decomposition level and type
level = 3;
wname = 'haar';

% Helper function
vec = @(x) x(:);

% One way to define the forward and adjoint operations
W.fwd = @(x) wavedec3(reshape(x,[N, N, T]),level,wname);
W.adj = @(w) vec(waverec3(w));

%% Run PDFP algorithm
% Set parameters
param.lambda = 0.99; % < Largest eigenvalue of W^TW
param.gamma = 1.99; % < 2/L, where L is the Lipschitz-constant of the gradient of 1/2*|| Ax - m ||_2^2
param.maxIter = 800;
param.normTol = 1e-4;
param.sparTol = 1e-2;

param.desiredSparsity = 0.3; % Ratio of "large" coefficients desired on the solution
param.psi = 1; % Initial alpha weight
param.omega = 1; % Tuning speed
param.kappa = 1e-7; % Threshold for "large" coefficients

% Iterate
[f, iter, info] = PDFPalgorithm(A, m, W, param);
fFinal = reshape(f, [N, N, T]);

%% Look at outcome

figure;
montage(fFinal,'DisplayRange', [])

%% Helper functions
function [f, iter, info] = PDFPalgorithm(A, m, W, param)
    % Main PDFP algorithm for reaching the solution iteratively
    fprintf('Begin PDFP algorithm... \n')
    tic;
    %% Unload PDFP parameters
    lambda = param.lambda;
    gamma = param.gamma;
    maxIter = param.maxIter;
    normTol = param.normTol;
    sparTol = param.sparTol;

    desiredSparsity = param.desiredSparsity;
    psi = param.psi;
    omega = param.omega;
    kappa = param.kappa;

    % Initialize
    fLen = size(A, 2);
    f = zeros(fLen,1);
    w = W.fwd(f);
    fSz = w.sizeINI;

    % Total number of coefficients is bit tricky but there are 7 detail
    % coefficient arrays per level and additional array of approximation
    % coefficients at the coarsest scale
    wLen = sum(prod([w.sizes(1:w.level,:),[8; 7*ones(w.level-1,1)]],2));

    relChange = nan(1,maxIter+1);
    dataFit = nan(1,maxIter);
    l1Norm = nan(1,maxIter);
    alphas = nan(1,maxIter);
    sparsity = nan(1,maxIter);
    e = 1;

    alpha = psi*1e-5;
    beta = omega*alpha; % Tuning step length
    iter = 1;
    alphas(1) = alpha;

    plotFlag = true;

    %% Iterate

    while (iter <= maxIter) && ((relChange(iter) > normTol) || (abs(e) > sparTol))   
        fOld = f;
        Af = A*f;
        dif = Af - m;
        BP = A'*dif;

        dataFit(iter) = norm(dif) / norm(m);

        % PDFP steps
        d = max(0, f - gamma*BP - lambda*W.adj(w));
        Wd = W.fwd(d);
        w.dec = cellfun(@plus, w.dec, Wd.dec, 'UniformOutput', false);
        w = IdMinusSoftThreshold(w, alpha*gamma/lambda);
        f = max(0, f - gamma*BP - lambda*W.adj(w));

        Wf = W.fwd(f); % Wavelet transform the current iterate
        spar = currentSparsity(Wf, wLen, kappa); % Compute current sparsity

        sparsity(iter) = spar;
        l1Norm(iter) = sum(cellfun(@(x) norm(x(:),1), Wf.dec));
        relChange(iter+1) = norm(f - fOld) / norm(fOld);

        eOld = e; % Old difference in sparsity
        e = spar - desiredSparsity; % Update

        % Change beta if controller error e changes sign
        if sign(e) ~= sign(eOld)
            beta = beta*(1-abs(e-eOld));
        end

        % Update alpha
        alpha = max(0, alpha + beta*e);
        alphas(iter) = alpha;

        if mod(iter,10) == 0
            fprintf('Iteration number %d reached \n', iter);
            fprintf('Relative change: %.5f, sparsity: %.3f \n', relChange(iter), spar);
        end
        if (mod(iter,20) == 0) && plotFlag
            figure(100)
            title(sprintf('Reconstruction at iter: %d', iter));
            montage(reshape(f,fSz), 'DisplayRange', []);

            figure(101)
            title(sprintf('Iteration: %d \n', iter));
            subplot(4,1,1)
            plot(sparsity(1:iter))
            ylabel('Sparsity')
            yline(desiredSparsity, 'r');

            subplot(4,1,2)
            semilogy(alphas(1:iter))
            ylabel('alpha')

            subplot(4,1,3)
            semilogy(relChange(1:iter))
            ylabel({'Relative'; 'change'})
            yline(normTol, 'r');

            subplot(4,1,4)
            semilogy(0.5*dataFit(1:iter).^2 + alphas(1:iter).*l1Norm(1:iter))
            ylabel({'Functional'; 'value'})
            xlabel('Iteration')
        end

        iter = iter+1;
    end
    iter = iter - 1;
    timeTot = toc;
    
    fprintf('Total computational time: %.1f s, approximately %.2f s per iteration \n', timeTot, timeTot / iter);
    
    if iter == maxIter
        fprintf('Maximum iteration count reached! Iteration stopped \n');
        relChange = relChange(2:end);
    else
        fprintf('Stopping criterion reached after %d iterations! \n', iter);
        % Cut stored arrays to correct length
        relChange = relChange(2:iter+1);
        dataFit = dataFit(1:iter);
        l1Norm = l1Norm(1:iter);
        alphas = alphas(1:iter);
        sparsity = sparsity(1:iter);
    end
    info.relChange = relChange;
    info.dataFit = dataFit;
    info.l1Norm = l1Norm;
    info.alphas = alphas;
    info.sparsity = sparsity;
    info.functionalValues = 0.5*dataFit.^2 + alphas.*l1Norm;
    info.timeTot = timeTot;
end

function w = IdMinusSoftThreshold(w, alpha)
    %%% Perform I - S_a(w), where I is the identity, S_a is the soft
    %%% thresholding operator and w is a wavedec3 structure
    
    % Coefficients are stored in a cell array
    for iii = 1:length(w.dec)
        c = w.dec{iii};
        % If |c| > alpha, then c - S_a(c) = sign(c)*alpha, 
        % otherwise c - S_a(c) = c - 0 = c;
        c(c > alpha) = alpha;
        c(c < -alpha) = -alpha;
        w.dec{iii} = c;
    end
end

function cs = currentSparsity(w, wLen, kappa)
    %%% Compute ratio of "nonzero" coefficients
    
    % Count the number of coefficients c such that |c| >= kappa in each
    % cell and sum them together
    Nbig = sum(cellfun(@(x) sum(abs(x) >= kappa, 'all'), w.dec));
    
     % Current sparsity is a ration of "big" coefficients wrt. total number of coefficients
    cs = Nbig / wLen;
end