function x = spm_mireg(VG,VF,params)
% Between modality coregistration using Mutual Information
% FORMAT x = spm_mireg(VG,VF,params)
% VG - handle for first image (see spm_vol).
% VF - handle for second image.
% x - the parameters describing the rigid body rotation.
%     such that a mapping from voxels in G to voxels in F
%     is attained by:  VF.mat\spm_matrix(x(:)')*VG.mat
% params - a cell array.
%          params{1} - optimisation sampling steps
%          params{2}(1) - smoothing for VG (FWHM mm)
%          params{2}(2) - smoothing for VF (FWHM mm)
%          params{3} - starting estimates (6 elements)
%
% The registration method used here is based on the work described in:
% A Collignon, F Maes, D Delaere, D Vandermeulen, P Suetens & G Marchal
% (1995) "Automated Multi-modality Image Registration Based On
% Information Theory". In the proceedings of Information Processing in
% Medical Imaging (1995).  Y. Bizais et al. (eds.).  Kluwer Academic
% Publishers.
%
% The mutual Information is essentially given by:
% H  = H/(sum(H(:))+eps);
% s1 = sum(H,1);
% s2 = sum(H,2);
% H  = H.*log2((H+eps)./(s2*s1+eps));
% mi = sum(H(:));
%
% where H is a 256x256 histogram, and mi is the mutual information.
% As an attempt to improve the convergence properties, the algorithm
% minimises exp(-mi).  This is what gets plotted for each line
% minimisation of the optimisation.
%
% The optimisation has been taken from "Numerical Recipes in C"
% (1992, 2nd Ed.), by WH Press, SA Teukolsky, WT Vetterling &
% BP Flannery.
%
% At the end, the voxel-to-voxel affine transformation matrix is
% displayed, along with the histograms for the images in the original
% orientations, and the final orientations.  The registered images are
% displayed at the bottom.
%_______________________________________________________________________
% @(#)spm_mireg.m	2.1 John Ashburner 99/08/18

if nargin < 1, VG = spm_vol(spm_get(1,'*.img','Select reference image')); end;
if nargin < 2, VF = spm_vol(spm_get(1,'*.img','Select moved image')); end;

if nargin <3, params = cell(0); end;
if length(params)<1 | length(params{1}) < 1,
	params{1} = [3];
	if isglobal('sptl_MIStps'),
		global sptl_MIStps;
		params{1} = sptl_MIStps;
	end;
end;
if length(params)<2 | length(params{2}) ~= 2,
	params{2} = [0 0]';
end;
if length(params)<3 | length(params{3}) ~= 6,
	params{3} = [0 0 0  0 0 0]';
end;

if ~isfield(VG, 'uint8'),
	VG.uint8 = loaduint8(VG);
	if params{2}(1) ~= 0, VG=smooth_uint8(VG,params{2}(1)); end;
end;
if ~isfield(VF, 'uint8'),
	VF.uint8 = loaduint8(VF);
	if params{2}(2) ~= 0, VF=smooth_uint8(VF,params{2}(2)); end;
end;

sc = [1 1 1 0.01 0.01 0.01]';
x  = params{3}(:);
xi = eye(6);

for samp=params{1}(:)',
	s  = max([1 1 1],round(samp*[1 1 1]./sqrt(sum(VG.mat(1:3,1:3).^2))));
	[x,fval,xi] = powell(x(:), xi,1e-4,'optfun',VG,VF,sc,s);
	% [x,fval] = neldermead(x,'optfun',VG,VF,sc,s);
	x = (x.*sc)';
	display_results(VG,VF,x);
end;
return;
%_______________________________________________________________________

%_______________________________________________________________________
function o = optfun(x,VG,VF,sc,s)
% The function that is minimised.  exp( - mutual_information)
if nargin<4, sc=[1 1 1]; end;
if nargin<5, s=ones(size(x)); end;
x   = x.*sc;
H   = spm_hist2(VG.uint8,VF.uint8,VF.mat\spm_matrix(x(:)')*VG.mat,s);
% krn = exp(-([-4:4].^2)/4); krn = krn/sum(krn);
% H   = conv2(H,krn); H   = conv2(H,krn');
mi  = mifromhist(H);
o   = exp(-mi); % Try to make cost function more quadratic
fprintf('%-8.4g%-8.4g%-8.4g | %-8.4f%-8.4f%-8.4f || %.5g\n',[x(:)' mi]);
return;
%_______________________________________________________________________

%_______________________________________________________________________
function mi = mifromhist(H)
% Compute the mutual information from the scatterplot
H  = H/(sum(H(:))+eps);
s1 = sum(H,1);
s2 = sum(H,2);
H  = H.*log2((H+eps)./(s2*s1+eps));
mi = sum(H(:));
return;
%_______________________________________________________________________

%_______________________________________________________________________
function udat = loaduint8(V)
% Load data from file indicated by V into an array of unsigned bytes.
if size(V.pinfo,2)==1 & V.pinfo(1) == 2,
	mx = 255*V.pinfo(1) + V.pinfo(2);
	mn = V.pinfo(2);
else,
	spm_progress_bar('Init',V.dim(3),...
		['Computing max/min of ' spm_str_manip(V.fname,'t')],...
		'Planes complete');
	mx = -Inf; mn =  Inf;
	for p=1:V.dim(3),
		img = spm_slice_vol(V,spm_matrix([0 0 p]),V.dim(1:2),1);
		mx  = max([max(img(:)) mx]);
		mn  = min([min(img(:)) mn]);
		spm_progress_bar('Set',p);
	end;
end;
spm_progress_bar('Init',V.dim(3),...
	['Loading ' spm_str_manip(V.fname,'t')],...
	'Planes loaded');

udat = uint8(0);
udat(V.dim(1),V.dim(2),V.dim(3))=0;
for p=1:V.dim(3),
	img = spm_slice_vol(V,spm_matrix([0 0 p]),V.dim(1:2),1);
	udat(:,:,p) = uint8(round((img-mn)*((256-1)/(mx-mn))+1));
	spm_progress_bar('Set',p);
end;
spm_progress_bar('Clear');
return;
%_______________________________________________________________________
%_______________________________________________________________________
function V=smooth_uint8(V,fwhm)
% Convolve the volume in memory.
s = sqrt(sum(V.mat(1:3,1:3).^2)).^(-1)*(fwhm/sqrt(8*log(2)));
x  = round(6*s(1)); x = [-x:x];
y  = round(6*s(2)); y = [-y:y];
z  = round(6*s(3)); z = [-z:z];
x  = exp(-(x).^2/(2*(s(1)).^2));
y  = exp(-(y).^2/(2*(s(2)).^2));
z  = exp(-(z).^2/(2*(s(3)).^2));
x  = x/sum(x);
y  = y/sum(y);
z  = z/sum(z);

i  = (length(x) - 1)/2;
j  = (length(y) - 1)/2;
k  = (length(z) - 1)/2;
spm_conv_vol(V.uint8,V.uint8,x,y,z,-[i j k]);
return;
%_______________________________________________________________________

%_______________________________________________________________________
function [x,fval] = neldermead(x,func,varargin)
% Unused code for doing Nelder-Mead Simplex Optimization.  This was
% based on the Matlab routine "fmins".  I have kept it in because it
% may be useful some time.

n       = length(x);
rho     = 1; chi = 2; psi = 0.5; sigma = 0.5;
onesn   = ones(1,n);
two2np1 = 2:n+1;
one2n   = 1:n;
xin     = x(:);
v       = zeros(n,n+1); fv = zeros(1,n+1);
v       = xin;
x(:)    = xin; 
fv      = feval(func,x,varargin{:});

delta   = 1;
for j = 1:n,
	y    = xin;
	y(j) = y(j)+delta;
	v(:,j+1) = y;
	x(:) = y; fv(1,j+1) = feval(func,x,varargin{:});
end;

[fv,j] = sort(fv);
v      = v(:,j);

while 1,
	if max(max(abs(v(:,two2np1)-v(:,onesn)))) <= 1e-4 & ...
		max(abs(fv(1)-fv(two2np1))) <= 1e-4,
		break;
	end;
	how = '';
   	xbar   = sum(v(:,one2n), 2)/n;
	xr     = (1 + rho)*xbar - rho*v(:,end);
	x(:)   = xr; fxr = feval(func,x,varargin{:});
   
	if fxr < fv(:,1),
		xe = (1 + rho*chi)*xbar - rho*chi*v(:,end);
		x(:) = xe; fxe = feval(func,x,varargin{:});
		if fxe < fxr,
			v(:,end) = xe;
			fv(:,end) = fxe;
			how = 'expand';
		else,
			v(:,end) = xr; 
			fv(:,end) = fxr;
			how = 'reflect';
		end;
	else,
		if fxr < fv(:,n),
			v(:,end) = xr; 
			fv(:,end) = fxr;
			how = 'reflect';
		else,
			if fxr < fv(:,end),
				xc = (1 + psi*rho)*xbar - psi*rho*v(:,end);
				x(:) = xc; fxc = feval(func,x,varargin{:});
				if fxc <= fxr,
					v(:,end) = xc; 
					fv(:,end) = fxc;
					how = 'contract outside';
				else,
					how = 'shrink';
				end;
			else,
				xcc = (1-psi)*xbar + psi*v(:,end);
				x(:) = xcc; fxcc = feval(func,x,varargin{:});
				if fxcc < fv(:,end),
					v(:,end) = xcc;
					fv(:,end) = fxcc;
					how = 'contract inside';
				else,
					% perform a shrink
					how = 'shrink';
				end;
			end;
			if strcmp(how,'shrink'),
				for j=two2np1
					v(:,j)=v(:,1)+sigma*(v(:,j) - v(:,1));
					x(:) = v(:,j); fv(:,j) = feval(func,x,varargin{:});
				end;
			end;
		end;
	end;
	[fv,j] = sort(fv);
	v = v(:,j);
end;
x(:) = v(:,1);
fval = min(fv);
return;
%_______________________________________________________________________

%_______________________________________________________________________
function [p,fret,xi] = powell(p,xi,ftol,func,varargin)
% Powell optimisation method - taken from Numerical Recipes (p. 417) and
% modified slightly.
p=p(:);
ITMAX = 32;
fret  = feval(func,p,varargin{:});
pt    = p;
for iter=1:ITMAX,
	fp   = fret;
	ibig = 0;
	del  = 0.0;
	for i=1:length(p),
		fptt = fret;
		[p,xit,fret] = linmin(p,xi(:,i),func,varargin{:});
		if abs(fptt-fret) > del,
			del  = abs(fptt-fret);
			ibig = i;
		end;
	end;
	if 2.0*abs(fp-fret) <= ftol*(abs(fp)+abs(fret)),
		return;
	end;
	ptt  = 2.0*p-pt;
	xit  = p-pt;
	pt   = p;
	fptt = feval(func,ptt,varargin{:});
	if fptt < fp,
		t = 2.0*(fp-2.0*fret+fptt)*(fp-fret-del).^2-del*(fp-fptt).^2;
		if t < 0.0,
			[p,xit,fret] = linmin(p,xit,func,varargin{:});
			xi(:,ibig)   = xi(:,end);
			xi(:,end)    = xit;
		end;
	end;
end;
warning('Too many iterations in routine POWELL');
return;
%_______________________________________________________________________

%_______________________________________________________________________
function [p,xi,fret] = linmin(p,xi,func,varargin)
% Code based on Numerical Recipes in C (p. 419)
global lnm
lnm = struct('pcom',p,'xicom',xi,'func',func,'args',[]);
lnm.args = varargin;
ax    = 0.0;
xx    = 1.0;
bx    = 2.0;
linmin_plot('Init', 'Line Minimisation','exp(-MutualInformation)','Parameter Value');
[ax,xx,bx,fa,fx,fb] = mnbrak(ax,xx);
[fret,xmin] = brent(ax,xx,bx,fx,2.0e-3);
xi    = xi * xmin;
p     = p + xi;
linmin_plot('Clear');
return;
%_______________________________________________________________________

%_______________________________________________________________________
function f = f1dim(x)
% Code based on Numerical Recipes in C (p. 419)
global lnm
xt = lnm.pcom+x.*lnm.xicom;
f = feval(lnm.func,xt,lnm.args{:});
linmin_plot('Set',x,f);
return;
%_______________________________________________________________________

%_______________________________________________________________________
function [ax,bx,cx,fa,fb,fc] = mnbrak(ax,bx)
% Code based on Numerical Recipes in C (p. 400)
GOLD   = 1.618034;
GLIMIT = 100.0;
TINY   = 1.0e-20;

fa=f1dim(ax);
fb=f1dim(bx);

if fb > fa
	dum = ax; ax = bx; bx = dum;
	dum = fb; fb = fa; fa = dum;
end;
cx = bx+GOLD*(bx-ax);
fc = f1dim(cx);
while fb > fc,
	r    = (bx-ax)*(fb-fc);
	q    = (bx-cx)*(fb-fa);
	u    = bx-((bx-cx)*q-(bx-ax)*r)/(2.0*(abs(q-r)+TINY)*sign(q-r));
	ulim = bx+GLIMIT*(cx-bx);
	if (bx-u)*(u-cx) > 0.0,
		fu=f1dim(u);
		if fu < fc,
			ax = bx; bx =  u;
			fa = fb; fb = fu;
			return;
		elseif fu > fb,
			cx = u;
			fc = fu;
			return;
		end;
		u  = cx+GOLD*(cx-bx);
		fu = f1dim(u);
	elseif (cx-u)*(u-ulim) > 0.0
		fu=f1dim(u);
		if fu < fc,
			bx = cx; cx = u; u = cx+GOLD*(cx-bx);
			fb = fc; fc = fu; fu = f1dim(u);
		end;
	elseif (u-ulim)*(ulim-cx) >= 0.0,
		u  = ulim;
		fu = f1dim(u);
	else,
		u  = cx+GOLD*(cx-bx);
		fu = f1dim(u);
	end;
	ax = bx; bx = cx; cx = u;
	fa = fb; fb = fc; fc = fu;
end;
return;
%_______________________________________________________________________

%_______________________________________________________________________
function [fx, x] = brent(ax,bx,cx,fx, tol)
% Code based on Numerical Recipes in C (p. 404)
ITMAX = 100;
CGOLD = 0.3819660; % 1-(1-sqrt(5))/2
e = 0.0;
a = min(ax,cx);
b = max(ax,cx);
x = bx; w = bx; v = bx;
fw = fx;
fv = fx;
for iter=1:ITMAX,
	xm   = 0.5*(a+b);
	tol1 = 2e-4*abs(x)+eps;
	tol2 = 2.0*tol1;
	if abs(x-xm) <= tol,
		return;
	end;
	if abs(e) > tol1,
		r     = (x-w)*(fx-fv);
		q     = (x-v)*(fx-fw);
		p     = (x-v)*q-(x-w)*r;
		q     = 2.0*(q-r);
		if q > 0.0, p = -p; end;
		q     = abs(q);
		etemp = e;
		e     = d;
		if abs(p) >= abs(0.5*q*etemp) | p <= q*(a-x) | p >= q*(b-x),
			if x >= xm, e = a-x; else, e = b-x; end;
			d = CGOLD*(e);
		else,
			d = p/q;
			u = x+d;
			if u-a < tol2 | b-u < tol2,
				d = tol1*sign(xm-x);
			end;
		end;
	else,
		if x>=xm, e = a-x; else, e = b-x; end;
		d = CGOLD*e;
	end;
	if abs(d) >= tol1, u = x+d; else, u = x+tol1*sign(d); end;
	fu=f1dim(u);
	if fu <= fx,
		if u >= x, a=x; else, b=x; end;
		 v =  w;  w =  x;  x =  u;
		fv = fw; fw = fx; fx = fu;
	else,
		if u < x, a=u; else, b=u; end;
		if fu <= fw | w == x,
			 v  = w;  w =  u;
			fv = fw; fw = fu;
		elseif fu <= fv | v == x | v == w,
			 v =  u;
			fv = fu;
		end;
	end;
end;
warning('Too many iterations in BRENT');
return;
%_______________________________________________________________________

%_______________________________________________________________________
function linmin_plot(action,arg1,arg2,arg3,arg4)
% Visual output for line minimisation
global linminplot
%-----------------------------------------------------------------------
if (nargin == 0)
	linmin_plot('Init');
else
	% initialize
	%---------------------------------------------------------------
	if (strcmp(lower(action),'init'))
		if (nargin<4)
			arg3 = 'Function';
			if (nargin<3)
				arg2 = 'Value';
				if (nargin<2)
					arg1 = 'Line minimisation';
				end
			end
		end
		fg = spm_figure('FindWin','Interactive');
		if ~isempty(fg)
			linminplot = struct('pointer',get(fg,'Pointer'),'name',get(fg,'Name'),'ax',[]);
			linmin_plot('Clear');
			set(fg,'Pointer','watch');
			% set(fg,'Name',arg1);
			linminplot.ax = axes('Position', [0.15 0.1 0.8 0.75],...
				'Box', 'on','Parent',fg);
			lab = get(linminplot.ax,'Xlabel');
			set(lab,'string',arg3,'FontSize',10);
			lab = get(linminplot.ax,'Ylabel');
			set(lab,'string',arg2,'FontSize',10);
			lab = get(linminplot.ax,'Title');
			set(lab,'string',arg1);
			line('Xdata',[], 'Ydata',[],...
				'LineWidth',2,'Tag','LinMinPlot','Parent',linminplot.ax,...
				'LineStyle','none','Marker','o');
			drawnow;
		end

	% reset
	%---------------------------------------------------------------
	elseif (strcmp(lower(action),'set'))
		F = spm_figure('FindWin','Interactive');
		br = findobj(F,'Tag','LinMinPlot');
		if (~isempty(br))
			xd = [get(br,'Xdata') arg1];
			yd = [get(br,'Ydata') arg2];
			set(br,'Ydata',yd,'Xdata',xd);
			drawnow;
		end

	% clear
	%---------------------------------------------------------------
	elseif (strcmp(lower(action),'clear'))
		fg = spm_figure('FindWin','Interactive');
		if isstruct(linminplot),
			if ishandle(linminplot.ax), delete(linminplot.ax); end;
			set(fg,'Pointer',linminplot.pointer);
			set(fg,'Name',linminplot.name);
		end;
		spm_figure('Clear',fg);
		drawnow;
	end;
end
%_______________________________________________________________________

%_______________________________________________________________________
function display_results(VG,VF,x)
fig = spm_figure('FindWin','Graphics');
if isempty(fig), return; end;
set(0,'CurrentFigure',fig);
spm_figure('Clear','Graphics');

% Display text
%-----------------------------------------------------------------------
ax = axes('Position',[0.1 0.8 0.8 0.15],'Visible','off','Parent',fig);
text(0.5,0.7, 'Mutual Information Coregistration','FontSize',16,...
	'FontWeight','Bold','HorizontalAlignment','center','Parent',ax);

Q = inv(VF.mat\spm_matrix(x(:)')*VG.mat);
text(0,0.5, sprintf('X1 = %0.3f*X %+0.3f*Y %+0.3f*Z %+0.3f',Q(1,:)),'Parent',ax);
text(0,0.3, sprintf('Y1 = %0.3f*X %+0.3f*Y %+0.3f*Z %+0.3f',Q(2,:)),'Parent',ax);
text(0,0.1, sprintf('Z1 = %0.3f*X %+0.3f*Y %+0.3f*Z %+0.3f',Q(3,:)),'Parent',ax);

% Display scatter-plots
%-----------------------------------------------------------------------
ax  = axes('Position',[0.1 0.5 0.35 0.3],'Visible','off','Parent',fig);
H   = spm_hist2(VG.uint8,VF.uint8,VF.mat\VG.mat,[1 1 1]);
tmp = log(H+1);
image(tmp*(64/max(tmp(:))),'Parent',ax');
set(ax,'DataAspectRatio',[1 1 1],...
	'PlotBoxAspectRatioMode','auto','XDir','normal','YDir','normal',...
	'XTick',[],'YTick',[]);
title('Original Histogram','Parent',ax);
xlabel(spm_str_manip(VG.fname,'k22'),'Parent',ax);
ylabel(spm_str_manip(VF.fname,'k22'),'Parent',ax);

H   = spm_hist2(VG.uint8,VF.uint8,VF.mat\spm_matrix(x(:)')*VG.mat,[1 1 1]);
ax  = axes('Position',[0.6 0.5 0.35 0.3],'Visible','off','Parent',fig);
tmp = log(H+1);
image(tmp*(64/max(tmp(:))),'Parent',ax');
set(ax,'DataAspectRatio',[1 1 1],...
	'PlotBoxAspectRatioMode','auto','XDir','normal','YDir','normal',...
	'XTick',[],'YTick',[]);
title('Final Histogram','Parent',ax);
xlabel(spm_str_manip(VG.fname,'k22'),'Parent',ax);
ylabel(spm_str_manip(VF.fname,'k22'),'Parent',ax);

% Display ortho-views
%-----------------------------------------------------------------------
spm_orthviews('Reset');
h1 = spm_orthviews('Image',VG.fname,[0.01 0.01 .48 .49]);
h2 = spm_orthviews('Image',VF.fname,[.51 0.01 .48 .49]);
global st
st.vols{h2}.premul = inv(spm_matrix(x(:)'));
spm_orthviews('Space',h1);

return;
