function [fitps, best_fitps, best_pM] = grtfit_st(oM,inits,fp,nfits,dbtype,mod_id); % function [fitps, best_fitps, best_pM] % = grtfit_st(oM,inits,fp,nfits,dbtype,mod_id); % % fitps = structure containing fitted parameters and fit statistic % oM = confusion matrices (4x4x5 array) % inits = structure containing initial parameter values % fp = structure of logicals indicating which parameters are free % nfits = number of accepted fits wanted % dbtype = 'linear' or 'piecewise' or 'maxpost' % mod_id = name of model being fit (e.g., ns_m, ns_PI, etc...) % % see file for more details % fields in inits must be identical to those in fitps, namely % % fitps = struct('maa',{},'mab',{},'mba',{},'mbb',{},'caa',{},'cab',{},'cba',{},'cbb',{},... % 'cx',{},'dx',{},'cy',{},'dy',{},'LL',{},'g2',{},'bic',{}); % % fp fields are as follows % % fp = struct('cx',{},'dx',{},'cy',{},'dy',{},'vxa',{},'vxb',{},'vya',{},'vyb',{},... % 'raa',{},'rab',{},'rba',{},'rbb',{},... % 'mxa',{},'mxb',{},'mya',{},'myb',{},'mab',{},'mba',{},'mbb',{},... % 'vaa',{},'vab',{},'vba',{},'vbb',{}); % % 1 = it can vary; 0 = it can't: % bounds: 'cx',{},'dx',{},'cy',{},'dy',{} % (co)var: 'raa',{},'rab',{},'rba',{},'rbb',{},'vaa',{},'vab',{},'vba',{},'vbb',{}); % means: 'mab',{},'mba',{},'mbb',{},... % % 1 = fixed across levels; 0 = not fixed across levels: % variance: 'vxa',{},'vxb',{},'vya',{},'vyb',{},... % means: 'mxa',{},'mxb',{},'mya',{},'myb',{},... % % if strcmp(dbtype,'linear'), the c and d parameters are slope and % intercept, respectively, when free % if strcmp(dbtype,'piecewise'), the c and d parameters are intercepts, % when free. % if strcmp(dbtype,'maxpost'), the a, b, c, and d parameters are baserate % 'scaling' parameters, when free. % % change log % most recent changes made on 9.14.08 - cleaning up base rate artifacts... timeseed = 1; if timeseed rand('state',sum(100*clock)); else rand('state',5); % for debugging, checking variations in zpro (see below) end if strcmp(dbtype,'linear') fitps = struct('maa',{},'mab',{},'mba',{},'mbb',{},'caa',{},'cab',{},'cba',{},'cbb',{},... 'cx',{},'dx',{},'cy',{},'dy',{},'g2',{},'LL',{},'bic',{}); elseif strcmp(dbtype,'maxpost') fitps = struct('maa',{},'mab',{},'mba',{},'mbb',{},'caa',{},'cab',{},'cba',{},'cbb',{},... 'waa',{},'wab',{},'wba',{},'wbb',{},'g2',{},'LL',{},'bic',{}); end % initial parameter settings if strcmp(dbtype,'linear') || strcmp(dbtype,'piecewise') npar = 14; % 6 m, 4 c, 4 db elseif strcmp(dbtype,'maxpost') npar = 13; % 6 m, 4 c, 3 'db' end maa = [0; 0]; % not free m = [maa inits.mab inits.mba inits.mbb]; if fp.mxa m(1,2) = m(1,1); npar = npar - 1; end if fp.mxb m(1,4) = m(1,3); npar = npar - 1; end if fp.mya m(2,3) = m(2,1); npar = npar - 1; end if fp.myb m(2,4) = m(2,2); npar = npar - 1; end c = zeros(2,2,4); for i = 1:2 for j = 1:2 if i==j c(i,j,1) = 1; c(i,j,2) = 1; c(i,j,3) = 1; c(i,j,4) = 1; else c(i,j,1) = inits.caa(i,j); c(i,j,2) = inits.cab(i,j); c(i,j,3) = inits.cba(i,j); c(i,j,4) = inits.cbb(i,j); end end end if ~fp.raa c(1,2,1) = 0; c(2,1,1) = 0; npar = npar - 1; end if ~fp.rab c(1,2,2) = 0; c(2,1,2) = 0; npar = npar - 1; end if ~fp.rba c(1,2,3) = 0; c(2,1,3) = 0; npar = npar - 1; end if ~fp.rbb c(1,2,4) = 0; c(2,1,4) = 0; npar = npar - 1; end if strcmp(dbtype,'maxpost') dbp = [inits.waa inits.wab inits.wba inits.wbb]; else % 9.14.08 the baserate algorithm - grtfit.m - can implement cubic and % quadratic terms in decision bounds; these terms are zero here. dbp(1,:) = [0 0 inits.cx inits.dx 0]; dbp(2,:) = [0 0 inits.cy inits.dy 0]; end [nstim,nresp,nbr] = size(oM); if strcmp(dbtype,'linear') if ~fp.cx dbp(1,3) = 0; npar = npar - 1; end if ~fp.cy dbp(2,3) = 0; npar = npar - 1; end if ~fp.dx dbp(1,4) = mean(m(2,:)); npar = npar - 1; end if ~fp.dy dbp(2,4) = mean(m(1,:)); npar = npar - 1; end elseif strcmp(dbtype,'maxpost') % $$ the Ns in npar = npar - N; expressions below will need to be changed mp_free = 0; mp_feat = 0; mp_br = 0; if ~fp.ay & ~fp.by & ~fp.ax & ~fp.bx if fp.cy & fp.dy & fp.cx & fp.dx % feat = feature; g = general; 'upper middle' model % other dimension's 'neutral' parameters can vary from .5 mp_feat = 1; npar = npar - 1; % $$ maybe need to fix the size of feat_params... feat_params = zeros(2,2); else mp_br = 1; % br = base rate = distribution weight; restricted model npar = npar - 3; end else mp_free = 1; % each distribution gets its own weight end end if npar > 12 error(['Too many free parameters: ' num2str(npar)]) elseif npar == 12 disp('12 free parameters makes this whole endeavor questionable, you know.') else disp(['N free parameters = ' num2str(npar)]) end % implementing BIC calculation; added 5/30/07 big_N = sum(sum(oM)); complexity = npar*log(big_N); % bin widths bin = .05; cf = -1e15; % zero likelihood, used to avoid unnecessary calculations with obviously bad parameter settings % initialization of fit statistics g2_new = -cf; g2_old = inits.g2; g2_min = g2_old; bic_new = -cf; bic_old = inits.bic; bic_min = bic_old; LL_new = cf; LL_old = inits.LL; LL_max = LL_old; best_fitps = inits; %best_pM = predcon(best_fitps,dbtype,0,fp,oM); % initialize # of fits n = 0; % initialize assorted other things pv = zeros(4,1); n_rejected = 0; zerofix = '1in4'; if strcmp(zerofix,'1in4') zpro = .25; elseif strcmp(zerofix,'loglinear') zpro = .5; elseif strcmp(zerofix,'1in100') zpro = .01; elseif strcmp(zerofix,'1/N') zpro = 1/sum(sum(sum(oM))); else error('A zero-cell correction must be specified.') % maybe add some other options? end oM = oM + zpro; Nr_oM = zeros(4,1); for i = 1:4 Nr_oM(i) = sum(oM(i,:)); end % initialize dbtype maxpost baserate scaling parameters if strcmp(dbtype,'maxpost') if sum(dbp(1,1:4))>1 c_N = sum(Nr_oM(:)); c_w = Nr_oM(:)/c_N; dbp(1,1) = c_w(1); dbp(1,2) = c_w(2); dbp(1,3) = c_w(3); dbp(1,4) = c_w(4); end % $$ this will certainly need to be changed, but I don't really % remember how it works right now, so I can't rightly fix it... if mp_feat feat_params(1,1) = dbp(1,1)/( dbp(1,1) + dbp(1,2) ); feat_params(1,2) = dbp(2,1)/( dbp(2,1) + dbp(2,2) ); feat_params(2,:) = 1-feat_params(1,:); end end js_fact = 1; % scales jumps in parameter space while n <= nfits disp(mod_id) disp(['working on fit ' num2str(n) ' out of ' num2str(nfits)]) %LL_new = []; % jumping distribution funtime party hour %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% % want to make it adaptive in some % % sense, e.g., smaller jumps for a % % bit before going big when 'stuck'% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% big_jump = 3; wee_jump = .25; if n_rejected > 9 if n_rejected < 20 js_m = wee_jump*.01*js_fact; % scales the jumping distributions for means js_v = wee_jump*.01*js_fact; % scales the js for variances js_r = wee_jump*.01*js_fact; % scales the js for covariances if strcmp(dbtype,'maxpost') js_a = wee_jump*.01*js_fact; % scales the js for the cubic terms js_b = wee_jump*.01*js_fact; % scales the js for the quadratic terms js_c = wee_jump*.01*js_fact; % scales the js for the slope terms js_d = wee_jump*.01*js_fact; % scales the js for the intercepts else js_a = wee_jump*.2*js_fact; % scales the js for the cubic terms js_b = wee_jump*.2*js_fact; % scales the js for the quadratic terms js_c = wee_jump*.01*js_fact; % scales the js for the slope terms js_d = wee_jump*.01*js_fact; % scales the js for the intercepts end js_s = wee_jump*.15*js_fact; % scales the js for the offset terms else js_m = big_jump*.02*js_fact; % scales the jumping distributions for means js_v = big_jump*.02*js_fact; % scales the js for variances js_r = big_jump*.02*js_fact; % scales the js for covariances if strcmp(dbtype,'maxpost') js_a = big_jump*.01*js_fact; % scales the js for the cubic terms js_b = big_jump*.01*js_fact; % scales the js for the quadratic terms js_c = big_jump*.01*js_fact; % scales the js for the slope terms js_d = big_jump*.01*js_fact; % scales the js for the intercepts else js_a = big_jump*.2*js_fact; % scales the js for the cubic terms js_b = big_jump*.2*js_fact; % scales the js for the quadratic terms js_c = big_jump*.02*js_fact; % scales the js for the slope terms js_d = big_jump*.02*js_fact; % scales the js for the intercepts end js_s = big_jump*.15*js_fact; % scales the js for the offset terms accept = 1; end end if n_rejected == 0 js_m = .01*js_fact; % scales the jumping distributions for means js_v = .01*js_fact; % scales the js for variances js_r = .01*js_fact; % scales the js for covariances (uses betainv) if strcmp(dbtype,'maxpost') js_a = .01*js_fact; % scales the js for the cubic terms js_b = .01*js_fact; % scales the js for the quadratic terms js_c = .01*js_fact; % scales the js for the slope terms js_d = .01*js_fact; % scales the js for the intercepts else js_a = .01*js_fact; % scales the js for the cubic terms js_b = .01*js_fact; % scales the js for the quadratic terms js_c = .01*js_fact; % scales the js for the slope terms js_d = .01*js_fact; % scales the js for the intercepts end js_s = .05*js_fact; % scales the js for the offset terms accept = 0; % logical that determines whether or not to accept the current fit end m_old = m; c_old = c; % $$ again, don't really remember how the maxpost models work; will likely % need to be changed just to clean things up... dbp_old = dbp; % means if fp.mab if ~fp.mxa m(1,2) = m(1,2) + js_m*randn(1); % m_xab free end end if fp.mba & fp.mbb if ~fp.mxb m(1,3) = m(1,3) + js_m*randn(1); % new m_xba m(1,4) = m(1,4) + js_m*randn(1); % new m_xbb else m(1,3) = m(1,3) + js_m*randn(1); % new m_xb m(1,4) = m(1,3); % only one m_xb end end if fp.mba if ~fp.mya m(2,3) = m(2,3) + js_m*randn(1); % m_yba free end end if fp.mab & fp.mbb if ~fp.myb m(2,2) = m(2,2) + js_m*randn(1); % new m_yab m(2,4) = m(2,4) + js_m*randn(1); % new m_ybb else m(2,2) = m(2,2) + js_m*randn(1); % new m_yb m(2,4) = m(2,2); % only one m_yb end end % covariance if fp.raa raa = c(1,2,1)/sqrt(c(1,1,1)*c(2,2,1)); end if fp.rab rab = c(1,2,2)/sqrt(c(1,1,2)*c(2,2,2)); end if fp.rba rba = c(1,2,3)/sqrt(c(1,1,3)*c(2,2,3)); end if fp.rbb rbb = c(1,2,4)/sqrt(c(1,1,4)*c(2,2,4)); end if fp.raa %probe_r = 0; %while probe_r == 0 raa_new = raa + js_r*randn(1); if raa_new <= -1 raa_new = -.9999; elseif raa_new >= 1 raa_new = .9999; end %if raa_new < 1 & raa_new > -1 raa = raa_new; %probe_r = 1; %end %end c(1,2,1) = raa*sqrt(c(1,1,1)*c(2,2,1)); c(2,1,1) = c(1,2,1); end if fp.rab %probe_r = 0; %while probe_r == 0 rab_new = rab + js_r*randn(1); if rab_new <= -1 rab_new = -.9999; elseif rab_new >= 1 rab_new = .9999; end %if raa_new < 1 & raa_new > -1 rab = rab_new; %probe_r = 1; %end %end c(1,2,2) = rab*sqrt(c(1,1,2)*c(2,2,2)); c(2,1,2) = c(1,2,2); end if fp.rba %probe_r = 0; %while probe_r == 0 rba_new = rba + js_r*randn(1); if rba_new <= -1 rba_new = -.9999; elseif rba_new >= 1 rba_new = .9999; end %if raa_new < 1 & raa_new > -1 rba = rba_new; %probe_r = 1; %end %end c(1,2,3) = rba*sqrt(c(1,1,3)*c(2,2,3)); c(2,1,3) = c(1,2,3); end if fp.rbb %probe_r = 0; %while probe_r == 0 rbb_new = rbb + js_r*randn(1); if rbb_new <= -1 rbb_new = -.9999; elseif rbb_new >= 1 rbb_new = .9999; end %if raa_new < 1 & raa_new > -1 rbb = rbb_new; %probe_r = 1; %end %end c(1,2,4) = rbb*sqrt(c(1,1,4)*c(2,2,4)); c(2,1,4) = c(1,2,4); end % 'horizontal' decision bound parameters if strcmp(dbtype,'maxpost') if mp_free dbp = abs(dbp + js_a*randn(1,4)); dbp = dbp/sum(dbp); elseif mp_feat feat_params = abs(feat_params + js_a*randn(2,2)); normalizer = sum(feat_params,1); for i = 1:2 feat_params(i,:) = feat_params(i,:)./normalizer; end % feat_params structure % (1,1) = stim.1 weight % (1,2) = stim1. weight % (2,1) = stim.2 weight % (2,2) = stim2. weight dbp(1,1) = feat_params(1,1)*feat_params(1,2); dbp(1,2) = feat_params(1,2)*feat_params(2,1); dbp(1,3) = feat_params(1,1)*feat_params(2,2); dbp(1,4) = feat_params(2,1)*feat_params(2,2); dbp = dbp/sum(dbp); elseif mp_br % nothing end elseif strcmp(dbtype,'linear') if fp.cx dbp(1,3) = dbp(1,3) + js_c*randn(1); end if fp.dx dbp(1,4) = dbp(1,4) + js_d*randn(1); end % 'vertical' decision bound parameters if fp.cy dbp(2,3) = dbp(2,3) + js_c*randn(1); end if fp.dy dbp(2,4) = dbp(2,4) + js_d*randn(1); end end % conditions placed on the parameters % mean vectors maa mab mba mbb if m(1,3) < m(1,1) || m(1,4) < m(1,2) || ... m(2,2) < m(2,1) || m(2,4) < m(2,3) pv(1) = 1; elseif m(1,2) < -3 || m(1,3) > 5 || m(1,4) > 5 || ... m(2,2) > 4 || m(2,3) < -3 || m(2,4) > 5 pv(1) = 1; else pv(1) = 0; end % covariance, variance for i = 1:4 if abs(c(1,2,i))>( sqrt(c(1,1,i)) * sqrt(c(2,2,i)) ) pv(2) = 1; break elseif sqrt(c(1,1,i)) > 3 || sqrt(c(2,2,i)) > 3 pv(2) = 1; break else pv(2) = 0; end end if ~strcmp(dbtype,'maxpost') & (dbp(1,5) > 0 || dbp(2,5) > 0)%(dbx1p(5) > 0 || dbp(1,5) > 0 || dbx5p(5) > 0 ... %|| dby2p(5) > 0 || dbp(2,5) > 0 || dby4p(5) > 0) pv(4) = 1; else pv(4) = 0; end if pv(1) disp('parameter violation: means in disarray') LL_new = cf; elseif pv(2) disp('parameter violation: covariance too big') LL_new = cf; elseif pv(3) disp('parameter violation: variance too big') LL_new = cf; elseif pv(4) disp('parameter violation: bound offset > 0') LL_new = cf; else % define limits on space, densities situated within [dist,x,y] = bivar_norm(m,c,bin); % define decision regions dec = logical(dec_reg(dbtype,dist,x,y,dbp,nbr)); pM = zeros(nstim,nresp); binsize = bin^2; %disp(['calculating predicted confusion probabilities']) for s = 1:nstim % stimulus pdf = dist(:,:,s); for r = 1:nresp % response pM(s,r) = max(binsize*sum( sum( pdf( dec(:,:,r) ) ) ),eps); end end for s = 1:4 pM(s,:) = pM(s,:)/sum(pM(s,:)); end efM = zeros(4,4); for i = 1:4 efM(i,:) = pM(i,:)*Nr_oM(i); end LL_M = zeros(4,4); LL_M = oM.*log(pM); LL_new = sum(sum(LL_M(:,:))); bic_new = -2*LL_new + complexity; g2_new = 2*sum(sum( oM.*log(oM./efM) )); % 2*(LL_sat - LL_new); end % if pv(i)==1, etc... if n==0 g2_min = g2_new; bic_min = bic_new; LL_max = LL_new; bic_old = bic_new; g2_old = g2_new; LL_old = LL_new; n = n+1; best_pM = pM; best_fitps.maa = m(:,1); best_fitps.mab = m(:,2); best_fitps.mba = m(:,3); best_fitps.mbb = m(:,4); best_fitps.caa = c(:,:,1); best_fitps.cab = c(:,:,2); best_fitps.cba = c(:,:,3); best_fitps.cbb = c(:,:,4); if strcmp(dbtype,'maxpost') best_fitps.waa = dbp(1,1); best_fitps.wab = dbp(1,2); best_fitps.wba = dbp(1,3); best_fitps.wbb = dbp(1,4); else best_fitps.cx = dbp(1,3); best_fitps.dx = dbp(1,4); best_fitps.cy = dbp(2,3); best_fitps.dy = dbp(2,4); end best_fitps.g2 = g2_new; best_fitps.bic = bic_new; best_fitps.LL = LL_new; else if bic_new < bic_old accept = 1; end if bic_new < bic_min g2_min = g2_new; bic_min = bic_new; LL_max = LL_new; best_pM = pM; best_fitps.maa = m(:,1); best_fitps.mab = m(:,2); best_fitps.mba = m(:,3); best_fitps.mbb = m(:,4); best_fitps.caa = c(:,:,1); best_fitps.cab = c(:,:,2); best_fitps.cba = c(:,:,3); best_fitps.cbb = c(:,:,4); if strcmp(dbtype,'maxpost') best_fitps.waa = dbp(1,1); best_fitps.wab = dbp(1,2); best_fitps.wba = dbp(1,3); best_fitps.wbb = dbp(1,4); elseif strcmp(dbtype,'linear') || strcmp(dbtype,'piecewise') best_fitps.cx = dbp(1,3); best_fitps.dx = dbp(1,4); best_fitps.cy = dbp(2,3); best_fitps.dy = dbp(2,4); end best_fitps.g2 = g2_new; best_fitps.bic = bic_new; best_fitps.LL = LL_new; end end if accept % put the parameters in the fitps struct fitps(n).maa = m(:,1); fitps(n).mab = m(:,2); fitps(n).mba = m(:,3); fitps(n).mbb = m(:,4); fitps(n).caa = c(:,:,1); fitps(n).cab = c(:,:,2); fitps(n).cba = c(:,:,3); fitps(n).cbb = c(:,:,4); if strcmp(dbtype,'maxpost') fitps(n).waa = dbp(1,1); fitps(n).wab = dbp(1,2); fitps(n).wba = dbp(1,3); fitps(n).wbb = dbp(1,4); else fitps(n).cx = dbp(1,3); fitps(n).dx = dbp(1,4); fitps(n).cy = dbp(2,3); fitps(n).dy = dbp(2,4); end fitps(n).g2 = g2_new; fitps(n).bic = bic_new; fitps(n).LL = LL_new; bic_old = bic_new; g2_old = g2_new; LL_old = LL_new; n = n+1; % increment the index for the current fit n_rejected = 0; else m = m_old; c = c_old; dbp = dbp_old; n_rejected = n_rejected + 1; end %disp(['LL_sat = ' num2str(LL_sat)]) %disp(['LL_max = ' num2str(LL_max)]) %disp(['LL_old = ' num2str(LL_old)]) %disp(['LL_new = ' num2str(LL_new)]) if n==1 || n==50 || n==100 || n==150 disp(['complexity = ' num2str(complexity)]) end disp(['bic_min = ' num2str(bic_min)]) disp(['bic_old = ' num2str(bic_old)]) disp(['bic_new = ' num2str(bic_new)]) %disp(['LL_dif = ' num2str(LL_sat - LL_new)]) disp(['n_rejected = ' num2str(n_rejected)]) disp(' ') end