function [fig, output] = LaughlinReRevisited() %% Memming re-doing the Laughlin analysis one more time % Given p(x) extracted from the image, find an appropriate % functional transformation y = f(x,theta) and find BEC theta % Laughlin analyzes LMC neuron which responds with graded potential % Infomax solution does histogram equalization... %% Load data raw = load('data/laughlin_memming'); xr = [-1, 1] * 2; raw.commonx(end) = 1; % FIX last point at 1.0157 to be in the figures... %% Setup % Prior distribution is 1D-normal distribution prior.type = 'normal'; prior.mu = raw.theta_ML(1); prior.sigma = raw.theta_ML(2); prior.cdf = @(x) normcdf(x, prior.mu, prior.sigma); prior.pdf = @(x) normcdf(x, prior.mu, prior.sigma); prior.invcdf = @(p) norminv(p, prior.mu, prior.sigma); %% Nonlinearity is quantization, initialize randomly %nQuantizationLevels = 5; %qB = sort(rand(nQuantizationLevels-1, 1)) * 2 - 1; % quantization boundaries %qB = norminv((1:nQuantizationLevels-1)/nQuantizationLevels, prior.mu, prior.sigma); %qB = qB(:) %mse0 = lossFunction(prior, qB, 'MSE') %ent0 = lossFunction(prior, qB, 'Entropy') %[xhat, dhat] = minDistortion(prior, qB, @(x,y) abs(x-y)); %lfh = @(qB) lossFunction(prior, qB, 'mse'); %[qBnew, fval] = fminunc(lfh, qB); %qBnew = sort(qBnew, 'ascend'); %% Randomly sample a few response curves from the "total scatter" given by the error bars % I need strictly increasing samples, so do a little quick and dirty rejection sampling... % nX = numel(raw.commonx); % kSample = 1; % tsR = nan(nX, 20); % while true % %tsRes = raw.ebly + rand(nX,1) .* (raw.ebuy - raw.ebly); % tsRes = raw.ry + randn(nX,1) .* (raw.ebuy - raw.ebly); % if all(diff(tsRes) > 0) % tsR(:, kSample) = tsRes; % if kSample == size(tsR, 2); % break; % end % kSample = kSample + 1; % end % end % plot(raw.commonx, tsR, 'x-', 'Color', [0.5, 0, 1]); %% Probably better to do a multivariate normal sampling with temporal correlations! fig = figure(4021); clf; subplot(1,2,1); hold all; nX = numel(raw.commonx); K = toeplitz(exp(-0.5*(abs(1 - (1:nX))/5).^2)); % some arbitrary smoothness sd = (raw.ebuy - raw.ebly) / 2; % assume error bars are 1 SD (symmetric) % Note that the error bars from the figure are NOT symmetric, so it's % something else... K = K .* (sd(:) * ones(1, nX)) .* (ones(nX, 1) * sd(:)'); tsR = mvnrnd(raw.ry, K, 500); % some rejection sampling for the constraints idx = all(diff(tsR, 1, 2) > 0, 2) & all(tsR > 0, 2) & all(tsR < 1, 2); idx = idx & all(tsR >= ones(size(tsR,1), 1) * raw.ebly(:)', 2); % restrict to bar range! idx = idx & all(tsR <= ones(size(tsR,1), 1) * raw.ebuy(:)', 2); tsR = tsR(idx, :); %plot(raw.commonx, tsR, '-', 'Color', [0.7, 0.5, 1]); fprintf('# of samples left: %d\n', size(tsR,1)); %% set(gca, 'TickDir', 'out'); plot(raw.commonx, raw.ry, 'ko-', 'LineWidth', 2, 'DisplayName', 'response'); llh = line([raw.commonx, raw.commonx]', [raw.ebuy, raw.ebly]', 'Color', 'k'); llhs = get(llh, 'Annotation'); for k = 1:numel(llhs) % disregard from legend set(get(llhs{k}, 'LegendInformation' ), 'IconDisplayStyle', 'off' ); end %plot(raw.cdfx, normcdf(raw.cdfx, prior.mu, prior.sigma), 'k', 'LineWidth', 2, 'DisplayName', 'Normal fit'); xlim([-1, 1]); xlabel('contrast'); ylabel('cdf or normalized response'); %% %f = @(theta,x) normcdf(x, theta(1), theta(2)); %finv = @(theta,y) norminv(y, theta(1), theta(2)); %r_LS = lsqcurvefit(f, [0; 0.5], raw.commonx, raw.ry); % Try Naka-Rushton theta = [a, p, c]; f = @(theta,x) abs(x-theta(1)).^theta(2) ./ (abs(x-theta(1)).^theta(2) + theta(3)); df = @(theta,x) theta(2) .* theta(3) .* abs(x-theta(1)).^(theta(2)-1) ./ (abs(x-theta(1)).^theta(2) + theta(3)).^2; finv = @(theta,y) (theta(1) + ((y*theta(3))./(1-y)).^(1/theta(2))); r_LS = lsqcurvefit(f, [-2; 6; 33], raw.commonx, raw.ry); plot(raw.cdfx, f(r_LS, raw.cdfx), 'b', 'LineWidth', 2, 'DisplayName', 'response (N-R fit)'); %% %xr = linspace(r_LS(1),3,1000); %figure; plot(xr, cumsum(df(r_LS, xr)) .* (xr(2)-xr(1))); %output = truncatedPDFstat(@(x)df(r_LS, x), 0, 0.5); %output = truncatedPDFstat(@(x)normpdf(x, prior.mu, prior.sigma), 0, 0.5); %fig = truncatedNormalStat(prior.mu, prior.sigma, 0, 0.5); %return %% cdf_LS = lsqcurvefit(f, [-2; 6; 7], raw.cdfx, raw.cdfy); plot(raw.cdfx, raw.cdfy, 'k:', 'LineWidth', 2, 'DisplayName', 'CDF'); plot(raw.cdfx, f(cdf_LS, raw.cdfx), 'r', 'LineWidth', 2, 'DisplayName', 'CDF (N-R fit)'); prior2.type = 'sigmoid'; prior2.cdf = @(x) f(cdf_LS, x); prior2.pdf = @(x) df(cdf_LS, x); prior2.invcdf = @(x) finv(cdf_LS, x); prior1 = prior; prior = prior2; %% lh = legend('show'); set(lh, 'Location', 'NorthWest', 'Box', 'off'); subplot(1,2,2); cla; hold all; set(gca, 'TickDir', 'out'); nQuantizationLevels = 25; line([0, 1], [1, 1], 'Color', 'k', 'LineWidth', 1.5, 'LineStyle', ':', 'DisplayName', 'Infomax'); %% assume uniform quantization based on the nonlinearity uqe = (0:nQuantizationLevels) / nQuantizationLevels; qe = interp1(raw.ry, raw.commonx, uqe, 'spline', 'extrap'); qe = finv(r_LS, uqe); %qe = norminv(uqe, r_LS(1), r_LS(2)); % Use the least squares curve fit % WARNING! This spline extrapolation doesn't always work! pchip doesn't % extrapolate very well. So double check!! %plot(uqe, qe) %plot(raw.ry, raw.commonx, 'o') %% estimate the marginal probability z2 = prior.cdf(qe); z2(isnan(z2)) = 1; r_py = diff([z2, 1]) * nQuantizationLevels; %z2_rawcdf = interp1(raw.cdfx, raw.cdfy, qe); %, 'linear', 'extrap'); % plot the corresponding marginal probability xrr = 0:nQuantizationLevels; xrr = xrr / xrr(end-1); %ph10 = stairs(xrr, r_py_normalfit, 'r-', 'LineWidth', 2); %ph10 = plot(xrr, r_py_normalfit, 'r-', 'LineWidth', 2); %ph11 = stairs(xrr, r_py_rawcdf, 'r:', 'LineWidth', 2); ph11 = plot(xrr, r_py, 'r:', 'LineWidth', 2, 'DisplayName', 'LMC response'); ylabel('normalized P(Y)'); xlim([0 1]); set(gca, 'XTick', []); xlabel('quantized y'); drawnow; %% Plot the totall scatter random draws % ts_qe = zeros(numel(uqe), size(tsR,1)); % ts_z2 = ts_qe; % ts_r_py = ts_qe; % for kSample = 1:size(tsR,1) % %ts_qe(:, kSample) = interp1(tsR(kSample, :), raw.commonx, uqe, 'spline', 'extrap'); % %ts_z2(:, kSample) = normcdf(ts_qe(:, kSample), prior.mu, prior.sigma); % ts_qe(:, kSample) = interp1(tsR(kSample, :), raw.commonx, uqe, 'linear', 'extrap'); % ts_z2(:, kSample) = interp1(raw.cdfx, raw.cdfy, ts_qe(:, kSample), 'linear', 'extrap'); % ts_z2(ts_z2 > 1) = 1; % ts_r_py(:, kSample) = diff([ts_z2(:, kSample); 1]) * nQuantizationLevels; % %plot(xrr, ts_r_py, 'x', 'LineWidth', 1, 'Color', [0.7, 0.5, 1]); % end % ebh = errorbar(xrr, nanmean(ts_r_py, 2), nanstd(ts_r_py, [], 2), 'r', 'LineWidth', 2); % ylim([0, 2]); % drawnow %% opts = optimset('display', 'none', 'tolfun', 1e-10, 'maxfunevals', 1e5); %% resulting marginal response P(Y) nQuantizationLevels = 25; for kOptimPlots = 1:100 lossStr = sprintf('l%.2f', kOptimPlots/50); disp(lossStr); %for kOptimPlots = 1:4 % switch kOptimPlots % case 1 % lossStr = 'l0.25'; % case 2 % lossStr = 'l0.5'; % case 3 % lossStr = 'l1'; % case 4 % lossStr = 'l2'; % case 5 % lossStr = 'l4'; % case 6 % lossStr = 'MSE'; % end %qB = sort(rand(nQuantizationLevels-1, 1)) * 2 - 1; % random initial boundaries qB = prior.invcdf((1:nQuantizationLevels-1)'/nQuantizationLevels); lfh = @(qB) lossFunction(prior, qB, lossStr); [qBnew, fval] = fminunc(lfh, qB, opts); fprintf('optimized[%s]: %g\n', lossStr, fval); qBnew = sort(qBnew, 'ascend'); py = qProb(prior, qBnew); output(kOptimPlots).qBnew = qBnew; output(kOptimPlots).py = py; output(kOptimPlots).lossStr = lossStr; output(kOptimPlots).fval = fval; figure(fig); subplot(1,2,1) plotPY_CDF(qBnew, xr, 'LineWidth', 1.5, 'DisplayName', lossStr); subplot(1,2,2); hold all plotPY(py, 'LineWidth', 1.5, 'DisplayName', lossStr); drawnow; end %lh = legend('Infomax', 'LMC neuron', 'L0.5', 'L1', 'L2', 'L4'); lh = legend('show'); set(lh, 'Location', 'South', 'box', 'off'); %% save figure set(fig, 'PaperSize', [10,5], 'PaperPosition', [0,0,10,5]); saveas(fig, [mfilename, '_output_', datestr(now, 30), '.pdf']); saveas(fig, [mfilename, '_output_', datestr(now, 30), '.png']); save([datestr(now, 30), '_output'], 'output', 'xr'); end function ph = plotPY_CDF(qB, xr, varargin) x = [xr(1); qB; xr(end)]; p = (0:(numel(x)-1)) / (numel(x)-1); %ph = stairs(x, p, varargin{:}); ph = plot(x, p, varargin{:}); end function ph = plotPY(p, varargin) % plot marginal probability (scaled to match # of bins) %p = [p(1); p; p(end)]; p = p * numel(p); %ph = stairs((0:numel(p)-1)/(numel(p)-2), p, varargin{:}); ph = plot((0:numel(p)-1)/(numel(p)-1), p, varargin{:}); end function L = lossFunction(prior, qB, typeStr) %% average stats of the posterior qB = sort(qB(:), 'ascend'); py = qProb(prior, qB); if strcmp(prior.type, 'normal') qEdges = [-Inf; qB; Inf]; else qEdges = [prior.invcdf(1e-8); qB; prior.invcdf(1 - 1e-8)]; qEdges = unique(sort(qEdges, 'ascend')); end parfor ky = 1:(numel(qEdges)-1) if strcmp(prior.type, 'normal') poststat(ky) = truncatedNormalStat(prior.mu, prior.sigma, qEdges(ky), qEdges(ky+1)); else poststat(ky) = truncatedPDFstat(prior.pdf, qEdges(ky), qEdges(ky+1)); end end switch lower(typeStr) case 'mse' L = [poststat(:).std].^2 * py; case 'entropy' L = [poststat(:).entropy] * py; otherwise p = sscanf(lower(typeStr), 'l%f'); if isempty(p) error('unknown loss function'); end dh = @(x,y) abs(x - y) .^ p; % l-p norm ^ p [~, dhat] = minDistortion(prior, qB, dh); L = dhat' * py; end end function [xhat, dhat] = minDistortion(prior, qB, dh) % given a distortion dh(x,y), numerically estimate Bayes optimal xhat % for each quantized interval qB qB = sort(qB(:), 'ascend'); %qEdges = [prior.mu - prior.sigma*5; qB; prior.mu + prior.sigma*5]; % 5 sigma bounds qEdges = [prior.invcdf(1e-8); qB; prior.invcdf(1 - 1e-8)]; qEdges = sort(unique(qEdges), 'ascend'); xhat = zeros(numel(qB)+1, 1); dhat = xhat; for ky = 1:(numel(qB)+1) xr = linspace(qEdges(ky), qEdges(ky+1), 200); % row vec pdf = prior.pdf(xr); pdf = pdf(:) / sum(pdf); % normalize & make column vec dxx = zeros(numel(xr), 1); for kx = 1:numel(xr) dxx(kx) = (dh(xr(kx), xr) * pdf); end [dhat(ky), midx] = min(dxx); xhat(ky) = xr(midx); %[xhat(ky), dhat(ky)] = fminunc(@(x) sum(dh(x, xr) .* pdf), (qEdges(ky) + qEdges(ky+1))/2); end end function p = qProb(prior, qB) % computes probability assigned to each quantization level % used to compute the marginal response P(Y) p = prior.cdf(qB); p = diff([0; p(:); 1]); end function stat = truncatedNormalStat(mu, sigma, a, b) % returns various statistics about the truncated normal distribution [a < b] assert(a < b, 'range must be non-empty'); a1 = (a - mu)/sigma; b1 = (b - mu)/sigma; c = log(2*pi)/2 + 0.5; Z = normcdf(b1) - normcdf(a1); stat.mean = mu + (normpdf(a1) - normpdf(b1))/Z*sigma; if mu < a stat.mode = a; elseif mu <= b stat.mode = mu; else stat.mode = b; end if ~isinf(a1) a1a = a1 * normpdf(a1); else a1a = 0; end if normpdf(b1) == 0 b1b = 0; else b1b = b1 * normpdf(b1); end stat.std = sigma * sqrt(1 + (a1a - b1b)/Z - ((normpdf(a1) - normpdf(b1))/Z)^2); stat.entropy = c + log(sigma*Z) + (a1a - b1b) / 2 / Z; end function stat = truncatedPDFstat(pdfh, a, b) % returns various statistics about the truncated distribution [a < b] % specified by its PDF assert(a < b, 'range must be non-empty'); assert(isa(pdfh, 'function_handle'), 'give me a PDF'); xr = linspace(a,b,500); dx = (xr(2) - xr(1)); p = pdfh(xr); assert(all(isnumeric(p) & isfinite(p)), 'invalid density values'); assert(all(p >= 0), 'prob density must be non-negative'); p = p / sum(p); stat.mean = sum(xr .* p); maval = max(p); stat.modes = xr(p == maval); stat.std = sqrt(sum((xr - stat.mean).^2 .* p)); stat.entropy = sum(-p .* log(p)) + log(dx); end