function [shiftedPODSystem, snapshotMatrixApprox] = subspaceDecompositionSPODswitch(discretization, offline, snapshotMatrix, shift, options)  
% subspaceDecompositionSPODswitch - computes a mode
% decomposition using shifted POD based on a simple domain division in the
% middle of the computational domain
%
% inputs:
% - discretization: struct array containing discretization parameters
% - offline: struct containing parameters for the mode decomposition
% - snapshotMatrix: snapshot matrix of FOM solution
% - shift: matrix containing path values for each time step and each
% traveling wave
% - options: struct containing parameters for the mode decomposition such 
% as the numbers of POD modes in the areas (i) and (ii)
%
% outputs:
% - shiftedPODSystem: struct containing results of the mode decomposition
% such as the modes and the corresponding coefficients
% - snapshotMatrixApprox: approximation of the snapshot matrix
%
% dependencies:
% - init_shift_matrix (in LIB/SHIFT)
% - shift_matrix (in LIB/SHIFT)
%
%--------------------------------------------------------------------------
% version 1.0 (July 27, 2021)
% authors:
% - Felix Black (TU Berlin), black@math.tu-berlin.de
% - Philipp Schulze (TU Berlin), pschulze@math.tu-berlin.de
% - Benjamin Unger (U Stuttgart), benjamin.unger@simtech.uni-stuttgart.de
%--------------------------------------------------------------------------

%% setup

% general settings
n_steps_space = size(snapshotMatrix, 1) ; % number of rows of snapshot matrix
space_midpoint_ind = floor(n_steps_space/2) ; % index corresponding to middle of the computational domain
n_steps_time = size(snapshotMatrix, 2) ; % number of snapshots
n_frames = 2 ; % number of traveling waves
n_modes = offline.n_modes ; % number of transformed modes
n_params = length(offline.nTimestepsPerParam) ; % number of parameter values

shift_matrix_generator = offline.shift_matrix_generator ; % function used for creating the shift matrices

% method-specific options
nPODModesBeforeSwitch = options.nPODModesBeforeSwitch ; % number of POD modes in area (i)
nPODModesAfterSwitch = options.nPODModesAfterSwitch ; % number of POD modes in area (ii)
switchTimeIndex = options.switchTimeIndex ; % time index between areas (i) and (ii)

if nPODModesAfterSwitch > 0 % if POD modes are used in area (ii)
    extrapolationMethod = options.extrapolationMethod; % method used for extrapolating the coefficients in area (ii-a)
    extrapolationPolOrder = options.extrapolationPolOrder; % degree of the extrapolation polynomial
    shiftedPODstart = options.shiftedPODstart; % time point between areas (ii-a) and (ii-b)
else
     % if no POD modes are used, then the coefficients in area (ii-a) are
     % just obtained via the SVD in the co-moving frame and area (ii-a)
     % vanishes, i.e., area (ii) consists only of area (ii-b)
    extrapolationMethod = 'projection';
    shiftedPODstart = 0;
end

% number of time steps in area (ii)
nTimestepsAfterSwitch   = (offline.nTimestepsPerParam+1-switchTimeIndex);
% first time index of area (ii-b)
firstSPODIndex          = floor(nTimestepsAfterSwitch*shiftedPODstart)+1;

% initialization and precomputations
snapshotMatrixApprox = zeros(size(snapshotMatrix)); % approximation of snapshot matrix
paramTimeStepIndices = [0 offline.nTimestepsPerParam+1]; % add the zero to make computations easier later on
% time indices corresponding to area (i) (initialized as if there would be 
% only one parameter sample)
beforeSwitchIdx = 1:switchTimeIndex; 
% time indices corresponding to area (ii) (initialized as if there would be 
% only one parameter sample)
afterSwitchIdx = (switchTimeIndex+1):offline.nTimestepsPerParam(1)+1;
% indices corresponding to area (ii-b); these indices are relative to the 
% snapshot data after the switch
sPODindices = firstSPODIndex(1):nTimestepsAfterSwitch(1);
% sPODindicesParam is as sPODindices, but values are stored in a cell array
sPODindicesParam = cell(1,n_params);
sPODindicesParam{1} = sPODindices;
% add indices for the other parameter values
if n_params>1
    for k=2:n_params
        beforeSwitchIdx = [beforeSwitchIdx (1:switchTimeIndex)+sum(paramTimeStepIndices(1:k))];
        afterSwitchIdx = [afterSwitchIdx ((switchTimeIndex+1):(offline.nTimestepsPerParam(k)+1))+sum(paramTimeStepIndices(1:k))];
        sPODindicesParam{k} = (firstSPODIndex(k):nTimestepsAfterSwitch(k)) + sum(nTimestepsAfterSwitch(1:(k-1)));
        sPODindices = [sPODindices sPODindicesParam{k}];
    end
end

% snapshot data in area (i)
dataBeforeSwitch = snapshotMatrix(:,beforeSwitchIdx);
% snapshot data in area (ii)
dataAfterSwitch = snapshotMatrix(:,afterSwitchIdx);

% construct the shift matrices for each time step in area (ii)
shiftedPODSystem = init_shift_matrix(discretization, shift_matrix_generator, shift(afterSwitchIdx,:)) ;
% construct the shift matrices shifting in the opposite direction for each time step in area (ii)
backShift = init_shift_matrix(discretization, shift_matrix_generator, -shift(afterSwitchIdx,:)) ;


%% initialize result variables
modes = cell(n_frames,1); % modes per traveling wave
coefficients = cell(n_frames,1); % corresponding coefficients

% coefficients are also stored in shiftedPODSystem struct
shiftedPODSystem.alpha = NaN(sum(n_modes), n_steps_time) ;

%% compute approximation before switch (simple POD)

% compute POD approximation of snapshot data in area (i)
[U,S,V] = svds(dataBeforeSwitch, nPODModesBeforeSwitch) ;
snapshotMatrixApprox(:,beforeSwitchIdx) = U*S*V';

shiftedPODSystem.UPODbeforeSwitch = U; % POD modes in area (i)
shiftedPODSystem.PODcoefficientsBeforeSwitch = S*V'; % POD coefficients in area (i)

%% compute approximation after switch per Frame

% initialization
dataAfterSwitchApprox = zeros(size(dataAfterSwitch)); % approximation of snapshots in area (ii)

for i=1:n_frames
    % roughly split the snapshot matrix in the middle and fill one half
    % with zeros
    frameSnapshotMatrix = zeros(size(dataAfterSwitch));
    frameSnapshotMatrix((i-1)*space_midpoint_ind+1:i*space_midpoint_ind,:) = dataAfterSwitch((i-1)*space_midpoint_ind+1:i*space_midpoint_ind,:) ;
    
    % shift the snapshots of the into the co-moving reference frame
    shiftedFrameSnapshotMatrix = shift_matrix(shiftedPODSystem.shift_matrices{i,1}, frameSnapshotMatrix); 
    % the output of shift_matrix is a 3D array and thus we have to reshape 
    % before plotting
    shiftedFrameSnapshotMatrix = reshape(shiftedFrameSnapshotMatrix,n_steps_space,[]) ;
    
    % compute modes via SVD of the snapshots in area (ii-b)
    [modes{i},S,V] = svds(shiftedFrameSnapshotMatrix(:,sPODindices),n_modes(i));
    
    % compute coefficients
    switch extrapolationMethod
        case 'projection'
            % compute coefficients via orthogonal projection onto the
            % snapshots
            coefficients{i} = modes{i}'*shiftedFrameSnapshotMatrix;
        case 'polynomial'
            % in area (ii-b), the coefficients are obtained from the SVD
            % and in area (ii-a) the coefficients are obtained by
            % extrapolating the coefficients determined for area (ii-b)
            extrapolatedCoefficients = zeros(n_modes(i),size(shiftedFrameSnapshotMatrix,2)); % temporary variable for the coefficients
            extrapolatedCoefficients(:,sPODindices) = S*V'; % coefficients in area (ii-b)
            for k=1:n_params
                % time indices where coefficients are to be determined via
                % extrapolation
                if k>1
                	extrapolationIdx = sPODindicesParam{k-1}(end):(sPODindicesParam{k}(1)-1);
                else
                	extrapolationIdx = 1:(sPODindicesParam{k}(1)-1);
                end
                for j=1:n_modes(i) % loop over coefficients corresponding to current traveling wave
                    % fit a polynomial based on the determined coefficient
                    % values in area (ii-b)
                    p = polyfit(sPODindicesParam{k},extrapolatedCoefficients(j,sPODindicesParam{k}),extrapolationPolOrder);
                    % evaluate the extrapolation polynomial within area
                    % (ii-a)
                    extrapolatedCoefficients(j,extrapolationIdx) = polyval(p,extrapolationIdx); 
                end
            end
            % transfer values of extrapolatedCoefficients to cell array 
            % coefficients
            coefficients{i} = extrapolatedCoefficients;
    end
    % approximation of the current traveling wave in the co-moving frame
    shiftedFrameApprox = modes{i}*coefficients{i} ;
    
    % shift approximation back into the lab frame
    frameApprox = shift_matrix(backShift.shift_matrices{i,1}, shiftedFrameApprox);
    frameApprox = reshape(frameApprox,n_steps_space,[]) ;
    
    % update the approximation of the snapshot data in area (ii)
    dataAfterSwitchApprox = dataAfterSwitchApprox + frameApprox ;
    % transfer coefficients into the shiftedPODSystem struct
    shiftedPODSystem.alpha(sum(n_modes(1:i-1))+1:sum(n_modes(1:i)),afterSwitchIdx) = coefficients{i} ;
end

% error for snapshot data in area (ii)
deviation = dataAfterSwitch - dataAfterSwitchApprox;

% compute POD modes for the error
[U, ~, ~] = svds(deviation, nPODModesAfterSwitch) ;
% compute POD coefficients for the error
PODCoefficients = U'*deviation ;
% update the approximation of the snapshots in area (ii) by incorporating
% the POD modes
dataAfterSwitchApprox = dataAfterSwitchApprox + U*PODCoefficients ;
% update the approximation of the complete snapshot matrix
snapshotMatrixApprox(:,afterSwitchIdx) = dataAfterSwitchApprox;

shiftedPODSystem.UPOD = U ; % POD modes in area (ii)
shiftedPODSystem.PODcoefficients = PODCoefficients; % POD coefficients in area (ii)
shiftedPODSystem.U = modes; % transformed modes in area (ii)