clear; clc; close all
if isempty(gcp('nocreate'))
    parpool
end

N_p = 9; % The size of each patch (should be odd number)
N_p_radius = floor(N_p / 2.0);
N_l = 80000; % The number of patches in the library
sigma_n = 1.; % Weights of LBNLM
ratio = 2.; % SR factors


image_list = {'0_1', '0_2', '0_3', '0_4', '5_1', '5_2', '5_3', '5_4',...
    '15_1', '15_2', '15_3', '15_4', '45_1', '45_2', '45_3', '45_4', '45_5', '45_6',...
    '60_1', '60_2', '60_3', '60_4'};

improve_psnr_insample_all = zeros(9, length(image_list));
improve_psnr_outofsample_all = zeros(3, length(image_list));
improve_ssim_insample_all = zeros(9, length(image_list));
improve_ssim_outofsample_all = zeros(3, length(image_list));

library_name = sprintf('./Libraries/Library_Pooled_P%d_S%d_paired_shift.mat', N_p, N_l);
load(library_name);

for kk = 1:length(image_list)
    image_name = image_list{kk};
    fprintf("Testing image pair: %s.\n", image_name);
    
    [~,n_class] = size(Patch_center);
    n_per_class = round(N_l/n_class);
    % y_train = Patch_dictionary_high(N_p_radius * N_p + N_p_radius + 1, :);
    
    results_mat = zeros(3, 12); % Delta PSNR, SSIM, and time
    % return;
    
    for i = 1:12
        
        if i <= 9
            basename = sprintf('%s_InSample_%d', image_name, i);
        else
            basename = sprintf('%s_OutSample_%d', image_name, i - 9);
        end
        fprintf('Reconstruct subimage name: %s.\n', basename);
        
        % basename = '0_1_OutSample_1';
        filename = ['./SubImages/' basename '.bmp'];
        image_low = imread(filename, 'bmp');
        if ~ismatrix(image_low)
            image_low = rgb2gray(image_low);
        end
        image_low = double(image_low);
        % [n_height, n_width] = size(image_low);
        
        image_up = imresize(uint8(image_low), ratio, 'bicubic');
        image_up = double(image_up);
        filename = ['./Result/LBNLM/' basename '_BI.bmp'];
        imwrite(uint8(image_up),filename, 'bmp');
        
        % Load ground truth
        filename = ['./SubImages/' basename '_GT.bmp'];
        image_high = imread(filename, 'bmp');
        if ~ismatrix(image_high)
            image_high = rgb2gray(image_high);
        end
        image_high = double(image_high);
        filename = ['./Result/LBNLM/' basename '_GT.bmp'];
        imwrite(uint8(image_high),filename, 'bmp');
        
        % Output the reconstruction results (expection) & uncertainty (variance)
        tic;
        [n_height, n_width] = size(image_up);
        
        image_filtered_all = zeros(n_height, n_width, n_width);
        parfor ii = 1 : n_width
            
            ii_begin = max([1, ii - N_p_radius]);
            ii_end = min([n_width, ii + N_p_radius]);
            
            image_filtered = zeros(n_height, n_width);
            x_test_all = zeros(N_p * N_p, n_height);
            Patch_recon_all = zeros(N_p * N_p, n_height);
            
            for jj = 1 : n_height
                % Identify the closest center
                Patch_current = zeros(N_p, N_p) + 127.0;
                
                jj_begin = max([1, jj - N_p_radius]);
                jj_end = min([n_height, jj + N_p_radius]);
                Patch_current(N_p_radius + 1 + (jj_begin - jj):N_p_radius + 1 + (jj_end - jj), N_p_radius + 1 + (ii_begin - ii):N_p_radius + 1 + (ii_end - ii)) = image_up(jj_begin:jj_end, ii_begin:ii_end);
                
                x_test_all(:, jj) = Patch_current(:);
            end
            
            % Calculate K(X_test,X_train)
            center_dist = sq_dist(x_test_all, Patch_center);
            [~, k_min] = min(center_dist, [], 2);
            
            for k = 1:n_class
                
                k_index = find(k_min == k);
                
                if ~isempty(k_index)
                    w_test_train = sq_dist(x_test_all(:, k_index), Patch_dictionary_high(:, (k - 1) * n_per_class + 1:k * n_per_class));
                    [w_min, ~] = min(w_test_train, [], 2);
                    w_test_train = w_test_train - repmat(w_min, 1, n_per_class); % Set the minimum dis as 0
                    w_test_train = exp(-w_test_train ./ (2 * N_p * N_p * sigma_n * sigma_n));
                    
                    % Normaliztion
                    w_sum = sum(w_test_train, 2);
                    w_test_train = w_test_train ./ repmat(w_sum, 1, n_per_class);
                    
                    % Average of the whole HR patch
                    Patch_recon_all(:, k_index) = (w_test_train * Patch_dictionary_high(:, (k - 1) * n_per_class + 1:k * n_per_class)')';
                end
            end
            
            for jj = 1 : n_height
                jj_begin = max([1, jj - N_p_radius]);
                jj_end = min([n_height, jj + N_p_radius]);
                
                Patch_recon = reshape(Patch_recon_all(:, jj), N_p, N_p);
                image_filtered(jj_begin:jj_end, ii_begin:ii_end) = image_filtered(jj_begin:jj_end, ii_begin:ii_end) + Patch_recon(N_p_radius + 1 + (jj_begin - jj):N_p_radius + 1 + (jj_end - jj), N_p_radius + 1 + (ii_begin - ii):N_p_radius + 1 + (ii_end - ii));
                image_filtered(jj, ii) = image_filtered(jj, ii) + 127.0 * (N_p * N_p - (jj_end - jj_begin + 1) * (ii_end - ii_begin + 1));
            end
            
            image_filtered_all(:, :, ii) = image_filtered;
        end
        
        image_filtered = sum(image_filtered_all, 3);
        image_filtered = round(image_filtered ./ double(N_p * N_p));
        image_filtered(image_filtered < 0) = 0;
        image_filtered(image_filtered > 255) = 255;
        
        time_interval = toc;
        
        psnr_up = psnr(uint8(image_high), uint8(image_up));
        psnr_filtered = psnr(uint8(image_high), uint8(image_filtered));
        % fprintf('PSNR before and after LB-NLM filter: %.3f dB and %.3f dB.\n', psnr_up, psnr_filtered);
        
        ssim_up = ssim(uint8(image_high), uint8(image_up));
        ssim_filtered = ssim(uint8(image_high), uint8(image_filtered));
        % fprintf('SSIM before and after LB-NLM filter: %.3f and %.3f.\n', ssim_up, ssim_filtered);
        
        % fprintf('Inference time: %.3f s.\n\n', time_interval);
        
        results_mat(1, i) = psnr_filtered - psnr_up;
        results_mat(2, i) = ssim_filtered - ssim_up;
        results_mat(3, i) = time_interval;
        
        % Save result
        filename = ['./Result/LBNLM_or/' basename '_LBNLM_or_pooled.bmp'];
        imwrite(uint8(image_filtered), filename, 'bmp');
        
    end
    
    fprintf('Increase of PSNR in in-sample images: %.3f dB.\n', mean(results_mat(1, 1:9)));
    fprintf('Increase of SSIM in in-sample images: %.3f.\n', mean(results_mat(2, 1:9)));
    fprintf('Increase of PSNR in out-of-sample images: %.3f dB.\n', mean(results_mat(1, 10:12)));
    fprintf('Increase of SSIM in out-of-sample images: %.3f.\n', mean(results_mat(2, 10:12)));
    fprintf('Average inference time: %.3f s.\n\n', mean(results_mat(3,:)));
    
    improve_psnr_insample_all(:, kk) = mean(results_mat(1, 1:9));
    improve_ssim_insample_all(:, kk) = mean(results_mat(2, 1:9));
    
    improve_psnr_outofsample_all(:, kk) = mean(results_mat(1, 10:12));
    improve_ssim_outofsample_all(:, kk) = mean(results_mat(2, 10:12));
    
end

fprintf("All testing images completed.\n");
fprintf("The improvement of the PSNR for the in-sample images is %.4f.\n", mean(improve_psnr_insample_all(:)));
fprintf("The improvement of the SSIM for the in-sample images is %.4f.\n", mean(improve_ssim_insample_all(:)));
fprintf("The improvement of the PSNR for the out-of-sample images is %.4f.\n", mean(improve_psnr_outofsample_all(:)));
fprintf("The improvement of the SSIM for the out-of-sample images is %.4f.\n\n", mean(improve_ssim_outofsample_all(:)));

