Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: ME | Revise SDC compute graph #2610

Merged
merged 10 commits into from
Oct 22, 2021
4 changes: 2 additions & 2 deletions fmriprep/interfaces/multiecho.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,8 @@ class T2SMapInputSpec(CommandLineInputSpec):
mandatory=True,
minlen=3,
desc='echo times')
mask_file = File(argstr='--mask',
mask_file = File(argstr='--mask %s',
position=3,
mandatory=False,
desc='mask file',
exists=True)
fittype = traits.Enum('curvefit', 'loglin',
Expand Down Expand Up @@ -95,6 +94,7 @@ class T2SMap(CommandLine):
>>> t2smap.cmdline # doctest: +ELLIPSIS
't2smap -d sub-01_run-01_echo-1_bold.nii.gz sub-01_run-01_echo-2_bold.nii.gz \
sub-01_run-01_echo-3_bold.nii.gz -e 13.0 27.0 43.0 --fittype curvefit'

"""
_cmd = 't2smap'
input_spec = T2SMapInputSpec
Expand Down
165 changes: 109 additions & 56 deletions fmriprep/workflows/bold/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,10 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False):

# BOLD buffer: an identity used as a pointer to either the original BOLD
# or the STC'ed one for further use.
boldbuffer = pe.Node(niu.IdentityInterface(fields=["bold_file"]), name="boldbuffer")
boldbuffer = pe.Node(niu.IdentityInterface(fields=["bold_file", "name_source"]),
name="boldbuffer")
if multiecho:
boldbuffer.synchronize = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forget what this does. Can look it up when there's time, or if you want to say a quick word?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup. Because boldbuffer now has two fields, and both are iterable, synchronize = True preempts the dot product of the values of the two iterable inputs.


summary = pe.Node(
FunctionalSummary(
Expand Down Expand Up @@ -482,22 +485,30 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False):
# fmt:on
else: # for meepi, iterate through stc_wf for all workflows
meepi_echos = boldbuffer.clone(name="meepi_echos")
meepi_echos.iterables = ("bold_file", bold_file)
meepi_echos.iterables = [
("bold_file", bold_file),
("name_source", bold_file),
]
# fmt:off
workflow.connect([
(meepi_echos, bold_stc_wf, [("bold_file", "inputnode.bold_file")]),
(meepi_echos, boldbuffer, [("name_source", "name_source")]),
])
# fmt:on
elif not multiecho: # STC is too short or False

# bypass STC from original BOLD in both SE and ME cases
elif not multiecho: # SE and skip-STC
# fmt:off
# bypass STC from original BOLD to the splitter through boldbuffer
workflow.connect([
(initial_boldref_wf, boldbuffer, [("outputnode.bold_file", "bold_file")]),
])
# fmt:on
else:
else: # ME and skip-STC
# for meepi, iterate over all meepi echos to boldbuffer
boldbuffer.iterables = ("bold_file", bold_file)
boldbuffer.iterables = [
("bold_file", bold_file),
("name_source", bold_file),
]

# MULTI-ECHO EPI DATA #############################################
if multiecho: # instantiate relevant interfaces, imports
Expand Down Expand Up @@ -637,12 +648,12 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False):
(inputnode, func_derivatives_wf, [
(("bold_file", combine_meepi_source), "inputnode.source_file"),
]),
(join_echos, bold_t2s_wf, [("bold_files", "inputnode.bold_file")]),
(join_echos, bold_t2s_wf, [
("bold_files", "inputnode.bold_file"),
]),
(bold_t2s_wf, split_opt_comb, [("outputnode.bold", "in_file")]),
(split_opt_comb, bold_t1_trans_wf, [("out_files", "inputnode.bold_split")]),
(bold_t2s_wf, bold_final, [
("outputnode.bold", "bold"),
]),
(bold_t2s_wf, bold_final, [("outputnode.bold", "bold")]),
])
# fmt:on

Expand Down Expand Up @@ -975,9 +986,11 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False):
use_fieldwarp=False,
name="bold_bold_trans_wf",
)
bold_bold_trans_wf.inputs.inputnode.name_source = ref_file
bold_bold_trans_wf.inputs.inputnode.fieldwarp = "identity"

if not multiecho:
bold_bold_trans_wf.inputs.inputnode.name_source = ref_file

# fmt:off
workflow.connect([
# Connect bold_bold_trans_wf
Expand All @@ -986,37 +999,33 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False):
("outputnode.xforms", "inputnode.hmc_xforms"),
]),
])
# fmt:on

# fmt:off
workflow.connect(
[
(bold_bold_trans_wf, final_boldref_wf, [
("outputnode.bold", "inputnode.bold_file"),
]),
(bold_bold_trans_wf, bold_final, [
("outputnode.bold", "bold"),
]),
] if not multiecho
else [
(bold_bold_trans_wf, join_echos, [
("outputnode.bold", "bold_files"),
]),
(join_echos, final_boldref_wf, [("bold_files", "inputnode.bold_file")]),
# use reference image mask used by bold_bold_trans_wf
(bold_bold_trans_wf, bold_t2s_wf, [
(("outputnode.bold_mask", pop_file), "inputnode.bold_mask"),
]),
]
)
workflow.connect([
(bold_bold_trans_wf, bold_final, [("outputnode.bold", "bold")]),
(bold_bold_trans_wf, final_boldref_wf, [
("outputnode.bold", "inputnode.bold_file"),
]),
] if not multiecho else [
(initial_boldref_wf, bold_t2s_wf, [
("outputnode.bold_mask", "inputnode.bold_mask"),
]),
(boldbuffer, bold_bold_trans_wf, [
("name_source", "inputnode.name_source"),
]),
(bold_bold_trans_wf, join_echos, [
("outputnode.bold", "bold_files"),
]),
(join_echos, final_boldref_wf, [
("bold_files", "inputnode.bold_file"),
]),
])
# fmt:on
return workflow

from niworkflows.interfaces.reportlets.registration import (
SimpleBeforeAfterRPT as SimpleBeforeAfter,
)
from niworkflows.interfaces.utility import KeySelect
from sdcflows.utils.misc import front as _pop
from sdcflows.workflows.apply.registration import init_coeff2epi_wf
from sdcflows.workflows.apply.correction import init_unwarp_wf

Expand Down Expand Up @@ -1087,41 +1096,85 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False):
("outputnode.xforms", "inputnode.hmc_xforms")]),
(initial_boldref_wf, sdc_report, [
("outputnode.ref_image", "before")]),
(unwarp_wf, final_boldref_wf, [
("outputnode.corrected_ref", "inputnode.bold_file"),
]),
(bold_split, unwarp_wf, [
("out_files", "inputnode.distorted")]),
(final_boldref_wf, sdc_report, [
("outputnode.ref_image", "after"),
("outputnode.bold_mask", "wm_seg")]),
(inputnode, ds_report_sdc, [("bold_file", "source_file")]),
(sdc_report, ds_report_sdc, [("out_report", "in_file")]),
# remaining workflow connections
(unwarp_wf, bold_std_trans_wf, [
# TEMPORARY: For the moment we can't use frame-wise fieldmaps
(("outputnode.fieldwarp", _pop), "inputnode.fieldwarp"),
]),
(unwarp_wf, bold_final, [("outputnode.corrected", "bold")]),
(unwarp_wf, bold_t1_trans_wf, [
# TEMPORARY: For the moment we can't use frame-wise fieldmaps
(("outputnode.fieldwarp", _pop), "inputnode.fieldwarp"),
]),

])
# fmt:on

if not multiecho:
# fmt:off
workflow.connect([
(bold_split, unwarp_wf, [
("out_files", "inputnode.distorted")]),
])
# fmt:on
else:
# fmt:off
workflow.connect([
(split_opt_comb, unwarp_wf, [
("out_files", "inputnode.distorted")])
(unwarp_wf, bold_final, [("outputnode.corrected", "bold")]),
# remaining workflow connections
(unwarp_wf, final_boldref_wf, [
("outputnode.corrected", "inputnode.bold_file"),
]),
(unwarp_wf, bold_t1_trans_wf, [
# TEMPORARY: For the moment we can't use frame-wise fieldmaps
(("outputnode.fieldwarp", pop_file), "inputnode.fieldwarp"),
]),
(unwarp_wf, bold_std_trans_wf, [
# TEMPORARY: For the moment we can't use frame-wise fieldmaps
(("outputnode.fieldwarp", pop_file), "inputnode.fieldwarp"),
]),
])
# fmt:on
return workflow

# Finalize connections if ME-EPI
join_sdc_echos = pe.JoinNode(
niu.IdentityInterface(
fields=[
"fieldmap",
"fieldwarp",
"corrected",
"corrected_ref",
"corrected_mask",
]
),
joinsource=("meepi_echos" if run_stc is True else "boldbuffer"),
joinfield=["bold_files"],
name="join_sdc_echos",
)

def _dpop(list_of_lists):
return list_of_lists[0][0]

# fmt:off
workflow.connect([
(unwarp_wf, join_echos, [
("outputnode.corrected", "bold_files"),
]),
(unwarp_wf, join_sdc_echos, [
("outputnode.fieldmap", "fieldmap"),
("outputnode.fieldwarp", "fieldwarp"),
("outputnode.corrected", "corrected"),
("outputnode.corrected_ref", "corrected_ref"),
("outputnode.corrected_mask", "corrected_mask"),
]),
# remaining workflow connections
(join_sdc_echos, final_boldref_wf, [
("corrected", "inputnode.bold_file"),
]),
(join_sdc_echos, bold_t2s_wf, [
("corrected_mask", "inputnode.bold_mask"),
]),
(join_sdc_echos, bold_t1_trans_wf, [
# TEMPORARY: For the moment we can't use frame-wise fieldmaps
(("fieldwarp", _dpop), "inputnode.fieldwarp"),
]),
(join_sdc_echos, bold_std_trans_wf, [
# TEMPORARY: For the moment we can't use frame-wise fieldmaps
(("fieldwarp", _dpop), "inputnode.fieldwarp"),
]),
])
# fmt:on

return workflow

Expand Down