Source code for omni.register

"""Functions for image registration.
"""
import os
import signal
import warnings
import tempfile
import multiprocessing as mp
from typing import Tuple
import numpy as np
import nibabel as nib
from scipy.stats import multivariate_normal as mvn
from scipy.stats import beta
from scipy.optimize import least_squares
from scipy.signal.windows import hamming
from omni import io, preprocessing
from omni.resample import resample
from omni.interface import OmniInterface
from omni.interfaces.common import normalize

# set multiprocessing to spawn due to problematic i/o with fork
try:
    mp.set_start_method("spawn")
except Exception:  # pylint: disable=broad-except
    pass


[docs]def grab_slice_encoding_dir(img: nib.Nifti1Image): """Grabs the slice encoding direction axis from the Nifti image.""" try: return int(bin(img.header.get("dim_info"))[-6:-4], 2) - 1 except ValueError: print("Could not parse the dim info of the NIFTI file. " "Assuming SED is 2...") return 2
[docs]def realign4d( img: nib.Nifti1Image, # pylint: disable=dangerous-default-value ref_idx: int, TR: float, slice_times: list, slice_encode_dir: int, loops: list = [1, 1, 1], speedup: list = [5, 3, 1], borders: list = [1, 1, 1], ) -> Tuple[np.ndarray, nib.Nifti1Image]: """Wrapper for SpaceTimeRealign.""" from nipy.algorithms.registration.groupwise_registration import SpaceTimeRealign # initialize algorithm algo = SpaceTimeRealign(img, TR, slice_times, slice_encode_dir) # estimates parameters print("Initializing estimator...") algo.estimate(loops=loops, align_runs=False, speedup=speedup, refscan=ref_idx, borders=borders) # get rigid body params (undoes the precond applied to param) params = np.array( [i.param * i.precond[i.param_inds] for i in algo._transforms[0]] ) # pylint: disable=protected-access # get resampled image (convert to Nifti1Image) # (leave out 4th time axis in affine) print("Resampling image...") resampled = algo.resample() output = nib.Nifti1Image(resampled[0].get_data(), resampled[0].affine[[0, 1, 2, 4]][:, [0, 1, 2, 4]]) # return rigid body parameters and image output return (params, output)
# apply tone curve to outputs def tone_curve_adj(target: nib.Nifti1Image, synth: nib.Nifti1Image, mask: nib.Nifti1Image = None) -> nib.Nifti1Image: print("Running tone curve adjustment...") # load target target_nm = normalize(target.get_fdata()) # load params on synth def adj_outputs(params): return beta.cdf(normalize(synth.get_fdata()), params[0], params[1]) # mask data if mask: mask_data = (mask.get_fdata().ravel() > 0) * mask.get_fdata().ravel() else: mask_data = np.ones(target_nm.ravel().shape) # do least squares with bounds target_data = target_nm.ravel() * mask_data # create optimizable parameter def opt_adj_outputs(x): return adj_outputs(x).ravel() * mask_data # run least squares result = least_squares( lambda x: target_data - opt_adj_outputs(x), (1, 1), bounds=([0, 0], [np.inf, np.inf]), method="dogbox", tr_solver="exact", verbose=2, ) # return tone adjusted image return nib.Nifti1Image(adj_outputs(result.x), synth.affine, synth.header)
[docs]def register_main( target, interaction, source, regress_mask=None, target_mask=None, source_mask=None, initial_affine=None, output=None, regressed_output=None, blurred_regressed_output=None, aligned_output=None, err_tol=1e-4, step_size=1e-3, max_iterations=200, num_params=12, optimizer="gd", saturate=False, equalize=False, fixed_regress=False, no_register=False, output_intmat=False, sigma=0, bandwidth=0, tone_curve=False, tone_curve_mask=None, ): """Main SynthTarget register routine. TODO: This is a mess of a function... Need to clean it up. """ # check if at least one output argument has been defined if not any([output, regressed_output, blurred_regressed_output, aligned_output]): warnings.warn("No output arguments defined... " "this run will not output any files!") # check number of source images (we can only handle up to 10 at the moment) if len(source) > 10: raise ValueError("Source argument cannot handle more than 10 arguments.") # start up interface oi = OmniInterface() # load/set images target_img = nib.load(target) oi.set_img("target", target_img) # we need to keep alive references to source images for C++ backend s_imgs = list() s_names = list() # keep list of source names # create tempdir for preprocessing images temppath = tempfile.TemporaryDirectory() for n, s in enumerate(source): # read in source images print("Loading Image: {}".format(os.path.basename(s))) s_img = nib.load(s) # load image if saturate: # check normalize option print("Saturating top/bottom 1% of image voxels...") s_img2 = preprocessing.saturate(s_img) # we have to save/load image to file, else C++ backend screws up s_img2.to_filename(os.path.join(temppath.name, "saturate{}.nii.gz".format(n))) s_img = nib.load(os.path.join(temppath.name, "saturate{}.nii.gz".format(n))) if equalize: print("Equalizing image histogram...") s_img2 = preprocessing.equalize(s_img) # we have to save/load image to file, else C++ backend screws up s_img2.to_filename(os.path.join(temppath.name, "equalize{}.nii.gz".format(n))) s_img = nib.load(os.path.join(temppath.name, "equalize{}.nii.gz".format(n))) s_imgs.append(s_img) # store image in array s_names.append("source{}".format(n)) # save image name oi.set_img("source{}".format(n), s_imgs[-1]) # set image in interface # load set masks if availiable if target_mask: target_mask_img = nib.load(target_mask) target_mask_name = "target_mask" abs_target_mask_data = np.abs(target_mask_img.get_fdata()) abs_target_mask_img = nib.Nifti1Image(abs_target_mask_data, target_mask_img.affine, target_mask_img.header) oi.set_img(target_mask_name, abs_target_mask_img) else: target_mask_name = "_NO_MASK_" if source_mask: source_mask_img = nib.load(source_mask) source_mask_name = "source_mask" oi.set_img(source_mask_name, source_mask_img) else: source_mask_name = "_NO_MASK_" if regress_mask: regress_mask_img = nib.load(regress_mask) regress_mask_name = "regress_mask" oi.set_img(regress_mask_name, regress_mask_img) else: regress_mask_name = "_NO_MASK_" # load the interaction model _, ext = os.path.splitext(interaction) if ext == ".model": spec = io.load_interaction_spec(interaction) else: spec = interaction # load initial affine if exists, otherwise initialize to identity if initial_affine: initial_affine = io.read_affine_file(initial_affine)[0] else: initial_affine = np.eye(4) # load fixed regress if exists if fixed_regress: fixed_regress = np.loadtxt(fixed_regress) else: fixed_regress = np.array([0]) # create blurring kernel if specified x = np.mgrid[-10:11, -10:11, -10:11].T if sigma != 0: use_kernel = True cov = np.eye(3) * sigma kernel = mvn.pdf(x, mean=[0, 0, 0], cov=cov) elif bandwidth != 0: use_kernel = True # define epanechnikov kernel def epanechnikov(x, h): e = (1 - (1 / h) * (x[:, :, :, 0] ** 2 + x[:, :, :, 1] ** 2 + x[:, :, :, 2] ** 2) > 0) * ( 1 - (1 / h) * (x[:, :, :, 0] ** 2 + x[:, :, :, 1] ** 2 + x[:, :, :, 2] ** 2) ) return e / np.sum(e.ravel()) # get kernel kernel = epanechnikov(x, bandwidth) else: use_kernel = False kernel = None # run registration routine affine_mat, regress_params = oi.link.registrate( "target", s_names, "regressed", "blurred_regressed", "aligned", spec, initial_affine, fixed_regress, optimizer, transcode=num_params, target_mask=target_mask_name, source_mask=source_mask_name, regress_mask=regress_mask_name, err_tol=err_tol, step_size=step_size, iterations=max_iterations, no_register=no_register, output_intmat=output_intmat, use_kernel=use_kernel, blur_kernel=kernel, ) # save parameters outputs if output: print("Writing parameters...") # save affine parameters io.write_omni_affine(output, affine_mat) # save regression parameters io.write_regress_params(output, regress_params) # get the tone curve mask img if tone_curve_mask: tone_curve_mask_img = nib.load(tone_curve_mask) else: tone_curve_mask_img = None # get the output if regressed_output: print("Writing regressed image...") regressed_img = oi.get_img("regressed", ref=s_imgs[0]) regressed_img.to_filename(regressed_output) # get the blurred output if blurred_regressed_output: print("Writing synthetic image...") blurred_regressed_img = oi.get_img("blurred_regressed", ref=s_imgs[0]) blurred_regressed_img.to_filename(blurred_regressed_output) # get the aligned image if aligned_output: print("Writing synthetic target image...") aligned_img = oi.get_img("aligned", ref=target_img) if tone_curve: # do tone curve adj if enabled if tone_curve_mask_img: aligned_tone_curve_mask_img = resample(aligned_img, tone_curve_mask_img, affine_mat) else: aligned_tone_curve_mask_img = None aligned_img = tone_curve_adj(target_img, aligned_img, aligned_tone_curve_mask_img) aligned_img.to_filename(aligned_output) # clean up temp path temppath.cleanup()
[docs]def register(*args, **kwargs): """There is some kind of memory leak issue in the C++ backend that I can not seem to find... As a workaround, this just starts the registration routine in a separate python process so that it gets cleaned up after it's done. This also sends SIGKILL to the job if the user inputs Ctrl+C to the main process. TODO: Gracefully handle SIGINT calls """ p = mp.Process(target=register_main, args=args, kwargs=kwargs) try: p.start() p.join() except KeyboardInterrupt as e: os.kill(p.pid, signal.SIGKILL) p.join() raise KeyboardInterrupt from e
[docs]def phase_corr_translate(target: nib.Nifti1Image, source: nib.Nifti1Image) -> np.ndarray: """Does phase correlation to register images (Translation only).""" # get data arrays target_data = target.get_fdata() source_data = source.get_fdata() # get resolution res = np.abs(target.affine[0, 0]) # make sure they are the same shape assert ( target_data.shape == source_data.shape ), "Target/Source \ must be same shape!" # create hamming window on data win_x = hamming(target_data.shape[0]) win_y = hamming(target_data.shape[1]) win_z = hamming(target_data.shape[2]) wins = np.meshgrid(win_x, win_y, win_z, indexing="ij") win = np.multiply.reduce(wins) # fft them target_fft = np.fft.fftn(target_data * win) source_fft = np.fft.fftn(source_data * win) # to prevent divide by zero errors, set 0 to a very small value target_fft[target_fft == 0] = 1e-12 source_fft[source_fft == 0] = 1e-12 # take conjugate of the source source_fft_conj = np.conj(source_fft) # now find the normalized cross correlation R = target_fft * source_fft_conj / np.abs(target_fft * source_fft_conj) # inverse R r = np.real(np.fft.ifftn(R)) # get translation t = np.unravel_index(np.argmax(r), target_data.shape) translate = [res * t[0], res * t[1], res * t[2]] # get translation in correct FOV bound_a = np.dot(target.affine, np.array([0, 0, 0, 1])) bound_b = np.dot(target.affine, np.array([target.shape[0] - 1, target.shape[1] - 1, target.shape[2] - 1, 1])) ordered_bounds = [np.sort([bound_a[i], bound_b[i]]) for i in range(3)] for i in range(3): while translate[i] < ordered_bounds[i][0]: translate[i] += res * target.shape[i] while translate[i] > ordered_bounds[i][1]: translate[i] -= res * target.shape[i] # return as omni affine return np.array([[1, 0, 0, translate[0]], [0, 1, 0, translate[1]], [0, 0, 1, translate[2]], [0, 0, 0, 1]])