function [plot_handle] = inspect_MTM_prediction_results( Y_test_pred, Y_layer_test_pred, Y_test, Y_layer_test, X_test, mtm_params)
% [plot_handle] = inspect_MTM_prediction_results( Y_test_pred, Y_layer_test_pred, Y_test, Y_layer_test, X_test, mtm_params)
% Plot MTM prediction results.

%% ---- set pars and initialize ----

% unpack useful params
layer_bins=mtm_params.layer_bins;
n_ch=mtm_params.n_ch;
S_freq=mtm_params.S_freq;

% get colors
C = flipud(parula);
layer_cols = C(1:floor(size(C,1)/5):floor(size(C,1)/5)*5,:);


%% --- plot prediction results ---

plot_handle=figure('units','normalized','outerposition',[0 0 1 1]);

% plot results subplot
subplot(1,2,2)
% assign labels
all_labels=[1,2,4,5,6];
predicted_labels=unique(Y_layer_test_pred);
true_labels=unique(Y_layer_test);

% draw channel observations and prediction
hold on;
for idx_tl=1:numel(true_labels)
    tl=true_labels(idx_tl);
    selected_indexes=find(tl==Y_layer_test);
    plot(selected_indexes,Y_test(selected_indexes),'.','MarkerSize',20*2,'Color',[0,0,0]);
    plot(selected_indexes,Y_test(selected_indexes),'d','LineWidth',1*2,'Color',layer_cols(tl==all_labels,:));
end
for idx_pl=1:numel(predicted_labels)
    pl=predicted_labels(idx_pl);
    selected_indexes=find(pl==Y_layer_test_pred);
    plot(selected_indexes,Y_test_pred(selected_indexes),'.','MarkerSize',30*2,'Color',layer_cols(pl==all_labels,:));
end
selected_indexes=find(Y_layer_test_pred==Y_layer_test);
plot(selected_indexes,Y_test_pred(selected_indexes),'o','MarkerSize',6.5*2,'Color',[0,0,0]);

% draw layer borders used for prediction
plot([1,n_ch],[layer_bins(1),layer_bins(1)],'--','Color',[0.5,0.5,0.5],'LineWidth',2)
plot([1,n_ch],[layer_bins(2),layer_bins(2)],'--','Color',[0.5,0.5,0.5],'LineWidth',2)
plot([1,n_ch],[layer_bins(3),layer_bins(3)],'--','Color',[0.5,0.5,0.5],'LineWidth',2)
plot([1,n_ch],[layer_bins(4),layer_bins(4)],'--','Color',[0.5,0.5,0.5],'LineWidth',2)
plot([1,n_ch],[layer_bins(5),layer_bins(5)],'--','Color',[0.5,0.5,0.5],'LineWidth',2)

% adjust axis
xlabel('channel number')
ylabel('cortical depth (\mum)')
ylim([-100,1600])
xlim([0,32])

% write metrics
ylimit=get(gca,'ylim');
xlimit=get(gca,'xlim');
rmse=sqrt(mean((Y_test_pred-Y_test).^2));
acc=sum(Y_layer_test_pred==Y_layer_test)/numel(Y_test);
text(0.35*xlimit(2),0.80*ylimit(2),...
    ['rmse (depth) = ',num2str(rmse,'%.0f'),' \mum'],'FontSize',14);
text(0.35*xlimit(2),0.75*ylimit(2),...
    ['accuracy (layer) = ',num2str(acc,'%.2f')],'FontSize',14);
set(gca, 'YDir','reverse');
title('Predicted vs. observed depths and layer attributions')
axis square;
set(gca,'fontsize',12)

% plot test data subplot
subplot(1,2,1)
CG=colormap('parula');
trace_time = (1:size(X_test, 2))./S_freq*1000; % in ms
% plot each chennel of the input observed VEP pattern
for i = 1 : size(X_test, 1)
    hold on;
    clr = CG(2*(i-1)+1,:);
    current_trace = X_test(i,:);
    plot(trace_time, current_trace, 'Color', clr, 'LineWidth', 3 )
end
xlabel('time from stimulus onset (ms)')
ylabel('LFP (\muV)')
title('VEP dynamics across chennels')
axis square;
set(gca,'fontsize',12)

end

