function [modes, reduced_state_matrix, offline, relErrors, sPODapproximation, PODModes] = shifted_pod_offline(snapshot_matrix, discretization, offline, beta, shift)
% shifted_pod_offline - determines the shifted POD modes for a given 
% snapshot_matrix and equation, based on the values in the structs 
% discretization, and offline.
%
% inputs:
% - snapshot_matrix: cell array containing snapshot matrix of FOM solution
% for each parameter value
% - discretization: struct array containing discretization parameters
% - offline: struct containing parameters for the mode decomposition
% - beta: vector containing the values of the Arrhenius coefficient
% - shift: matrix containing paths for each traveling wave and for each
% time step; optional argument
%
% outputs:
% - modes: matrix containing the determined modes
% - reduced_state_matrix: cell array containing the coefficients and the 
% paths for each parameter value
% - offline: struct containing parameters for the mode decomposition
% - relErrors: vector containing the relative approximation error for each
% parameter value
% - sPODapproximation: cell array containing the shifted POD approximation 
% for each parameter value
% - PODModes: cell array containing the POD modes used before and after the
% switch (areas (i) and (ii)); only determined if offline.method equals 
% @subspaceDecompositionSPODswitch
%
% dependencies:
% - determineTemperatureShift (in LIB/SHIFT)
% - relative_L2_norm (in LIB/ERROR)
%
%--------------------------------------------------------------------------
% version 1.0 (July 27, 2021)
% authors:
% - Felix Black (TU Berlin), black@math.tu-berlin.de
% - Philipp Schulze (TU Berlin), pschulze@math.tu-berlin.de
% - Benjamin Unger (U Stuttgart), benjamin.unger@simtech.uni-stuttgart.de
%--------------------------------------------------------------------------


%% initialization

n_param_samples = length(snapshot_matrix); % number of parameter values

n_steps_space = size(snapshot_matrix{1}, 1); % number of steps in space times number of physical variables
dx = discretization(1).space.dx ; % spatial mesh width
nStepsPerVar = discretization(1).space.n_steps; % number of steps in space
n_vars = n_steps_space / nStepsPerVar ; % number of physical variables
n_modes = offline.n_modes; % number of transformed modes

% number of steps in time for each parameter value
n_timesteps = NaN(1,n_param_samples); 
for i=1:n_param_samples
    n_timesteps(i) = discretization(i).time.n_steps ;
end

totalNTimeSteps = sum(n_timesteps+1) ; % total number of snapshots

%% Construct shifts

fprintf('Starting shifted POD offline phase corresponding to the parameters\n');
fprintf('\t beta = %f\n',beta);
fprintf('\t ...');

if(nargin<5) % the shift is not provided a-priori
    for i=1:n_param_samples
        % determine shift based on the temperature snapshots for each
        % parameter values
        shift{i} = determineTemperatureShift(snapshot_matrix{i}(1:nStepsPerVar,:),dx,offline.shiftPostProcessing.offset);
    end
end

% concatenate all shifts for the different parameter samples
combinedShift = NaN(totalNTimeSteps,size(shift{1},2));
currentIndex = 1;
for i=1:n_param_samples
    combinedShift(currentIndex:sum(n_timesteps(1:i)+1),:) = shift{i};
    currentIndex = 1 + sum(n_timesteps(1:i)+1);
end

%% mode calculation

subspace_decomposition = offline.method; % method used for the mode decomposition

% initialization of shifted POD approximation for each parameter value
sPODapproximation = cell(1,n_param_samples);

% if the number of rows of n_modes equals one, it is assumed that the
% specified numbers in n_modes are for the temperature and for the supply
% mass fraction
if(size(n_modes,1)~=n_vars)
    if(size(n_modes,1)==1)
        n_modes = repmat(n_modes, n_vars, 1) ;
    else
        error('The number of rows of n_modes should be either one or it should coincide with the number of physical variables.')
    end
end

% total number of modes (counting modes for all traveling waves and for all
% physical variables)
totalNModes = sum(sum(n_modes)) ;

% initialization
sPODSystemTmp = cell(n_vars,1) ; % contains the results for the mode decomposition of each physical variable
qsTildeTemp = cell(n_vars, 1) ; % contains the shifted POD approximation for each physical variable
qsTilde = NaN(n_steps_space,sum(n_timesteps+1)) ; % same as qsTildeTemp, but saved as one fat matrix instead of as a cell array of matrices

options = struct(); % struct containing options for the mode decomposition

if(strcmp(func2str(offline.method),'subspaceDecompositionSPODswitch'))
    nPODModesBeforeSwitch = offline.nPODModesBeforeSwitch ; % number of POD modes in area (i)
    nPODModesAfterSwitch = offline.nPODModesAfterSwitch ; % number of POD modes in area (ii)
    % time index between areas (i) and (ii)
    offline.switchTimeIndex = ceil(offline.switchTime/discretization(1).time.length*n_timesteps(1)) ;
    switchTimeIndex = offline.switchTimeIndex ;
    options.switchTimeIndex = switchTimeIndex ;
    options.extrapolationMethod = offline.extrapolationMethod ; % method used for extrapolating the coefficients in area (ii-a)
    options.extrapolationPolOrder = offline.extrapolationPolOrder ; % degree of the extrapolation polynomial
    useSwitchedROM = true ; % each snapshot matrix is split into the two areas (i) and (ii)
    % time indices corresponding to area (ii)
    timeIndicesOf2ndInterval = NaN(1, totalNTimeSteps-n_param_samples*switchTimeIndex) ;
    timeHelper = cumsum([0 n_timesteps+1]) ;
    timeHelper2 = cumsum([0 n_timesteps+1-switchTimeIndex]) ;
    for i=1:n_param_samples
        timeIndicesOf2ndInterval(timeHelper2(i)+1:timeHelper2(i+1)) = timeHelper(i)+switchTimeIndex+1:timeHelper(i+1) ;
    end
else
    useSwitchedROM = false ; % snapshot matrix is not split
end

% Perform the mode decomposition for each physical variable separately
for i=1:n_vars
    offline.n_modes = n_modes(i,:) ; % number of modes for current physical variable
    if(useSwitchedROM) % if offline.method equals 'subspaceDecompositionSPODswitch'
        options.nPODModesBeforeSwitch = nPODModesBeforeSwitch(i) ; % number of POD modes in area (i)
        options.nPODModesAfterSwitch = nPODModesAfterSwitch(i) ; % number of POD modes in area (ii)
        options.shiftedPODstart = offline.shiftedPODstart(i) ; % time point between areas (ii-a) and (ii-b)
    end
    % assemble combined snapshot matrix for the parameters
    snapshotMatrixVar = NaN(nStepsPerVar,sum(n_timesteps+1));
    currentIndex = 1;
    for j=1:n_param_samples
        snapshotMatrixVar(:,currentIndex:sum(n_timesteps(1:j)+1)) = snapshot_matrix{j}((i-1)*nStepsPerVar+1:i*nStepsPerVar,:);
        currentIndex = sum(n_timesteps(1:j)+1)+1;
    end
    offline.nTimestepsPerParam = n_timesteps; % number of steps in time for each parameter value

    % perform the mode decomposition
    [sPODSystemTmp{i}, qsTildeTemp{i}] = subspace_decomposition(discretization(1), offline, snapshotMatrixVar, combinedShift, options) ;
    % transfer values from cell array qsTildeTemp to matrix qsTilde
    qsTilde((i-1)*nStepsPerVar+1:i*nStepsPerVar,:) = qsTildeTemp{i} ;
end

% compute relative Errors
relErrors = NaN(1,n_param_samples); % total relative error per parameter value
relErrorsPerVar = NaN(n_vars,n_param_samples); % relative error per parameter value and physical variable
currentIndex = 1;
for j=1:n_param_samples % loop over parameter samples
    sPODapproximation{j} = qsTilde(:,currentIndex:sum(n_timesteps(1:j)+1)); % current shifted POD approximation
    currentIndex = 1+sum(n_timesteps(1:j)+1);
    % compute relative L2 error for whole snapshot matrix
    relErrors(j) = relative_L2_norm(snapshot_matrix{j}, sPODapproximation{j}-snapshot_matrix{j});
    % compute relative L2 errors for each physical variable separately
    for i=1:n_vars
        sPODapproximationVar = sPODapproximation{j}((i-1)*nStepsPerVar+1:i*nStepsPerVar,:);
        snapshotMatrixVar = snapshot_matrix{j}((i-1)*nStepsPerVar+1:i*nStepsPerVar,:);
        relErrorsPerVar(i,j) = relative_L2_norm(snapshotMatrixVar,sPODapproximationVar-snapshotMatrixVar);
    end
end

nFrames = sPODSystemTmp{1}.n_subspaces ; % number of traveling waves
% get the modes for each frame from sPODSystemTmp
modesPerFrame = cell(nFrames, 1) ;
for i=1:nFrames
    modesPerFrame{i} = zeros(n_steps_space, sum(n_modes(:,i))) ;
    for j=1:n_vars
        modesPerFrame{i}((j-1)*nStepsPerVar+1:j*nStepsPerVar,sum(n_modes(1:j-1,i))+1:sum(n_modes(1:j,i))) = sPODSystemTmp{j}.U{i} ;
    end
end

% collect the modes of all frames in one mode matrix
modes = zeros(n_steps_space, sum(sum(n_modes)));
helper = [0, cumsum(sum(n_modes,1))];
for i = 1:nFrames
    modes(:,helper(i)+1:helper(i+1)) = modesPerFrame{i};
end

% modify offline.n_modes to collect mode numbers of all physical variables
offline.n_modes = sum(n_modes,1) ; 

if(useSwitchedROM) % if offline.method equals 'subspaceDecompositionSPODswitch'
    % collect POD modes for areas (i) and (ii)
    PODModes = cell(2, 1) ;
    PODModes{1} = zeros(n_steps_space, sum(nPODModesBeforeSwitch)) ;
    PODModes{2} = zeros(n_steps_space, sum(nPODModesAfterSwitch)) ;
    PODHelper1 = cumsum([0 nPODModesBeforeSwitch]) ;
    PODHelper2 = cumsum([0 nPODModesAfterSwitch]) ;
    for j=1:n_vars 
        PODModes{1}((j-1)*nStepsPerVar+1:j*nStepsPerVar,PODHelper1(j)+1:PODHelper1(j+1)) = sPODSystemTmp{j}.UPODbeforeSwitch ;
        PODModes{2}((j-1)*nStepsPerVar+1:j*nStepsPerVar,PODHelper2(j)+1:PODHelper2(j+1)) = sPODSystemTmp{j}.UPOD ;
    end 
    % collect coefficients and paths for each parameter value and each area
    redStateMatrixTmp = cell(2, 1) ;
    redStateMatrixTmp{1} = NaN(sum(nPODModesBeforeSwitch), n_param_samples*switchTimeIndex) ;
    redStateMatrixTmp{2} = NaN(totalNModes+sum(nPODModesAfterSwitch)+nFrames, totalNTimeSteps-n_param_samples*switchTimeIndex) ;
    for i=1:nFrames
        for j=1:n_vars
            offset = sum(sum(n_modes(:,1:i-1)))+sum(n_modes(1:j-1,i)) ;
            alphaIndices = sum(n_modes(j,1:i-1))+1:sum(n_modes(j,1:i)) ;
            % coefficients in area (ii) for the transformed modes for each 
            % frame and physical variable
            redStateMatrixTmp{2}(offset+(1:n_modes(j,i)),:) = sPODSystemTmp{j}.alpha(alphaIndices,timeIndicesOf2ndInterval) ;
        end
    end
    % paths in area (ii)
    redStateMatrixTmp{2}(end-nFrames+1:end,:) = combinedShift(timeIndicesOf2ndInterval,:)' ;
    for j=1:n_vars
        % coefficients in areas (i) and (ii) for the POD modes
        redStateMatrixTmp{1}(PODHelper1(j)+1:PODHelper1(j+1),:) = sPODSystemTmp{j}.PODcoefficientsBeforeSwitch ;
        redStateMatrixTmp{2}(totalNModes+(PODHelper2(j)+1:PODHelper2(j+1)),:) = sPODSystemTmp{j}.PODcoefficients ;
    end
    % transfer entries from cell array redStateMatrixTmp to nested cell
    % array reduced_state_matrix
    reduced_state_matrix = cell(n_param_samples, 1) ;
    for j=1:n_param_samples
        reduced_state_matrix{j} = cell(2,1) ;
        reduced_state_matrix{j}{1} = redStateMatrixTmp{1}(:,(j-1)*switchTimeIndex+(1:switchTimeIndex)) ;
        reduced_state_matrix{j}{2} = redStateMatrixTmp{2}(:,timeHelper2(j)+1:timeHelper2(j+1)) ;
    end
else
    % collect coefficients and paths for each parameter value
    redStateMatrixTmp = NaN(totalNModes+nFrames, sum(n_timesteps+1)) ;
    for i=1:nFrames
        for j=1:n_vars
            offset = sum(sum(n_modes(:,1:i-1)))+sum(n_modes(1:j-1,i)) ;
            % coefficients for the transformed modes for each frame and 
            % physical variable
            redStateMatrixTmp(offset+(1:n_modes(j,i)),:) = sPODSystemTmp{j}.alpha(sum(n_modes(j,1:i-1))+1:sum(n_modes(j,1:i)),:) ;
        end
    end
    redStateMatrixTmp(end-nFrames+1:end,:) = combinedShift' ; % paths
    % transfer entries from matrix redStateMatrixTmp to cell array 
    % reduced_state_matrix
    reduced_state_matrix = cell(n_param_samples, 1) ;
    currentIndex = 1;
    for j=1:n_param_samples
        reduced_state_matrix{j} = cell(1) ;
        reduced_state_matrix{j}{1} = redStateMatrixTmp(:,currentIndex:sum(n_timesteps(1:j)+1)) ;
        currentIndex = 1+sum(n_timesteps(1:j)+1);
    end
end

fprintf('... done!\n\n');

% print the relative errors
for j=1:n_param_samples
    fprintf('Parameter beta = %f\n -- Total relative offline error measured in L2 norm:\t\t\t%e \n', beta(j), relErrors(j));

    if(n_vars>1)  
        for i=1:n_vars
            fprintf(' -- Relative offline error of physical variable #%d measured in L2 norm:\t%e \n', i, relErrorsPerVar(i,j));
        end
    end
end