% CALC_LAMBDA_REGTOOLS: Find optimal hyperparameter by the L-curve (LCC) criterion or the generalized cross-validation (GCV). lambdas = calc_lambda_regtools(imdl, vh, vi, type, doPlot); Output: lambdas - "optimal" hyperparameter(s) determined using LCC or GCV Input: imdl - inverse model (EIDORS struct) vh - homogenous voltage matrix (of size nVtg x 1) vi - inhomogenous voltage matrix (of size nVtg x nFrames) including noise(!) type - type of approach used, either: 'LCC' (default), the L-curve criterion 'GCV', generalized cross-validation doPlot - will enable plotting if set to true (default = false) Example: calc_lambda_regtools('unit_test'); NOTE if vi contains multiple frames the returned values will contain an "optimal" hyperparameter for each frame. An appropriate lambda can then be determined from the average (e.g. median) of these values. See also: RTv4manual.pdf (please note that all page numbers listed correspond to the ones written in the upper right corner, the effective PDF page number will be += 2). Nomenclature: Jacobian J is A; Prior R (not RtR) is L; Voltage v is b Fabian Braun, December 2016 CITATION_REQUEST: AUTHOR: P C Hansen TITLE: Regularization tools version 4.0 for Matlab 7.3. JOURNAL: Numerical algorithms YEAR: 2007 VOL: 46 NUM: 2 PAGE: S189-194 DOI: 10.1007/s11075-007-9136-9
0001 function lambdas = calc_lambda_regtools(imdl, vh, vi, type, doPlot) 0002 %% CALC_LAMBDA_REGTOOLS: Find optimal hyperparameter by the L-curve (LCC) 0003 % criterion or the generalized cross-validation (GCV). 0004 % lambdas = calc_lambda_regtools(imdl, vh, vi, type, doPlot); 0005 % 0006 % Output: 0007 % lambdas - "optimal" hyperparameter(s) determined using LCC or GCV 0008 % 0009 % Input: 0010 % imdl - inverse model (EIDORS struct) 0011 % vh - homogenous voltage matrix (of size nVtg x 1) 0012 % vi - inhomogenous voltage matrix (of size nVtg x nFrames) including noise(!) 0013 % type - type of approach used, either: 0014 % 'LCC' (default), the L-curve criterion 0015 % 'GCV', generalized cross-validation 0016 % doPlot - will enable plotting if set to true (default = false) 0017 % 0018 % Example: 0019 % calc_lambda_regtools('unit_test'); 0020 % 0021 % NOTE 0022 % if vi contains multiple frames the returned values will contain an 0023 % "optimal" hyperparameter for each frame. An appropriate lambda can then 0024 % be determined from the average (e.g. median) of these values. 0025 % 0026 % See also: RTv4manual.pdf (please note that all page numbers listed 0027 % correspond to the ones written in the upper right corner, the effective 0028 % PDF page number will be += 2). 0029 % 0030 % Nomenclature: Jacobian J is A; Prior R (not RtR) is L; Voltage v is b 0031 % 0032 % Fabian Braun, December 2016 0033 % 0034 % CITATION_REQUEST: 0035 % AUTHOR: P C Hansen 0036 % TITLE: Regularization tools version 4.0 for Matlab 7.3. 0037 % JOURNAL: Numerical algorithms 0038 % YEAR: 2007 0039 % VOL: 46 0040 % NUM: 2 0041 % PAGE: S189-194 0042 % DOI: 10.1007/s11075-007-9136-9 0043 % 0044 0045 % (C) 2016 Fabian Braun. License: GPL version 2 or version 3 0046 % $Id: calc_lambda_regtools.m 5540 2017-06-15 11:13:24Z aadler $ 0047 0048 citeme(mfilename); 0049 0050 %% unit testing? 0051 if ischar(imdl) && strcmpi(imdl, 'unit_test') 0052 doUnitTest(); 0053 return; 0054 end 0055 0056 0057 %% set default inputs 0058 if ~exist('type', 'var') || isempty(type) 0059 type = 'LCC'; 0060 end 0061 if ~exist('doPlot', 'var') || isempty(doPlot) 0062 doPlot = false; 0063 end 0064 0065 %% check for existence of the regtools package 0066 if exist('regudemo.m')==2 % file is already on path 0067 % Do nothing. We're OK 0068 elseif exist('./regtools', 'dir') %check if in current folder 0069 addpath('./regtools'); 0070 %%% What should this do? 0071 elseif exist([fileparts(mfilename('fullpath')), filesep, 'regtools']) 0072 addpath([fileparts(mfilename('fullpath')), filesep, 'regtools']) 0073 else 0074 error('Regtools are required but are not available, please download them from <a href="matlab: web http://www.mathworks.com/matlabcentral/fileexchange/52-regtools -browser">File Exchange</a> or <a href="matlab: web http://www2.compute.dtu.dk/~pcha/Regutools/ -browser">P.C. Hansen''s website</a> and store them in the subfolder called ''regtools''. In order to allow for a fast execution it is recommended to disable (uncomment) all plotting functions in l_cuve.m and gcv.m.'); 0075 end 0076 0077 % AA: 3feb2017: Please make changes so 0078 % 1. we don't call get_RM 0079 % 2. we call calc_R_prior 0080 % fix calc_R_prior so it does what you want 0081 % 3. rename to calc_lambda_regtools 0082 % 4. change tutorial to call new name 0083 % 5. Make changes to mk_GREIT_model 0084 % 6. Merge these changes into mainline (if it works) 0085 % OR: delete mainline and svn mv 0086 0087 %% prepare imdl 0088 imdlTmp = imdl; 0089 imdlTmp.prior_use_fwd_not_rec = 0; 0090 % if isfield(imdl.fwd_model,'coarse2fine') 0091 % imdlTmp.fwd_model = rmfield(imdlTmp.fwd_model,'coarse2fine'); 0092 % end 0093 % if isfield(imdl, 'rec_model') && isfield(imdl.rec_model,'coarse2fine') 0094 % imdlTmp.rec_model = rmfield(imdlTmp.rec_model,'coarse2fine'); 0095 % end 0096 img_bkgnd = calc_jacobian_bkgnd(imdlTmp); 0097 A = calc_jacobian(img_bkgnd); 0098 W = calc_meas_icov(imdlTmp); 0099 L = calc_R_prior(imdlTmp); 0100 0101 LtL = calc_RtR_prior(imdlTmp); 0102 LtL_ = L'*L; 0103 % assert(all(LtL_(:) - LtL(:) < 100*eps), 'Prior differs too much!'); 0104 0105 % check that measurement covariance matrix W is identity 0106 assert(isequal(W, speye(size(W)))); 0107 0108 0109 %% (IMPORTANT!) bring generalized to standart form (section 2.6 p.21 of RTv4manual.pdf) 0110 % L-curves and of generalized and standard form are equal this is 0111 % because they have identical norms see (section 2.6.3 p.24 of RTv4manual.pdf) 0112 [A_s, ~, ~] = std_form(A, L, nan(size(vh,1),1)); % as L is square b won't be affected, only A 0113 % [A_s,b_s,L_p,K,M] = std_form(A,L,b); 0114 % NOTE: We need it in standard form as l_curve and gcv routines only accept this 0115 [U_s, s_s] = csvd(A_s); 0116 0117 0118 %% Iterate through all frames to get a range of lambdas 0119 nFrames = size(vi,2); 0120 lambdas = nan(nFrames,1); 0121 0122 progress_msg('Calculating lambda for each frame:', 0, nFrames); 0123 0124 if doPlot 0125 figure(); 0126 end 0127 0128 for iFrame = 1:nFrames 0129 0130 progress_msg(iFrame, nFrames); 0131 0132 %% prepare differential data of current frame 0133 b = vh - vi(:,iFrame); 0134 0135 switch(lower(type)) 0136 case 'lcc' 0137 %% L-curve (see section 2.5 p.20 of RTv4manual.pdf) 0138 % calculate and plot continuous l-curve (documentation on p.83 of RTv4manual.pdf) 0139 lambdas(iFrame) = l_curve(U_s,s_s,b); 0140 0141 % add my own l-curve to plot for validation purposes 0142 if doPlot && (iFrame == nFrames) 0143 lInit = imdl.hyperparameter.value; 0144 lams = flip(logspace(log10(lInit*1E-3), log10(lInit*1E3), 10)); 0145 lams = [lams lInit]; 0146 0147 clear myrho myeta 0148 for i=1:length(lams); 0149 imdl.hyperparameter.value = lams(i); 0150 RM = get_RM( imdl ); 0151 myrho(i) = (norm(A*(RM*b) - b)); 0152 myeta(i) = (norm(L*(RM*b))); 0153 end 0154 0155 % plot it 0156 hold on; 0157 loglog(myrho(1:end-1), myeta(1:end-1), 'ob'); 0158 hold on; 0159 loglog(myrho(end), myeta(end), 'og'); 0160 end 0161 case 'gcv' 0162 %% gcv (see p.37 of RTv4manual.pdf) 0163 % documentation on p.65 of RTv4manual.pdf 0164 lambdas(iFrame) = gcv(U_s,s_s,b); 0165 0166 % plot my own GCV for validation purposes 0167 if doPlot && (iFrame == nFrames) 0168 lInit = imdl.hyperparameter.value; 0169 lams = flip(logspace(log10(lInit*1E-3), log10(lInit*1E3), 10)); 0170 lams = [lams lInit]; 0171 0172 clear myG 0173 for i=1:length(lams); 0174 imdl.hyperparameter.value = lams(i); 0175 RM = get_RM( imdl ); 0176 rho = (norm(A*(RM*b) - b))^2; 0177 myG(i) = rho / (trace(eye(size(RM,2)) - A*RM)^2); 0178 end 0179 0180 % plot it 0181 hold on; 0182 loglog(lams(1:end-1), myG(1:end-1), 'ob'); 0183 hold on; 0184 loglog(lams(end), myG(end), 'og'); 0185 end 0186 otherwise 0187 error('type not supported!'); 0188 end 0189 end 0190 0191 progress_msg('Calculating lambda for each frame:', inf); 0192 0193 end 0194 0195 0196 function doUnitTest() 0197 % inspired by the tutorial mentioned below: 0198 % http://eidors3d.sourceforge.net/tutorial/EIDORS_basics/tutorial110.shtml 0199 % 0200 0201 % Load some data 0202 load iirc_data_2006 0203 0204 stim = mk_stim_patterns(16,1,[0,1],[0,1],{'meas_current'},1); 0205 0206 for iRun = 1 0207 % for iRun = [0 1] 0208 % Get a 2D image reconstruction model 0209 if iRun 0210 % more advanced 3D model which includes coarse2fine mapping which 0211 % makes all crash 0212 fmdl = mk_library_model('adult_male_16el'); 0213 fmdl.stimulation = stim; 0214 opts = []; 0215 opts.noise_figure = 0.5; 0216 imdl = mk_GN_model(fmdl, opts, []); 0217 else 0218 % simple one 0219 imdl= mk_common_model('c2c'); 0220 imdl.fwd_model.stimulation = stim; 0221 imdl.fwd_model = rmfield( imdl.fwd_model, 'meas_select'); 0222 imdl.RtR_prior = @prior_tikhonov; 0223 end 0224 0225 % load the real data 0226 vi = real(v_rotate)/1e4; vh = real(v_reference)/1e4; 0227 % allow double precision, else we run into (unexplainable) problems 0228 vi = double(vi); vh = double(vh); 0229 0230 % get the hyperparameter value via L-curve 0231 figure 0232 lambdas_lcc = calc_lambda_regtools(imdl,vh,vi,'LCC',true); 0233 0234 % get the hyperparameter value via GCV 0235 lambdas_gcv = calc_lambda_regtools(imdl,vh,vi,'GCV',true); 0236 0237 % visualize 0238 FramesOfInterest = [10 35 60 85]; 0239 fig = figure(1 + iRun); 0240 subplot(121); 0241 imdl.hyperparameter.value = median(lambdas_lcc); 0242 imgs_lcc = inv_solve(imdl, vh, vi(:,FramesOfInterest)); 0243 imgs_lcc.show_slices.img_cols = 1; 0244 show_slices(imgs_lcc); 0245 title('L-curve'); 0246 0247 subplot(122); 0248 imdl.hyperparameter.value = median(lambdas_gcv); 0249 imgs_gcv = inv_solve(imdl, vh, vi(:,FramesOfInterest)); 0250 imgs_gcv.show_slices.img_cols = 1; 0251 show_slices(imgs_gcv); 0252 title('GCV'); 0253 0254 end 0255 0256 end