clear
clc
load Progress_mhw_event_All.mat
load('E:\Data\SST\sst_1982.mat','lon');
load('E:\Data\SST\sst_1982.mat','lat');
%% Count events
[uniqRows, ~, idx] = unique([mhw_lat_all,mhw_lon_all], 'rows');
counts = accumarray(idx, 1);
result = [uniqRows counts];
result(:,1)=(result(:,1)+89.875)./0.25+1;
result(:,2)=(result(:,2)-0.125)./0.25+1;
[Lon,Lat]=meshgrid(lon,lat);
event_num=zeros(size(Lon));
event_num(sub2ind(size(Lon),result(:,1),result(:,2)))=result(:,3);
mhw_mean_chl_anomaly=mean(Chl_Progress_all(:,16:65),2,'omitmissing');
mean_chl_anomaly_event_grid = accumarray(idx, mhw_mean_chl_anomaly,[],@mean);
chl_anomaly_map=nan(size(Lon));
chl_anomaly_map(sub2ind(size(Lon),result(:,1),result(:,2)))=mean_chl_anomaly_event_grid;
%%
% % Grid-point mean trajectory
% event_curve = Chl_Progress_all(:,16:65);   % [nEvent × 50]
% T = size(event_curve,2);
% chl_curve_grid = nan(max(idx), T);
% 
% for t = 1:T
%     t
%     chl_curve_grid(:,t) = accumarray(idx, event_curve(:,t), [], @(x) mean(x,'omitnan'));
% end
% 
% % Valid grid points
% valid = counts >= 5 & ~isnan(mean_chl_anomaly_event_grid);
% valid = valid & sum(~isnan(chl_curve_grid),2) >= 40;
% 
% mu    = mean_chl_anomaly_event_grid(valid);
% curve = chl_curve_grid(valid,:);
% save('curve_grid','curve',"mu","mean_chl_anomaly_event_grid","chl_curve_grid")
load curve_grid.mat
%%
rng(1)
% 2. Group into 5 bins by quantiles
nBin = 5;
edges = [-30 -20 -10 0 10 20];
bin_id = discretize(mu, edges);
% 3. Randomly sample at most Nmax grid points per bin
Nmax = 200;   % Tunable, 150~300 is reasonable
bin_indices = cell(nBin,1);
for b = 1:nBin
    ind = find(bin_id == b);
    if numel(ind) > Nmax
        ind = ind(randperm(numel(ind), Nmax));
    end
    bin_indices{b} = ind;
end
% 4. Compute trajectory dissimilarity matrix between bins
M_traj = nan(nBin, nBin);   % Median trajectory dissimilarity
M_n    = nan(nBin, nBin);   % Number of pairs in each bin
for i = 1:nBin
    for j = i:nBin
        ind_i = bin_indices{i};
        ind_j = bin_indices{j};

        if isempty(ind_i) || isempty(ind_j)
            continue
        end

        dij = [];

        if i == j
            % Within same bin: all pairwise combos (excluding self)
            for a = 1:length(ind_i)-1
                x = curve(ind_i(a),:)';
                for b = a+1:length(ind_i)
                    y = curve(ind_i(b),:)';
                    r = corr(x, y, 'rows','pairwise');
                    if isfinite(r)
                        dij(end+1,1) = 1 - r;
                    end
                end
            end
        else
            % Between different bins: full cross-combination
            for a = 1:length(ind_i)
                x = curve(ind_i(a),:)';
                for b = 1:length(ind_j)
                    y = curve(ind_j(b),:)';
                    r = corr(x, y, 'rows','pairwise');
                    if isfinite(r)
                        dij(end+1,1) = 1 - r;
                    end
                end
            end
        end

        if ~isempty(dij)
            M_traj(i,j) = median(dij,'omitnan');
            M_traj(j,i) = M_traj(i,j);

            M_n(i,j) = numel(dij);
            M_n(j,i) = M_n(i,j);
        end
    end
end
%%  Plot heatmap
nBin = size(M_traj,1);
% Lower triangle
M_plot = M_traj;
M_plot(triu(true(nBin),1)) = NaN;
% Labels
labels = arrayfun(@(k) sprintf('[%.2f, %.2f]', edges(k), edges(k+1)), ...
    1:nBin, 'UniformOutput', false);
% Plotting
figure('Color','w');
imagesc(M_plot, 'AlphaData', ~isnan(M_plot));
xlabel('Event-mean chlorophyll anomaly');
ylabel('Event-mean chlorophyll anomaly');
set(gca, 'YDir','normal', ...
    'XTick',[0.5 1.5 2.5 3.5 4.5 5.5], 'YTick',[0.5 1.5 2.5 3.5 4.5 5.5], ...
    'XTickLabel',edges, 'YTickLabel',edges, ...
    'FontSize',12, 'LineWidth',1,'TickLength',[0 0],'FontWeight','bold');
colormap("parula")
cb=colorbar
cb.Location='south'
cb.Label.String='Trajectory Distance'
cb.Label.FontSize=12
clim([0.8 1])
% Annotate values
for i = 1:nBin
    for j = 1:nBin
        if ~isnan(M_plot(i,j))
            text(j, i, sprintf('%.2f', M_plot(i,j)), ...
                'HorizontalAlignment','center', ...
                'FontSize',10, 'Color','k','FontWeight','bold');
        end
    end
end
text(0,1.035,"(E)","FontSize",13,'FontWeight','bold','Units','normalized')
% exportgraphics(gca,'traject_similarity.png','resolution',500)
%%
figure('Position',[10 10 800 500])
chl_anomaly_map(event_num<5)=nan;
m_proj('robinson','lon',[0 360]);
m_pcolor(Lon,Lat,chl_anomaly_map)
hold on
shading interp
m_gshhs('cc','patch',[240 240 240]./256)
m_plot_rectangle(1,31,31,45)
m_plot_rectangle(32,40,-35,-18)
m_plot_rectangle(150,200,-5,5)
m_plot_rectangle(285,305,35,42)
m_plot_rectangle(280,305,-57,-40)
m_plot_rectangle(240,280,-5,5)
m_grid('tickdir','out','linewi',2,'linestyle','none','gridcolor','k','fontsize',13,'gridlinewidth',1,'fontweight','bold');
clim([-40 40])
colormap(nclCM('GMT_polar'))
c1=colorbar("southoutside")
c1.FontSize=13
c1.FontWeight='bold'
c1.Label.String="Chl Anomaly(%)"
c1.Ticks=[-40 -20 0 20 40]
text(0,1.035,"(D)","FontSize",13,'FontWeight','bold','Units','normalized')
% exportgraphics(gca,'chl_event_map_global.png','Resolution',500)
%%
txt={'(A) Mediterranean Sea','(B) Western Equatorial Pacific','(C) Gulf Stream Extension','(F) Mozambique Channel','(G) Eastern Equatorial Pacific','(H) Patagonian Shelf Region'};
box_lim=[1 31 31 45;150 200 -5 5;285 305 35 42;...
    32 40 -35 -18;240 280 -5 5;280 305 -57 -40];
for i=1:6
    x1=box_lim(i,1);x2=box_lim(i,2);y1=box_lim(i,3);y2=box_lim(i,4);
    min_mld_id=mhw_lat_all>y1&mhw_lat_all<y2&mhw_lon_all<x2&mhw_lon_all>x1;
    chl_process_here=mean(Chl_Progress_all(min_mld_id,16:65),1,"omitmissing");
    mld_process_here=mean(MLD_Progress_all(min_mld_id,16:65),1,'omitmissing');
    sst_process_here=mean(SST_Progress_all(min_mld_id,16:65),1,'omitmissing');
    f1=figure('Position',[10 10 530 350])
    % [h1,h2,yn]=my_yyaxis(3,1:50,[sst_process_here;mld_process_here;chl_process_here]',[],{"SST Anomaly(°C)","MLD Anomaly(%)","Chla Anomaly(%)"},[249, 56, 39;60, 60, 230;22, 196, 127]./256,1);
    % [h1,h2,yn]=my_yyaxis(3,1:50,[sst_process_here;mld_process_here;chl_process_here]',[],{"SST Anomaly(°C)","MLD Anomaly(%)","Chla Anomaly(%)"},[241, 90, 36;40, 82, 148;34, 139, 115]./255,1);
    % [h1,h2,yn]=my_yyaxis(3,1:50,[sst_process_here;mld_process_here;chl_process_here]',[],{'','','Chla Anomaly(%)'},[241, 90, 36;40, 82, 148;34, 139, 115]./255,1);
    [h1,h2,yn]=my_yyaxis(3,1:50,[sst_process_here;mld_process_here;chl_process_here]',[],{},[241, 90, 36;40, 82, 148;34, 139, 115]./255,1);
    % grid on
    xlim([1 50])
    set(h1,'fontsize',16);set(h2,'fontsize',16);
    set(gca,'xtick',[1,10:10:50])
    text(-0.13,1.05,txt{i},'Units','normalized','FontSize',16,'FontWeight','bold')
    exportgraphics(f1,['process_',txt{i},'.png'],'Resolution',350)
end
%% Custom functions
function m_plot_rectangle(x1,x2,y1,y2)
m_plot([x1,x2],[y1,y1],'color','k','linewidth',1.5)
m_plot([x1,x1],[y1,y2],'color','k','linewidth',1.5)
m_plot([x2,x2],[y1,y2],'color','k','linewidth',1.5)
m_plot([x1,x2],[y2,y2],'color','k','linewidth',1.5)
end