% =========================================================================
% Simple demo codes for image super-resolution via sparse representation
%
% Reference
%   J. Yang et al. Image super-resolution as sparse representation of raw
%   image patches. CVPR 2008.
%   J. Yang et al. Image super-resolution via sparse representation. IEEE
%   Transactions on Image Processing, Vol 19, Issue 11, pp2861-2873, 2010
%
% Jianchao Yang
% ECE Department, University of Illinois at Urbana-Champaign
% For any questions, send email to jyang29@uiuc.edu
% =========================================================================

clear variables; clc; close all

% im_l = imresize(im, 0.5, 'bicubic');
% imwrite(im_l, [pathstr '/' name '_0.bmp'], 'bmp');

% set parameters
lambda = 0.15;                  % sparsity regularization
overlap = 4;                    % the more overlap the better (patch size 5x5)
up_scale = 2;                   % scaling factor, depending on the trained dictionary
maxIter = 20;                   % if 0, do not use backprojection
dict_size   = 1024;             % dictionary size
patch_size  = 5;                % image patch size


% load dictionary
% load('Dictionary/D_512_0.15_5_11_15_0_58.mat');
% load('Dictionary/D_1024_0.15_5_11_15_22_10.mat');
image_list = {'0_1', '0_2', '0_3', '0_4', '5_1', '5_2', '5_3', '5_4',...
    '15_1', '15_2', '15_3', '15_4', '45_1', '45_2', '45_3', '45_4', '45_5', '45_6',...
    '60_1', '60_2', '60_3', '60_4'};

improve_psnr_insample_all = zeros(9, length(image_list));
improve_psnr_outofsample_all = zeros(3, length(image_list));
improve_ssim_insample_all = zeros(9, length(image_list));
improve_ssim_outofsample_all = zeros(3, length(image_list));

SR_results = zeros(12, 4);

for ii = 1:length(image_list)
    image_name = image_list{ii};
    fprintf("Testing image: %s.\n", image_name);
    
    dict_path = ['Dictionary/D_' num2str(dict_size) '_' num2str(lambda) '_' num2str(patch_size) ...
        '_' image_list{ii} '.mat' ];
    load(dict_path)
    
    for i = 1:12
        
        if i <= 9
            basename = sprintf('%s_InSample_%d', image_name, i);
        else
            basename = sprintf('%s_OutSample_%d', image_name, i - 9);
        end
        % fprintf('Testing image name: %s.\n', basename);
        if i < 12
            fprintf('%d ', i);
        else
            fprintf('%d\n', i);
        end
        
        % basename = '0_1_OutSample_1';
        filename = ['../SubImages/' basename '.bmp'];
        image_low = imread(filename, 'bmp');
        if ~ismatrix(image_low)
            image_low = rgb2gray(image_low);
        end
        image_low = double(image_low);
        % [n_height, n_width] = size(image_low);
        
        image_up = imresize(uint8(image_low), up_scale, 'bicubic');
        image_up = double(image_up);
        filename = ['../Result/ScSR/Self/' basename '_BI.bmp'];
        imwrite(uint8(image_up),filename, 'bmp');
        
        % Load ground truth
        filename = ['../SubImages/' basename '_GT.bmp'];
        image_high = imread(filename, 'bmp');
        if ~ismatrix(image_high)
            image_high = rgb2gray(image_high);
        end
        image_high = double(image_high);
        filename = ['../Result/ScSR/Self/' basename '_GT.bmp'];
        imwrite(uint8(image_high),filename, 'bmp');
        
        % SR
        image_filtered = ScSR(image_low, 2, Dh, Dl, lambda, overlap);
        filename = ['../Result/ScSR/Self/' basename '_ScSR.bmp'];
        imwrite(uint8(image_filtered),filename, 'bmp');
        
        psnr_up = psnr(uint8(image_high), uint8(image_up));
        psnr_filtered = psnr(uint8(image_high), uint8(image_filtered));
        % fprintf('PSNR before and after LB-NLM filter: %.3f dB and %.3f dB.\n', psnr_up, psnr_filtered);
        
        ssim_up = ssim(uint8(image_high), uint8(image_up));
        ssim_filtered = ssim(uint8(image_high), uint8(image_filtered));
        % fprintf('SSIM before and after LB-NLM filter: %.3f and %.3f.\n\n', ssim_up, ssim_filtered);
        
        SR_results(i, 1) = psnr_up;
        SR_results(i, 2) = psnr_filtered;
        SR_results(i, 3) = ssim_up;
        SR_results(i, 4) = ssim_filtered;
        
        % fprintf('\nImprovemnt of PSNR:%.4f.\n', psnr_filtered - psnr_up);
        % fprintf('Improvemnt of SSIM:%.4f.\n', ssim_filtered - ssim_up);
        
    end
    
    %filename = ['./Results/Testing_All/' image_name '_ScSR', '.mat'];
    % save(filename, 'SR_results');
    
    improve_psnr_insample_all(:, ii) = SR_results(1:9, 2) - SR_results(1:9, 1);
    improve_ssim_insample_all(:, ii) = SR_results(1:9, 4) - SR_results(1:9, 3);
    
    improve_psnr_insample = mean(SR_results(1:9, 2)) - mean(SR_results(1:9, 1));
    improve_ssim_insample = mean(SR_results(1:9, 4)) - mean(SR_results(1:9, 3));
    fprintf("The improvement of the PSNR for the in-sample images is %.4f.\n", improve_psnr_insample);
    fprintf("The improvement of the SSIM for the in-sample images is %.4f.\n", improve_ssim_insample);
    
    improve_psnr_outofsample_all(:, ii) = SR_results(10:12, 2) - SR_results(10:12, 1);
    improve_ssim_outofsample_all(:, ii) = SR_results(10:12, 4) - SR_results(10:12, 3);
    
    improve_psnr_outofsample = mean(SR_results(10:12, 2)) - mean(SR_results(10:12, 1));
    improve_ssim_outofsample = mean(SR_results(10:12, 4)) - mean(SR_results(10:12, 3));
    fprintf("The improvement of the PSNR for the out-of-sample images is %.4f.\n", improve_psnr_outofsample);
    fprintf("The improvement of the SSIM for the out-of-sample images is %.4f.\n\n", improve_ssim_outofsample);
    
end

fprintf("All testing images completed.\n");
fprintf("The improvement of the PSNR for the in-sample images is %.4f.\n", mean(improve_psnr_insample_all(:)));
fprintf("The improvement of the SSIM for the in-sample images is %.4f.\n", mean(improve_ssim_insample_all(:)));
fprintf("The improvement of the PSNR for the out-of-sample images is %.4f.\n", mean(improve_psnr_outofsample_all(:)));
fprintf("The improvement of the SSIM for the out-of-sample images is %.4f.\n\n", mean(improve_ssim_outofsample_all(:)));

