% mkfig_LinGaussEncodersDists.m logdet = @(A) 2*sum(log(diag(chol(A)))); alph = .9; % corr coefficient d = 2; % dimensionality Q = [1 alph;alph 1]*.5; [u,s,v] = svd(Q); % R = u*diag([.1 .1])*u'; % c = trace(Q)*(7); % constraint (arbitrary) eps = .1; R = eps*Q; c = trace(Q); %*(1-6*eps); % infomax I = eye(2); Ami = sqrtm(c/d*I-R)*inv(sqrtm(Q)); % varmin M1 = sqrtm(R*Q); Amv = sqrtm(c/(trace(M1))*M1-R)*inv(sqrtm(Q)); % stdmin M2 = (R*Q)^(2/5); Ams = sqrtm(c/(trace(M2))*M2-R)*inv(sqrtm(Q)); % p = 2; M4 = (R*Q)^(2/3); A4 = sqrtm(c/(trace(M4))*M4-R)*inv(sqrtm(Q)); Cy1 = Ami*Q*Ami'+ R; % Always identity! Cy2 = Amv*Q*Amv'+ R; Cy3 = Ams*Q*Ams'+ R; Cy4 = A4*Q*A4'+ R; c - [trace(Cy1), trace(Cy2), trace(Cy3), trace(Cy4)] %% ---- Make figs ------ nsamps = 100; x = mvnrnd(zeros(nsamps,2),Q); NOISEON = 1; if NOISEON % Noisy encoding y1 = x*Ami + mvnrnd(zeros(nsamps,2),R); y2 = x*Amv + mvnrnd(zeros(nsamps,2),R); y3 = x*Ams + mvnrnd(zeros(nsamps,2),R); y4 = x*A4 + mvnrnd(zeros(nsamps,2),R); else % Noiseless encoding y1 = x*Ami; y2 = x*Amv; y3 = x*Ams; y4 = x*A4; end % dothis = ['axis equal; axis square; axis(4*[-1 1 -1 1])']; % % subplot(151); % plot(x(:,1),x(:,2),'k.'); % eval(dothis); % subplot(152); % plot(y1(:,1),y1(:,2),'k.'); % eval(dothis); % subplot(154); % plot(y2(:,1),y2(:,2),'k.'); % eval(dothis); % subplot(153); % plot(y3(:,1),y3(:,2),'k.'); % eval(dothis); % subplot(155); % plot(y4(:,1),y4(:,2),'k.'); % eval(dothis); %% Make density plots colormap bone; xrnge = 3; dx = .02; x0 = -xrnge+dx/2:dx:xrnge; [xx,yy] = meshgrid(x0); nd = size(xx,1); Dx = reshape(mvnpdf([xx(:),yy(:)],[0 0],Q),[nd nd]); D1 = reshape(mvnpdf([xx(:),yy(:)],[0 0],Cy1),[nd nd]); D2 = reshape(mvnpdf([xx(:),yy(:)],[0 0],Cy2),[nd nd]); D3 = reshape(mvnpdf([xx(:),yy(:)],[0 0],Cy3),[nd nd]); D4 = reshape(mvnpdf([xx(:),yy(:)],[0 0],Cy4),[nd nd]); %dothis = ['axis image; axis xy; axis(2.5*[-1 1 -1 1]);set(gca,''ytick'',2*[-1 0 1],''xtick'',2*[-1 0 1]);']; dothis = ['axis image; axis xy; axis(2*[-1 1 -1 1]);set(gca,''ytick'',[],''xtick'',[]);']; ms = 8; subplot(251); imagesc(x0,x0,Dx); hold on; plot(x(:,1),x(:,2),'b.','markersize',ms); hold off; eval(dothis); xlabel('stim axis 1'); ylabel('stim axis 2'); subplot(252); imagesc(x0,x0,D1); hold on; plot(y1(:,1),y1(:,2),'r.','markersize',ms); hold off; eval(dothis); subplot(254); imagesc(x0,x0,D2); hold on; plot(y2(:,1),y2(:,2),'r.','markersize',ms); hold off; eval(dothis); subplot(253); imagesc(x0,x0,D3); hold on; plot(y3(:,1),y3(:,2),'r.','markersize',ms); hold off; eval(dothis); subplot(255); imagesc(x0,x0,D4); hold on; plot(y4(:,1),y4(:,2),'r.','markersize',ms); hold off; eval(dothis); figdims = [8 6]; set(gcf,'papersize',[figdims+.25], ... 'paperposition', [.1 .1 figdims+.1]); %print -dpdf figs/BEC_LinGaussianEncoderDists.pdf %% Compute posterior covs L1 = inv(Ami'*inv(R)*Ami+inv(Q)); L2 = inv(Amv'*inv(R)*Amv+inv(Q)); L3 = inv(Ams'*inv(R)*Ams+inv(Q)); L4 = inv(A4'*inv(R)*A4+inv(Q)); dothis = ['axis image; axis xy; axis(1.5*[-1 1 -1 1]);set(gca,''ytick'',[],''xtick'',[]);']; subplot(257); DL1 = reshape(mvnpdf([xx(:),yy(:)],[0 0],L1),[nd nd]); imagesc(x0,x0,DL1); eval(dothis); subplot(259); DL2 = reshape(mvnpdf([xx(:),yy(:)],[0 0],L2),[nd nd]); imagesc(x0,x0,DL2); eval(dothis); subplot(258); DL3 = reshape(mvnpdf([xx(:),yy(:)],[0 0],L3),[nd nd]); imagesc(x0,x0,DL3); eval(dothis); subplot(2,5,10); DL4 = reshape(mvnpdf([xx(:),yy(:)],[0 0],L4),[nd nd]); imagesc(x0,x0,DL4); eval(dothis); MIs = .5*(logdet(Q)-[logdet(L1),logdet(L2),logdet(L3),logdet(L4)]); vrs = [trace(L1),trace(L2),trace(L3), trace(L4)]; stds = [trace(sqrtm(L1)),trace(sqrtm(L2)),trace(sqrtm(L3)),trace(sqrtm(L4))]; kurts = [trace((L1^2)),trace((L2^2)),trace((L3^2)),trace((L4^2))]; MIs = MIs/max(MIs) vrs = vrs/min(vrs) stds = stds/min(stds) ks = kurts./min(kurts)