clear; clc; close all

ratio = 2.0;

image_list = {'0_1', '0_2', '0_3', '0_4', '5_1', '5_2', '5_3', '5_4',...
    '15_1', '15_2', '15_3', '15_4', '45_1', '45_2', '45_3', '45_4', '45_5', '45_6',...
    '60_1', '60_2', '60_3', '60_4'};

improve_psnr_insample_all = zeros(9, length(image_list));
improve_psnr_outofsample_all = zeros(3, length(image_list));
improve_ssim_insample_all = zeros(9, length(image_list));
improve_ssim_outofsample_all = zeros(3, length(image_list));

results_folder = './Result/LBNLM/';
SR_results = zeros(12, 4);

fprintf("Paired LB-NLM super-resolution.\n");
for kk = 1:2
    if kk == 1
       fprintf("Self-Training:\n");
    else
       fprintf("Pooled-Training:\n"); 
    end
    
    count_failure = 0;
    count_success = 0;
    
    fprintf("Testing image: ");
    for ii = 1:length(image_list)
        image_name = image_list{ii};
        if ii < length(image_list)
            fprintf(" %s,", image_name);
        else
            fprintf(" %s.\n", image_name);
        end
        
        for i = 1:12
            
            if i <= 9
                basename = sprintf('%s_InSample_%d', image_name, i);
            else
                basename = sprintf('%s_OutSample_%d', image_name, i - 9);
            end
            % fprintf('Testing image name: %s.\n', basename);
            
            % basename = '0_1_OutSample_1';
            filename = ['./SubImages/' basename '.bmp'];
            image_low = imread(filename, 'bmp');
            if ~ismatrix(image_low)
                image_low = rgb2gray(image_low);
            end
            image_low = double(image_low);
            % [n_height, n_width] = size(image_low);
            
            image_up = imresize(uint8(image_low), ratio, 'bicubic');
            image_up = double(image_up);
            
            % Load ground truth
            filename = ['./SubImages/' basename '_GT.bmp'];
            image_high = imread(filename, 'bmp');
            if ~ismatrix(image_high)
                image_high = rgb2gray(image_high);
            end
            image_high = double(image_high);
            
            % Load Results
            if kk == 1
                filename = [results_folder basename '_LBNLM_self.bmp'];
            else
                filename = [results_folder basename '_LBNLM_pooled.bmp'];
            end
            image_filtered = imread(filename, 'bmp');
            if ~ismatrix(image_filtered)
                image_filtered = rgb2gray(image_filtered);
            end
            image_filtered = double(image_filtered);
            
            psnr_up = psnr(uint8(image_high), uint8(image_up));
            psnr_filtered = psnr(uint8(image_high), uint8(image_filtered));
            % fprintf('PSNR before and after LB-NLM filter: %.3f dB and %.3f dB.\n', psnr_up, psnr_filtered);
            
            ssim_up = ssim(uint8(image_high), uint8(image_up));
            ssim_filtered = ssim(uint8(image_high), uint8(image_filtered));
            % fprintf('SSIM before and after LB-NLM filter: %.3f and %.3f.\n\n', ssim_up, ssim_filtered);
            
            SR_results(i, 1) = psnr_up;
            SR_results(i, 2) = psnr_filtered;
            SR_results(i, 3) = ssim_up;
            SR_results(i, 4) = ssim_filtered;
            
            
            if psnr_up > psnr_filtered
                count_failure = count_failure + 1;
            else
                count_success = count_success + 1;
            end
            
        end
        
        
        improve_psnr_insample_all(:, ii) = SR_results(1:9, 2) - SR_results(1:9, 1);
        improve_ssim_insample_all(:, ii) = SR_results(1:9, 4) - SR_results(1:9, 3);
        
        improve_psnr_insample = mean(SR_results(1:9, 2)) - mean(SR_results(1:9, 1));
        improve_ssim_insample = mean(SR_results(1:9, 4)) - mean(SR_results(1:9, 3));
        % fprintf("The improvement of the PSNR for the in-sample images is %.4f.\n", improve_psnr_insample);
        %fprintf("The improvement of the SSIM for the in-sample images is %.4f.\n", improve_ssim_insample);
        
        improve_psnr_outofsample_all(:, ii) = SR_results(10:12, 2) - SR_results(10:12, 1);
        improve_ssim_outofsample_all(:, ii) = SR_results(10:12, 4) - SR_results(10:12, 3);
        
        improve_psnr_outofsample = mean(SR_results(10:12, 2)) - mean(SR_results(10:12, 1));
        improve_ssim_outofsample = mean(SR_results(10:12, 4)) - mean(SR_results(10:12, 3));
        % fprintf("The improvement of the PSNR for the out-of-sample images is %.4f.\n", improve_psnr_outofsample);
        % fprintf("The improvement of the SSIM for the out-of-sample images is %.4f.\n\n", improve_ssim_outofsample);
        
    end
    
    fprintf("The improvement of the PSNR for the in-sample images is %.4f.\n", mean(improve_psnr_insample_all(:)));
    fprintf("The improvement of the SSIM for the in-sample images is %.4f.\n", mean(improve_ssim_insample_all(:)));
    fprintf("The improvement of the PSNR for the out-of-sample images is %.4f.\n", mean(improve_psnr_outofsample_all(:)));
    fprintf("The improvement of the SSIM for the out-of-sample images is %.4f.\n\n", mean(improve_ssim_outofsample_all(:)));
    
    % fprintf("The failure rate is %.4f.\n\n", double(count_failure)/double(12.0 * length(image_list)));
end