% Date: 11.09.2023
% Project: Signatures  
% Script: Heterozygosity plots (Fogarty & Otto, Figure 2)
% Author: Laurel Fogarty

% Figure 2 is a two panel figure illustrating change in heterozygosity from
% initial to sim end as a result of hitchhiking of a neutral
% genetic allele on a cultural trait under selection. 

% Dependency: hatchfill.m written by Neil Tandon available at https://de.mathworks.com/matlabcentral/fileexchange/30733-hatchfill

clear all
close all

nGen = 10000;
N = 10^6;

SC = 0.1;
x0 = 1/N;

r_vec = (0:0.000001:0.0001);

%% Acultural haploid model (calculated for checks and balances)

HH0_r = zeros(numel(r_vec),1);

for rval = 1:numel(r_vec)
    
    r = r_vec(rval); 
    
    p0 = 1/N;
    Q0 = 1;
    R0 = 0.5;
    
    p = p0;
    Q = Q0;
    R = R0;
    
    for t = 1:nGen
        
        % Find P'
        
        WBAR =1+p*SC;
        Pprime = p*(1+SC)*(1+p*SC)/(WBAR^2);
        
        % Find Q'
        
        f1 = (p^2)*(Q^2);
        f2 = 2*(p^2)*Q*(1 - Q);
        f3 = 2*p*(1 - p)*Q*R;
        f4 = 2*p*(1 - p)*Q*(1 - R);
        f5 = (p^2)*((1-Q)^2);
        f6 = 2*p*(1 - p)*(1 - Q)*R;
        f7 = 2*p*(1 - p)*(1 - Q)*(1 - R);
        f8 = ((1 - p)^2)*(R^2);
        f9 = 2*((1 - p)^2)*R*(1 - R);
        f10 = ((1 - p)^2)*((1 - R)^2);
        
        ACprime = ((f1*((1+SC)^2) + f2*((1+SC)^2)/2 + f3*(1+SC)/2 + f4*((1 - r)/2)*(1+SC) + f6*(r/2)*(1+SC)))/(WBAR^2);
        Acprime = (f3*(1+SC)/2 + f4*(1+SC)*r/2 + f6*(1+SC)*(1-r)/2 + f8 + f9/2)/(WBAR^2);
        
        Qprime = ((1+p*SC)*Q + r*(1-p)*(R-Q))/(1+p*SC); %ACprime/Pprime;
        Rprime = Acprime/(1-Pprime);
        
        % Find frequency of A allele
        
        freqAprime = ACprime+Acprime;
        
        % update and record everything
        
        p = Pprime; % update p
        Q = Qprime;
        R = Rprime; 
    end
    
    HH0_r(rval) = Q*(1-Q)/(R0*(1-R0));
end

%% gene-culture system, vertical transmission with affinity bias

b1_vec = 0.0001:0.0001:0.1;
b2_vec = 0.0001:0.0001:0.1;

HH0_cv1 = zeros(numel(b1_vec),numel(b2_vec));

for b1 = 1:numel(b1_vec)
    beta_1 = b1_vec(b1);
    
    for b2 = 1:numel(b2_vec)
        beta_2 = b2_vec(b2);
        
        p0 = 1/N;
        Q0 = 1;
        R0 = 0.5;
        
        p = p0;
        Q = Q0;
        R = R0;
        
        for t = 1:nGen
            
            % Find P'
            
            WBAR =1+p*SC;
            Pprime1 = (p*(1+SC)*((1+p*SC)+(1-p)*(R-Q)*(beta_1-beta_2)))/(WBAR^2);
            
            % Find Q', R', p'
            
            f1 = (p^2)*(Q^2);
            f2 = 2*(p^2)*Q*(1 - Q);
            f3 = 2*p*(1 - p)*Q*R;
            f4 = 2*p*(1 - p)*Q*(1 - R);
            f5 = (p^2)*((1-Q)^2);
            f6 = 2*p*(1 - p)*(1 - Q)*R;
            f7 = 2*p*(1 - p)*(1 - Q)*(1 - R);
            f8 = ((1 - p)^2)*(R^2);
            f9 = 2*((1 - p)^2)*R*(1 - R);
            f10 = ((1 - p)^2)*((1 - R)^2);
            
            ACprime = ((f1*((1+SC)^2) + f2*((1+SC)^2)/2 + f3*(1+SC)/2 + f4*((1 - beta_1)/2)*(1+SC) + f6*(beta_1/2)*(1+SC)))/(WBAR^2);
            Acprime = (f3*(1+SC)/2 + f4*(1+SC)*beta_1/2 + f6*(1+SC)*(1-beta_1)/2 + f8 + f9/2)/(WBAR^2);
            
            aCprime = (f2*((1+SC)^2)/2 + f4*((beta_2)/2)*(1+SC) + f5*((1+SC)^2) + f6*(1-beta_2)*(1+SC)/2 + f7*(1+SC)/2)/(WBAR^2);
            acprime = (f4*(1+SC)*(1-beta_2)/2 + f6*(1+SC)*(beta_2)/2 + f7*(1+SC)/2 + f9/2 + f10)/(WBAR^2);
            
            Pprime = ACprime+aCprime;
            Qprime = ACprime/Pprime;
            Rprime = Acprime/(1-Pprime);
            
            % update and record everything
            
            p = Pprime; % update p
            Q = Qprime;
            R = Rprime; 
            
            if Pprime>1-10^-10
                break
            end
            
        end
        
        HH0_cv1(b1,b2) = Q*(1-Q)/(R0*(1-R0));
    end
end

figure(1)
subplot(1,2,1)
%imagesc(b1_vec,b2_vec,HH0_cv1')
hold on 
contourf(log10(b1_vec),log10(b2_vec),HH0_cv1',0:0.05:1)
%0:0.05:1
caxis([0 1])
set(gca,'YDir','normal')
box on 

% subplot(1,2,1)
% imagesc(HH0_cv1)
%set(gca,'YDir','normal')

ylabel('$log(\beta_2)$','interpreter','latex','FontSize',16)
xlabel('$log(\beta_1)$','interpreter','latex','FontSize',16)

colorbar

cb = colorbar(); 
ylabel(cb,'$4\hat{Q}/(1-\hat{Q})$','interpreter','latex','Rotation',270)


hold on 
%scatter(log10(b1_vec),log10b(2_vec),50,'square','r')
plot(log10([0.0001,0.1]),log10([0.0001,0.1]),'r-','LineWidth',1.5)
title('(A) Formulation (i)')


%% gene-culture system, vertical transmission with cultural trait bias

nGen = 10000;

g1_vec = 0:0.01:1; %0.5+(0.1:0.001:0.4);
g2_vec = 0:0.01:1; %0.5-(-0.1:0.001:0.4);%ones(1,numel(g1_vec)).*0.5; %;

HH0_cv2 = zeros(numel(g1_vec),numel(g2_vec));
endPoints = zeros(numel(g1_vec),numel(g2_vec),4);

for g1 = 1:numel(g1_vec)
    gamma_1 = g1_vec(g1);
    
    for g2 = 1:numel(g2_vec)
        gamma_2 = g2_vec(g2);
        
        p0 = 1/N;
        Q0 = 1;
        R0 = 0.5;
        
        p = p0;
        Q = Q0;
        R = R0;
        
        for t = 1:nGen
            
            % Find P'
            
            WBAR =1+p*SC;
            %Pprime = (p*(1+SC)*((1+SC)*p+(1-p)*(gamma_1*(Q+R)+gamma_2*(2-R-Q))))/(WBAR^2);
            
            % Find Q'
            
            f1 = (p^2)*(Q^2);
            f2 = 2*(p^2)*Q*(1 - Q);
            f3 = 2*p*(1 - p)*Q*R;
            f4 = 2*p*(1 - p)*Q*(1 - R);
            f5 = (p^2)*((1-Q)^2);
            f6 = 2*p*(1 - p)*(1 - Q)*R;
            f7 = 2*p*(1 - p)*(1 - Q)*(1 - R);
            f8 = ((1 - p)^2)*(R^2);
            f9 = 2*((1 - p)^2)*R*(1 - R);
            f10 = ((1 - p)^2)*((1 - R)^2);
            
            ACprime = (f1*((1+SC)^2)+(f2/2)*(1+SC)^2+f3*(1+SC)*gamma_1+(f4/2)*(1+SC)*gamma_1+(f6/2)*(1+SC)*gamma_1)/(WBAR^2);
            Acprime = (f3*(1+SC)*(1-gamma_1) + f4*(1+SC)*(1-gamma_1)/2 + f6*(1+SC)*(1-gamma_1)/2 + f8 + f9/2)/(WBAR^2);
            
            aCprime = ((f2/2)*((1+SC)^2)+(f4/2)*(1+SC)*gamma_2+f5*((1+SC)^2)+(f6/2)*gamma_2*(1+SC)+f7*(1+SC)*gamma_2)/(WBAR^2);
            acprime = ((f4/2)*(1-gamma_2)*(1+SC)+(f6/2)*(1-gamma_2)*(1+SC)+f7*(1+SC)*(1-gamma_2)+f9/2+f10)/(WBAR^2);
            
            Pprime = ACprime+aCprime;
            Qprime = ACprime/Pprime;
            Rprime = Acprime/(1-Pprime);
            
            p = Pprime; % update p
            Q = Qprime;
            R = Rprime;
            
            % if there is less than 1 individual or between N-1 and N
            % individuals with P the simulation is over! 
            
            if Pprime>1-10^-10
                break
            elseif Pprime<10^-10
                break
            end
            
        end
        endPoints(g1,g2,:) = [ACprime Acprime aCprime acprime]; 
        HH0_cv2(g1,g2) = Q*(1-Q)/(R0*(1-R0));
    end
end

endPoints_C = endPoints(:,:,1)+ endPoints(:,:,3); 
endPoints_A = endPoints(:,:,1)+endPoints(:,:,2);

HH0_cv2(find(endPoints_C<0.001))=NaN;

figure(1)
subplot(1,2,2)
contourf(g1_vec,g2_vec,HH0_cv2',0:0.05:1)
caxis([0 1])
set(gca,'YDir','normal')
ax=gca;
ax.Color = '#C1C7C9'

colorbar
cb = colorbar(); 
ylabel(cb,'$4\hat{Q}/(1-\hat{Q})$','interpreter','latex','Rotation',270)

%% patch code 
Y = [g2_vec flip(g2_vec)];
EA = endPoints_A'>0.501;
numY = numel(Y);
X = zeros(1,numY);
flag = []; 

for i = 1:numel(g2_vec)
    if numel(find(EA(i,:)))>0
        X(i) = g1_vec(find(EA(i,:)>0,1,'first'));
        X(numY+1-i) = g1_vec(find(EA(i,:)>0,1,'last'));
    else 
        flag = [flag i numY+1-i];
    end
end

X(flag) = [];
Y(flag) = []; 

ht = patch(X,Y,'red');
hatchfill(ht,'single', 95,4)

%

Y = [g2_vec flip(g2_vec)];
EA = endPoints_A'<0.499;

numY = numel(Y);
X = zeros(1,numY);
flag = []; 

for i = 1:numel(g2_vec)
    if numel(find(EA(i,:)))>0
        X(i) = g1_vec(find(EA(i,:)>0,1,'first'));
        X(numY+1-i) = g1_vec(find(EA(i,:)>0,1,'last'));
    else 
        flag = [flag i numY+1-i];
    end
end

X(flag) = [];
Y(flag) = []; 

ht = patch(X,Y,'red');
hatchfill(ht,'cross')

xlim([0 1])
ylim([0 1])

%%
hold on 

plot([0,1],[0,1],'r-','LineWidth',1.5)

ylabel('$\gamma_2$','interpreter','latex','FontSize',16)
xlabel('$\gamma_1$','interpreter','latex','FontSize',16)

title('(B) Formulation (ii)')

figure
subplot(1,2,1)
contourf(g1_vec,g2_vec,endPoints_C','k','LineWidth',1.5)

subplot(1,2,2)
%imagesc(g1_vec,g2_vec,endPoints_A')
hold on 
contourf(g1_vec,g2_vec,endPoints_A','k','LineWidth',1.5)
set(gca, 'YDir','normal')

