import os
from pathlib import Path
import logging
import shutil
from threading import local
from typing import List
import nibabel as nib
from scipy.ndimage import gaussian_filter
from memori.pathman import append_suffix, repath
from memori.helpers import create_output_path, create_symlink_to_path
from omni.interfaces.common import create_fov_masks
from omni.interfaces.afni import Allineate, NwarpApply, NwarpCat, resample
from omni.interfaces.fsl import flirt
from omni.interfaces.ants import antsRegistration, antsApplyTransform
from omni.preprocessing import clahe, localized_contrast_enhance
from omni.io import convert_affine_file
from omni.masks import make_regression_mask
from omni.register import register
# Directory this file lives in
THISDIR = os.path.dirname(os.path.abspath(__file__))
# Get atlas directory
ATLASDIR = os.path.join(os.path.dirname(THISDIR), "atlas")
[docs]@create_output_path
def align_affine_epi_to_anat(
output_path: str,
ref_epi_bet: str,
initial_synth_model: str = "rbf(0;4)+rbf(1;4)+rbf(0;4)*rbf(1;4)",
t1_bet: str = None,
t2_bet: str = None,
program: str = "fsl",
bandwidth: int = 16,
initial_affine: str = None,
skip_affine: bool = False,
skip_synthtarget_affine: bool = False,
resolution_pyramid: List[float] = [4, 2, 1],
max_iterations: List[int] = [2000, 500, 100],
err_tol: List[float] = [1e-4, 1e-4, 5e-4],
step_size: List[float] = [1e-3, 1e-3, 1e-3],
):
"""Affine align func to anat
Parameters
----------
output_path: str
Output path to write out files to.
ref_epi_bet: str
Reference EPI skullstripped.
initial_synth_model: str
Synth model.
t1_bet: str
T1 skullstripped.
t2_bet: str
T2 skullstripped.
bandwidth: int
Bandwidth for blurring kernel.
program: str
Program to use for affine alignment.
initial_affine: str
Initial affine to use instead of using internal flirt alignment.
skip_affine: bool
Skip affine alignment step.
skip_synthtarget_affine: bool
Skip synthtarget affine alignment step.
resolution_pyramid: List[float]
Resampling pyramid to use for affine alignment (mm).
max_iterations: List[int]
Max iterations for each SynthTarget call.
err_tol: List[float]
Error tolerance level for each SynthTarget call.
step_size: List[float]
Step size for gradient descent.
Returns
-------
str
Final func to anat affine (afni).
str
Final anat to func affine (afni).
"""
# resample ref_func_bet images to specified image pyramid
ref_epi_bet_res = list()
for res in resolution_pyramid:
res_str = ("%s" % res).replace(".", "d")
ref_epi_bet_res.append(append_suffix(repath(output_path, ref_epi_bet), "_%smm" % res_str))
resample(ref_epi_bet_res[-1], ref_epi_bet, ref_epi_bet, dxyz=res)
# resample anat images to specified image pyramid
t1_bet_res = list()
t2_bet_res = list()
for res in resolution_pyramid:
res_str = ("%s" % res).replace(".", "d")
t1_bet_res.append(append_suffix(repath(output_path, t1_bet), "_%smm" % res_str))
t2_bet_res.append(append_suffix(repath(output_path, t2_bet), "_%smm" % res_str))
resample(t1_bet_res[-1], t1_bet, t1_bet, dxyz=res)
resample(t2_bet_res[-1], t2_bet, t2_bet, dxyz=res)
# do not do initial affine alignment call if skip enabled
# or if initial affine is already provided
if not skip_affine and initial_affine is None:
# create a symlinks to func and T2 for reference
t2_bet_link = create_symlink_to_path(t2_bet, output_path)
ref_bet_link = create_symlink_to_path(ref_epi_bet, output_path)
# get an initial affine transform guess
initial_func_to_t2_xfm = os.path.join(output_path, "initial_func_to_t2.aff12.1D")
initial_func_to_t2 = os.path.join(output_path, "initial_func_to_t2.nii.gz")
if program == "afni":
Allineate(
initial_func_to_t2,
t2_bet_link,
ref_bet_link,
cost="mi",
warp="shift_rotate",
matrix_save=initial_func_to_t2_xfm,
twopass=True,
)
elif program == "fsl": # Use flirt
initial_func_to_t2_flirt = os.path.join(output_path, "initial_func_to_t2.mat")
flirt(
initial_func_to_t2,
t2_bet_link,
ref_bet_link,
out_matrix=initial_func_to_t2_flirt,
dof=6,
cost="corratio",
)
convert_affine_file(
initial_func_to_t2_xfm, initial_func_to_t2_flirt, "afni", target=t2_bet_link, source=ref_bet_link
)
else:
raise ValueError("Invalid parameter set for program. " "Must be either 'afni' or 'fsl'.")
# convert affine file to omni format and invert
initial_t2_to_func_xfm = os.path.join(output_path, "initial_t2_to_func.affine")
convert_affine_file(initial_t2_to_func_xfm, initial_func_to_t2_xfm, "omni", invert=True)
# initialize the initial affine to use
initial_affine = initial_t2_to_func_xfm
elif skip_affine and initial_affine is None: # if initial affine was not defined, but skip enabled
initial_affine = None
# get the resolution of the final resolution pyramid
anat_res = resolution_pyramid[-1]
# run SynthTarget
output_pathobj = Path(output_path)
synthtarget_params_res = list()
synthtarget_affine_res = list()
synthtarget_synthetic_res = list()
for i, res in enumerate(resolution_pyramid):
res_str = ("%s" % res).replace(".", "d")
synthtarget_params_res.append(str(output_pathobj / ("synthtarget_params_%smm" % res_str)))
synthtarget_affine_res.append(str(output_pathobj / ("synthtarget_params_%smm.affine" % res_str)))
synthtarget_synthetic_res.append(str(output_pathobj / ("synthtarget_%smm.nii.gz" % res_str)))
register( # SynthTarget
ref_epi_bet_res[i],
initial_synth_model,
[t1_bet_res[i], t2_bet_res[i]],
output=synthtarget_params_res[i],
aligned_output=synthtarget_synthetic_res[i],
initial_affine=initial_affine,
max_iterations=max_iterations[i],
err_tol=err_tol[i],
step_size=step_size[i],
bandwidth=bandwidth / (res / anat_res),
no_register=skip_affine or skip_synthtarget_affine,
)
# set next initial affine
initial_affine = synthtarget_affine_res[i]
# convert affine to afni
final_epi_to_anat_affine = str(output_pathobj / "final_epi_to_anat.aff12.1D")
final_anat_to_epi_affine = str(output_pathobj / "final_anat_to_epi.aff12.1D")
convert_affine_file(final_epi_to_anat_affine, synthtarget_affine_res[-1], "afni", invert=True)
convert_affine_file(final_anat_to_epi_affine, synthtarget_affine_res[-1], "afni", invert=False)
# return affine
return (final_epi_to_anat_affine, final_anat_to_epi_affine)
[docs]@create_output_path
def distortion_correction(
output_path: str,
epi_moco: str,
ref_epi: str,
t1: str,
t2: str,
anat_bet_mask: str,
anat_weight_mask: str,
final_anat_to_epi_affine: str,
final_epi_to_anat_affine: str,
final_synth_model: str = "rbf(0;12)+rbf(1;12)+rbf(0;12)*rbf(1;12)",
bandwidth: int = 16,
resample_resolution: float = 1,
sigma_t2: float = 0.5,
use_initializing_warp: bool = False,
skip_synth_warp: bool = False,
contrast_enhance_method: str = "lce",
initial_warp_field: str = None,
SyN_step_size: List[float] = [3, 1, 0.1],
SyN_smoothing: str = "2x1x0x0",
SyN_shrink_factors: str = "4x3x2x1",
SyN_field_variance: int = 0,
noise_mask_dilation_size: int = 2,
noise_mask_iterations: int = 20,
noise_mask_sigma: float = 2,
warp_direction: str = "none",
autobox_mask: str = None,
):
"""Distortion correction.
Parameters
----------
output_path: str
Output path to write out files to.
epi_moco : str
EPI image that has been motion corrected/framewise aligned.
ref_epi: str
Reference EPI image.
t1: str
T1.
t2: str
T2.
anat_bet_mask: str
Anatomical brain mask.
anat_weight_mask: str
Anatomical weight mask.
final_anat_to_epi_affine: str
Final anat to epi affine (afni).
final_epi_to_anat_affine: str
Final epi to anat affine (afni).
final_synth_model: str
Synth model.
bandwidth: int
Bandwidth for blurring kernel.
resample_resolution: float
Resample resolution space to do warps on (mm).
sigma_t2: float
Parameter to smooth T2 for initial warp.
use_initializing_warp: bool
Uses initializing warp from the T2 initial warp.
skip_synth_warp: bool
Skip the synth warp step, will only use initial T2 warp solution.
contrast_enhance_method: str
Contrast enhancement method to use, by default "lce".
initial_warp_field: str
Uses this file as an initial warp field instead of computing it, this should be from ref_epi -> T2.
SyN_step_size : List[float]
Set the gradient descent step size for each iteration of warp.
SyN_smoothing: str
Smoothing kernel size for each level of optimization.
SyN_shrink_factors: str
Resampling factor for each level of optimization.
SyN_field_variance: int
Field variance for SyN.
noise_mask_dilation_size : int
Dilation size for noise mask.
noise_mask_iterations: int
Number of iterations to run noise mask LDA.
noise_mask_sigma: float
Size of gaussian smoothing kernel for noise mask.
warp_direction: str
Warp direction
autobox_mask: int
mask defining the autobox
Returns
-------
str
Final synthetic to EPI warp (Distort synthetic to match EPI)
str
Final EPI to synthetic warp (Undistort EPI to match synthetic)
"""
# get warp direction
restrict_deform = {"x": "1x0x0", "y": "0x1x0", "z": "0x0x1", "none": "1x1x1"}[warp_direction]
# create resample resolution string
res_str = ("%s" % resample_resolution).replace(".", "d")
# resample images
ref_epi_res = append_suffix(repath(output_path, ref_epi), "_%smm" % res_str)
t1_res = append_suffix(repath(output_path, t1), "_%smm" % res_str)
t2_res = append_suffix(repath(output_path, t2), "_%smm" % res_str)
resample(ref_epi_res, ref_epi, ref_epi, dxyz=resample_resolution)
resample(t1_res, t1, t1, dxyz=resample_resolution)
resample(t2_res, t2, t2, dxyz=resample_resolution)
# run lce on ref epi
ref_epi_res_img = nib.load(ref_epi_res)
if contrast_enhance_method == "clahe":
ref_epi_lce_res_img = clahe(ref_epi_res_img)
ref_epi_lce_res = append_suffix(repath(output_path, ref_epi_res), "_clahe")
elif contrast_enhance_method == "lce":
anat_bet_mask_epispace = append_suffix(repath(output_path, anat_bet_mask), "_epispace")
Allineate(anat_bet_mask_epispace, ref_epi_res, anat_bet_mask, matrix_apply=final_anat_to_epi_affine)
anat_bet_mask_epispace_img = nib.load(anat_bet_mask_epispace)
ref_epi_lce_res_img = localized_contrast_enhance(ref_epi_res_img, anat_bet_mask_epispace_img, nfrac=0.005)
ref_epi_lce_res = append_suffix(repath(output_path, ref_epi_res), "_lce")
ref_epi_lce_res_img.to_filename(ref_epi_lce_res)
# get t2 in space of epi
t2_res_epispace = append_suffix(t2_res, "_epispace")
Allineate(t2_res_epispace, ref_epi_lce_res, t2_res, matrix_apply=final_anat_to_epi_affine)
# smooth the t2
t2_res_epispace_img = nib.load(t2_res_epispace)
t2_res_smooth_data = gaussian_filter(t2_res_epispace_img.get_fdata(), sigma_t2)
t2_res_epispace_smooth = append_suffix(t2_res_epispace, "_smooth")
nib.Nifti1Image(t2_res_smooth_data, t2_res_epispace_img.affine, t2_res_epispace_img.header).to_filename(
t2_res_epispace_smooth
)
# create_fov_masks
ref_fov_mask = os.path.join(output_path, "ref_fov_mask.nii.gz")
source_fov_mask = os.path.join(output_path, "source_fov_mask.nii.gz")
create_fov_masks(ref_epi_lce_res, t2_res, final_anat_to_epi_affine, ref_fov_mask, source_fov_mask)
# if an autobox was used on functional, then we need to account for it
if autobox_mask:
# resample to the ref_fov_mask
resample(ref_fov_mask, ref_epi_lce_res, autobox_mask, rmode="NN")
# setup initial warp names
initial_warp_prefix = os.path.join(output_path, "initial_")
initial_0warp = initial_warp_prefix + "0Warp.nii.gz"
initial_0iwarp = initial_warp_prefix + "0InverseWarp.nii.gz"
# if an initial warp field is provided, use it
if initial_warp_field:
# resample the initial warp field to the working resolution
resample(initial_0iwarp, ref_epi_lce_res, initial_warp_field)
# get the inverse of the iwarp to get the warp
NwarpCat(initial_0warp, initial_0iwarp, True)
else: # get the initial warps running a crude nonlinear SyN registration
# run initial ants warp
antsRegistration(
initial_warp_prefix,
ref_epi_lce_res,
t2_res_epispace_smooth,
grad_step=SyN_step_size[0], # always use initial step size of parameter selection
rad_bin=2, # radius of CC metric TODO: make this a parameter on the front-end
update_field_var=SyN_field_variance,
total_field_var=SyN_field_variance,
convergence="".join(["500x" for i in SyN_shrink_factors.split("x")[:-1]]) + "0",
smoothing=SyN_smoothing,
shrink_factors=SyN_shrink_factors,
threshold_window=20,
restrict_deform=restrict_deform,
reference_mask=ref_fov_mask,
source_mask=source_fov_mask,
)
# apply warps to images for reference
t2_res_epispace_warped = append_suffix(t2_res_epispace, "_initialwarped")
NwarpApply(t2_res_epispace_warped, ref_epi_lce_res, t2_res_epispace_smooth, initial_0warp)
ref_epi_lce_res_unwarped = append_suffix(ref_epi_lce_res, "_initialunwarped")
NwarpApply(ref_epi_lce_res_unwarped, t2_res_epispace_smooth, ref_epi_lce_res, initial_0iwarp)
# rename the warps from their temp name
initial_warp = initial_warp_prefix + "Warp.nii.gz"
initial_iwarp = initial_warp_prefix + "InverseWarp.nii.gz"
shutil.move(initial_0warp, initial_warp)
shutil.move(initial_0iwarp, initial_iwarp)
# ensure warps have vector intent code for ANTs
warp_img = nib.load(initial_warp)
warp_img.header.set_intent("vector")
warp_img.to_filename(initial_warp)
iwarp_img = nib.load(initial_iwarp)
iwarp_img.header.set_intent("vector")
iwarp_img.to_filename(initial_iwarp)
# skip synth warp refinement if flag set
if skip_synth_warp:
# resample the forward warp to the EPI resolution
ref_epi_img = nib.load(ref_epi) # resolution of EPI
# assumes voxels are isotropic
data_resolution = ref_epi_img.header.get_zooms()[0]
final_synth_to_epi_warp = os.path.join(output_path, "final_synth_to_epi_warp.nii.gz")
resample(final_synth_to_epi_warp, initial_warp, initial_warp, dxyz=data_resolution)
# get inverted final warp
final_epi_to_synth_warp = os.path.join(output_path, "final_epi_to_synth_warp.nii.gz")
NwarpCat(final_epi_to_synth_warp, final_synth_to_epi_warp, True)
# return warps
return (final_synth_to_epi_warp, final_epi_to_synth_warp)
# get initial regression mask
initial_prefix = os.path.join(output_path, "iteration0_")
regression_mask = make_regression_mask(
initial_prefix,
epi_moco,
anat_bet_mask,
anat_weight_mask,
final_anat_to_epi_affine,
final_epi_to_anat_affine,
initial_warp,
initial_iwarp,
noise_mask_dilation_size,
noise_mask_iterations,
noise_mask_sigma,
)
# resample regression mask
regression_mask_res = append_suffix(regression_mask, "_%smm" % res_str)
resample(regression_mask_res, regression_mask, regression_mask, dxyz=resample_resolution)
# convert affine to omni type
final_anat_to_epi_omni_affine = os.path.join(output_path, "final_anat_to_epi.affine")
convert_affine_file(final_anat_to_epi_omni_affine, final_anat_to_epi_affine, "omni")
# setup synth warp output name function
# setting suffix to "" gives only the prefix
def synth_warp(i, suffix):
return os.path.join(output_path, "iteration%d_synth_%s" % (i, suffix))
# set initial warp, (set to None if use_initializing_warp is NOT set)
if not use_initializing_warp:
initial_warp = None
# Run iterative warp
iterations = len(SyN_step_size)
for n in range(iterations):
logging.info("Running warp (Iteration %d)...", n)
# generate synthetic image
synthetic_epispace = os.path.join(output_path, "iteration%d_synthetic_epispace.nii.gz" % n)
synthetic_anatspace = os.path.join(output_path, "iteration%d_synthetic_anatspace.nii.gz" % n)
register(
ref_epi_lce_res_unwarped,
final_synth_model,
[t1_res, t2_res],
aligned_output=synthetic_epispace,
blurred_regressed_output=synthetic_anatspace,
initial_affine=final_anat_to_epi_omni_affine,
regress_mask=regression_mask_res,
no_register=True,
bandwidth=bandwidth,
tone_curve=True,
tone_curve_mask=regression_mask_res,
)
# warp synthetic to functional
antsRegistration(
synth_warp(n, ""),
ref_epi_lce_res,
synthetic_epispace,
grad_step=SyN_step_size[n],
rad_bin=2,
update_field_var=SyN_field_variance,
total_field_var=SyN_field_variance,
convergence="".join(["500x" for i in SyN_shrink_factors.split("x")[:-1]]) + "0",
threshold_window=20,
smoothing=SyN_smoothing,
shrink_factors=SyN_shrink_factors,
restrict_deform=restrict_deform,
reference_mask=ref_fov_mask,
source_mask=source_fov_mask,
initial_warp=initial_warp if n == 0 else synth_warp(n - 1, "Warp.nii.gz"),
)
# if initial warp was not used
if n == 0 and not use_initializing_warp:
os.rename(synth_warp(n, "0Warp.nii.gz"), synth_warp(n, "Warp.nii.gz"))
os.rename(synth_warp(n, "0InverseWarp.nii.gz"), synth_warp(n, "InverseWarp.nii.gz"))
# concatenate with last warp
else:
antsApplyTransform(
synth_warp(n, "Warp.nii.gz"),
ref_epi_lce_res,
synthetic_epispace,
[initial_warp if n == 0 else synth_warp(n - 1, "Warp.nii.gz"), synth_warp(n, "1Warp.nii.gz")],
composite_warp=True,
)
antsApplyTransform(
synth_warp(n, "InverseWarp.nii.gz"),
ref_epi_lce_res,
synthetic_epispace,
[
initial_iwarp if n == 0 else synth_warp(n - 1, "InverseWarp.nii.gz"),
synth_warp(n, "1InverseWarp.nii.gz"),
],
composite_warp=True,
)
# delete temp warps
os.remove(synth_warp(n, "0Warp.nii.gz"))
os.remove(synth_warp(n, "1Warp.nii.gz"))
os.remove(synth_warp(n, "1InverseWarp.nii.gz"))
# get the warp applied to synthetic
synthetic_warped = append_suffix(synthetic_epispace, "_warped")
NwarpApply(synthetic_warped, ref_epi_lce_res, synthetic_epispace, synth_warp(n, "Warp.nii.gz"))
# get the inverse warp applied to epi
ref_epi_lce_res_unwarped = append_suffix(ref_epi_lce_res, "%d_unwarped" % n)
NwarpApply(ref_epi_lce_res_unwarped, synthetic_epispace, ref_epi_lce_res, synth_warp(n, "InverseWarp.nii.gz"))
# make symlink to unwarped epi for easy reference
epi_unwarped = os.path.join(output_path, "iteration%d_epi_unwarped.nii.gz" % n)
if os.path.islink(epi_unwarped):
os.remove(epi_unwarped)
os.symlink(os.path.basename(ref_epi_lce_res_unwarped), epi_unwarped)
if n < iterations - 1: # skip the last iteration for regression mask
# get new regression mask from new warp
mask_prefix = os.path.join(output_path, "iteration%d_" % (n + 1))
regression_mask = make_regression_mask(
mask_prefix,
epi_moco,
anat_bet_mask,
anat_weight_mask,
final_anat_to_epi_affine,
final_epi_to_anat_affine,
synth_warp(n, "Warp.nii.gz"),
synth_warp(n, "InverseWarp.nii.gz"),
noise_mask_dilation_size,
noise_mask_iterations,
noise_mask_sigma,
)
# resample regression mask
regression_mask_res = append_suffix(regression_mask, "_%smm" % res_str)
resample(regression_mask_res, regression_mask, regression_mask, dxyz=resample_resolution)
# resample the forward warp to the EPI resolution
ref_epi_img = nib.load(ref_epi) # resolution of EPI
# assumes voxels are isotropic
data_resolution = ref_epi_img.header.get_zooms()[0]
final_synth_to_epi_warp = os.path.join(output_path, "final_synth_to_epi_warp.nii.gz")
resample(final_synth_to_epi_warp, synth_warp(n, "Warp.nii.gz"), synth_warp(n, "Warp.nii.gz"), dxyz=data_resolution)
# get inverted final warp
final_epi_to_synth_warp = os.path.join(output_path, "final_epi_to_synth_warp.nii.gz")
NwarpCat(final_epi_to_synth_warp, final_synth_to_epi_warp, True)
# return warps
return (final_synth_to_epi_warp, final_epi_to_synth_warp)