Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions AFQ/api/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,14 +283,6 @@ def __init__(
if len(self.sessions) * len(self.subjects) < 2:
self.parallel_params["engine"] = "serial"

# do not parallelize within subject if parallelizing across
# subject-sessions
if self.parallel_params["engine"] != "serial":
if "ray_n_cpus" not in kwargs:
kwargs["ray_n_cpus"] = 1
if "numba_n_threads" not in kwargs:
kwargs["numba_n_threads"] = 1

self.valid_sub_list = []
self.valid_ses_list = []
self.pAFQ_list = []
Expand Down
4 changes: 2 additions & 2 deletions AFQ/data/fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1520,7 +1520,7 @@ def read_slf_templates(as_img=True, resample_to=False):
"26831642",
"26831645",
"26831648",
"26831651",
"63998848",
Comment thread
36000 marked this conversation as resolved.
"26831654",
"26831657",
"26831660",
Expand All @@ -1537,7 +1537,7 @@ def read_slf_templates(as_img=True, resample_to=False):
"9cff03af586d9dd880750cef3e0bf63f",
"ff728ba3ffa5d1600bcd19fdef8182c4",
"4f1978e418a3169609375c28b3eba0fd",
"fd163893081b520f4594171aeea04f39",
"ebdfe9d26fc4d7b018a26d7e38895055",
"bf795d197912b5e074d248d2763c6930",
"13efde1efe0de52683cbf352ecba457e",
"75c7bd2092950578e599a2dcb218909f",
Expand Down
7 changes: 3 additions & 4 deletions AFQ/models/asym_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def unified_filtering(
sh_data,
sphere,
sh_basis="descoteaux07",
is_legacy=False,
is_legacy=True,
sigma_spatial=1.0,
sigma_align=0.8,
sigma_angle=None,
Expand Down Expand Up @@ -232,7 +232,6 @@ def _unified_filter_build_nx(
# the direction controls the align weight
if i == j == k == 0 or disable_align:
# hack for main direction to have maximal weight
# w_align = np.ones((1, len(directions)), dtype=np.float32)
w_align = np.zeros((1, len(directions)), dtype=np.float32)
else:
dxy /= len_xy
Expand All @@ -252,7 +251,7 @@ def _unified_filter_build_nx(

for ui in range(len(directions)):
w_sum = np.sum(nx_weights[..., ui])
nx_weights /= w_sum
nx_weights[..., ui] /= w_sum

return nx_weights

Expand Down Expand Up @@ -668,7 +667,7 @@ def compute_nufid_asym(sh_coeffs, sphere, csf, mask):
sh_order_max=sh_order,
basis_type="descoteaux07",
full_basis=full_basis,
legacy=False,
legacy=True,
)

# Guess at threshold from 2.0 * mean of ODF maxes in CSF
Expand Down
91 changes: 19 additions & 72 deletions AFQ/models/msmt.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,16 @@
import multiprocessing

import numpy as np
import osqp
import ray
from dipy.reconst.mcsd import MSDeconvFit, MultiShellDeconvModel
from scipy.sparse import csr_matrix
from tqdm import tqdm

from AFQ.utils.stats import chunk_indices

__all__ = ["MultiShellDeconvModel"]


def _fit(self, data, mask=None, n_cpus=None):
def _fit(self, data, mask=None):
"""
Use OSQP to fit the multi-shell spherical deconvolution model.
"""
if n_cpus is None:
n_cpus = max(multiprocessing.cpu_count() - 1, 1)

og_data_shape = data.shape
if len(data.shape) < 4:
data = data.reshape((1,) * (4 - data.ndim) + data.shape)
Expand Down Expand Up @@ -47,69 +39,24 @@ def _fit(self, data, mask=None, n_cpus=None):
A = csr_matrix(A)
Q = csr_matrix(Q)

if n_cpus > 1:
ray.init(ignore_reinit_error=True)

data_id = ray.put(data)
mask_id = ray.put(mask)
Q_id = ray.put(Q)
A_id = ray.put(A)
b_id = ray.put(b)
R_id = ray.put(R)

@ray.remote(num_cpus=n_cpus)
def process_batch_remote(batch_indices, data, mask, Q, A, b, R):
import numpy as np
import osqp

m = osqp.OSQP()
m.setup(P=Q, A=A, l=b, u=None, q=None, verbose=False)
return_values = np.zeros(
(len(batch_indices),) + data.shape[1:3] + (A.shape[1],),
dtype=np.float64,
)
for i, ii in enumerate(batch_indices):
for jj in range(data.shape[1]):
for kk in range(data.shape[2]):
if mask[ii, jj, kk]:
c = np.dot(-R.T, data[ii, jj, kk])
m.update(q=c)
results = m.solve()
return_values[i, jj, kk] = results.x
return return_values

# Launch tasks in chunks
all_indices = list(range(data.shape[0]))
indices_chunked = list(chunk_indices(all_indices, n_cpus * 2))
futures = [
process_batch_remote.remote(batch, data_id, mask_id, Q_id, A_id, b_id, R_id)
for batch in indices_chunked
]

# Collect and assign results
for batch, future in zip(indices_chunked, tqdm(futures)):
results = ray.get(future)
for i, ii in enumerate(batch):
coeff[ii] = results[i]
else:
m = osqp.OSQP()
m.setup(
P=Q,
A=A,
l=b,
u=None,
q=None,
verbose=False,
adaptive_rho=True,
)
for ii in tqdm(range(data.shape[0])):
for jj in range(data.shape[1]):
for kk in range(data.shape[2]):
if mask[ii, jj, kk]:
c = np.dot(-R.T, data[ii, jj, kk])
m.update(q=c)
results = m.solve()
coeff[ii, jj, kk] = results.x
m = osqp.OSQP()
m.setup(
P=Q,
A=A,
l=b,
u=None,
q=None,
verbose=False,
adaptive_rho=True,
)
for ii in tqdm(range(data.shape[0])):
for jj in range(data.shape[1]):
for kk in range(data.shape[2]):
if mask[ii, jj, kk]:
c = np.dot(-R.T, data[ii, jj, kk])
m.update(q=c)
results = m.solve()
coeff[ii, jj, kk] = results.x
coeff = coeff.reshape(og_data_shape[:-1] + (n,))
return MSDeconvFit(self, coeff, None)

Expand Down
2 changes: 1 addition & 1 deletion AFQ/models/wmgm_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def fit_wm_gm_interface(PVE_img, dwiref_img):
static_affine=dwiref_img.affine,
).get_fdata()

wm_boundary = find_boundaries(wm > 0.5, mode="inner")
wm_boundary = find_boundaries(wm > 0.9, mode="inner")
gm_smoothed = gaussian_filter(gm, 1)
csf_smoothed = gaussian_filter(csf, 1)

Expand Down
2 changes: 1 addition & 1 deletion AFQ/recognition/cleaning.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def clean_by_orientation_mahalanobis(
length_threshold = np.inf
fgarray = abu.resample_tg(streamlines, n_points)

assignment_idxs = np.asarray(assignment_map(fgarray, fgarray, 100))
_, assignment_idxs = np.asarray(assignment_map(fgarray, fgarray, n_points))
assignment_idxs = assignment_idxs.reshape((len(fgarray), n_points))
fgarray = np.asarray(fgarray)

Expand Down
43 changes: 5 additions & 38 deletions AFQ/recognition/criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import dipy.tracking.streamline as dts
import nibabel as nib
import numpy as np
import ray
from dipy.io.stateful_tractogram import Space, StatefulTractogram
from dipy.io.streamline import load_tractogram
from dipy.segment.bundles import RecoBundles
Expand All @@ -20,7 +19,6 @@
import AFQ.recognition.utils as abu
from AFQ.api.bundle_dict import apply_to_roi_dict
from AFQ.recognition.clustering import subcluster_by_atlas
from AFQ.utils.stats import chunk_indices
from AFQ.utils.streamlines import move_streamlines

criteria_order_pre_other_bundles = [
Expand Down Expand Up @@ -178,7 +176,7 @@ def primary_axis(b_sls, bundle_def, img, **kwargs):
b_sls.select(accept_idx, "orientation")


def include(b_sls, bundle_def, preproc_imap, n_cpus, **kwargs):
def include(b_sls, bundle_def, preproc_imap, **kwargs):
accept_idx = b_sls.initiate_selection("include")
flip_using_include = len(bundle_def["include"]) > 1 and not b_sls.oriented_yet

Expand All @@ -191,39 +189,9 @@ def include(b_sls, bundle_def, preproc_imap, n_cpus, **kwargs):
else:
include_roi_tols = [preproc_imap["tol"] ** 2] * len(bundle_def["include"])

# For now I am turning ray parallelization here off.
# It is never worthwhile considering other changes we
# have made to speed up this step,
# so spinning up ray and transferring data back
# and forth is not worth it.
# In the future, I think we should redo this with numba and
# use multithreading
n_cpus = 1

# with parallel segmentation, the first for loop will
# only collect streamlines and does not need tqdm
if n_cpus > 1:
inc_results = np.zeros(len(b_sls), dtype=tuple)

inc_rois_id = ray.put(bundle_def["include"])
inc_roi_tols_id = ray.put(include_roi_tols)

_check_inc_parallel = ray.remote(num_cpus=n_cpus)(abr.check_sls_with_inclusion)

sls_chunks = list(chunk_indices(np.arange(len(b_sls)), n_cpus))
futures = [
_check_inc_parallel.remote(
b_sls.get_selected_sls()[sls_chunk], inc_rois_id, inc_roi_tols_id
)
for sls_chunk in sls_chunks
]

for ii, future in enumerate(futures):
inc_results[sls_chunks[ii]] = ray.get(future)
else:
inc_results = abr.check_sls_with_inclusion(
b_sls.get_selected_sls(), bundle_def["include"], include_roi_tols
)
inc_results = abr.check_sls_with_inclusion(
b_sls.get_selected_sls(), bundle_def["include"], include_roi_tols
)

n_inc = len(bundle_def["include"])
roi_closest = np.zeros((n_inc, len(b_sls)), dtype=np.int32)
Expand Down Expand Up @@ -405,13 +373,12 @@ def orient_mahal(b_sls, bundle_def, **kwargs):
b_sls.select(accept_idx, "orient_mahal")


def isolation_forest(b_sls, bundle_def, n_cpus, rng, **kwargs):
def isolation_forest(b_sls, bundle_def, rng, **kwargs):
b_sls.initiate_selection("isolation_forest")
accept_idx = abc.clean_by_isolation_forest(
b_sls.get_selected_sls(),
distance_threshold=bundle_def["isolation_forest"].get("distance_threshold", 3),
n_rounds=bundle_def["isolation_forest"].get("n_rounds", 5),
n_jobs=n_cpus,
random_state=rng,
)
b_sls.select(accept_idx, "isolation_forest")
Expand Down
8 changes: 3 additions & 5 deletions AFQ/recognition/recognize.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def recognize(
mapping,
bundle_dict,
reg_template,
n_cpus,
nb_points=False,
nb_streamlines=False,
clip_edges=False,
Expand Down Expand Up @@ -53,8 +52,6 @@ def recognize(
Dictionary of bundles to segment.
reg_template : str, nib.Nifti1Image
Template image for registration.
n_cpus : int
Number of CPUs to use for parallelization.
nb_points : int, boolean
Resample streamlines to nb_points number of points.
If False, no resampling is done. Can only be done
Expand Down Expand Up @@ -155,7 +152,9 @@ def recognize(
if isinstance(tg, StatefulTractogram):
if nb_streamlines and len(tg) > nb_streamlines:
tg = StatefulTractogram(
select_random_set_of_streamlines(tg.streamlines, nb_streamlines),
select_random_set_of_streamlines(
tg.streamlines, nb_streamlines, rng=rng
),
tg,
tg.space,
)
Expand Down Expand Up @@ -188,7 +187,6 @@ def recognize(
bundle_name,
recognized_bundles_dict,
clip_edges=clip_edges,
n_cpus=n_cpus,
rb_recognize_params=rb_recognize_params,
prob_threshold=prob_threshold,
refine_reco=refine_reco,
Expand Down
Loading
Loading