From a0d02c67b6cae81b920045f1a25ca4c24e9bbfb1 Mon Sep 17 00:00:00 2001 From: 36000 Date: Fri, 24 Apr 2026 16:12:30 -0700 Subject: [PATCH 01/26] [ENH] use DIPY fast track by default --- AFQ/tasks/structural.py | 7 +- AFQ/tasks/tractography.py | 25 ++++- AFQ/tests/test_tractography.py | 35 ++++--- AFQ/tractography/tractography.py | 157 ++++++++++++----------------- AFQ/tractography/utils.py | 4 + examples/howto_examples/pyafq_2.py | 2 +- 6 files changed, 111 insertions(+), 119 deletions(-) diff --git a/AFQ/tasks/structural.py b/AFQ/tasks/structural.py index 2422a8cf..9b600f04 100644 --- a/AFQ/tasks/structural.py +++ b/AFQ/tasks/structural.py @@ -30,10 +30,9 @@ def configure_ncpus_nthreads(ray_n_cpus=None, numba_n_threads=None, low_memory=F Tractography and MSMT use Ray. Default: None numba_n_threads : int, optional - The number of threads to use for Numba. + The number of threads to use for Numba and DIPY tracking. If None, uses the number of available CPUs minus one, - but with a maximum of 16. - ASYM fit uses Numba. + but with a maximum of 32. Default: None low_memory : bool, optional Whether to use low-memory versions of algorithms @@ -43,7 +42,7 @@ def configure_ncpus_nthreads(ray_n_cpus=None, numba_n_threads=None, low_memory=F if ray_n_cpus is None: ray_n_cpus = 1 if numba_n_threads is None: - numba_n_threads = min(max(get_num_threads() - 1, 1), 16) + numba_n_threads = min(max(get_num_threads() - 1, 1), 32) return ray_n_cpus, numba_n_threads, low_memory diff --git a/AFQ/tasks/tractography.py b/AFQ/tasks/tractography.py index 13d4bc93..7d844134 100644 --- a/AFQ/tasks/tractography.py +++ b/AFQ/tasks/tractography.py @@ -99,7 +99,6 @@ def streamlines( # get masks this_tracking_params["seed_mask"] = nib.load(seed).get_fdata() - this_tracking_params["pve"] = tissue_imap["pve_internal"] is_trx = this_tracking_params.get("trx", False) @@ -116,6 +115,9 @@ def streamlines( " 'num_chunks' arg" ) + this_tracking_params["pve"] = tissue_imap["pve_internal"] + this_tracking_params["n_threads"] = structural_imap["n_threads"] + @ray.remote class TractActor: def __init__(self): @@ -207,17 +209,30 @@ def delete_lazyt(self, id): sft = trx_concatenate(sfts) else: - lazyt = aft.track(fodf, **this_tracking_params) + lazyt = aft.track( + fodf, + tissue_imap["pve_internal"], + structural_imap["n_threads"], + **this_tracking_params, + ) # Chunk size is number of streamlines tracked before saving to disk. sft = TrxFile.from_lazy_tractogram( - lazyt, seed, dtype_dict=dtype_dict, chunk_size=1e5 + lazyt, + seed, + dtype_dict=dtype_dict, + chunk_size=1e5, + extra_buffer=int(1e6), ) n_streamlines = len(sft) else: start_time = time() - sft = aft.track(fodf, **this_tracking_params) - sft.to_vox() + sft = aft.track( + fodf, + tissue_imap["pve_internal"], + structural_imap["n_threads"], + **this_tracking_params, + ) n_streamlines = len(sft.streamlines) if len(sft) == 0: diff --git a/AFQ/tests/test_tractography.py b/AFQ/tests/test_tractography.py index 160eb031..f3631f85 100644 --- a/AFQ/tests/test_tractography.py +++ b/AFQ/tests/test_tractography.py @@ -50,6 +50,7 @@ def test_csd_local_tracking(): sls = track( fname, fpve, + 1, directions, odf_model="CSD", max_angle=30.0, @@ -58,7 +59,6 @@ def test_csd_local_tracking(): n_seeds=seeds, step_size=step_size, minlen=minlen, - tracker="local", ).streamlines for sl in sls: @@ -71,6 +71,7 @@ def test_dti_local_tracking(): sls = track( fdict["params"], fpve, + 1, directions, max_angle=30.0, sphere="repulsion724", @@ -79,7 +80,6 @@ def test_dti_local_tracking(): step_size=step_size, minlen=minlen, odf_model="DTI", - tracker="local", ).streamlines for sl in sls: npt.assert_(len(sl) >= minlen / step_size) @@ -103,20 +103,19 @@ def test_pft_tracking(): ], ["DTI", "CSD"], ): - for directions in ["det", "prob"]: - sls = track( - fname, - fpve, - directions, - max_angle=30.0, - sphere="repulsion724", - seed_mask=None, - n_seeds=1, - step_size=step_size, - minlen=minlen, - odf_model=odf, - tracker="pft", - ).streamlines + sls = track( + fname, + fpve, + 1, + "pft", + max_angle=30.0, + sphere="repulsion724", + seed_mask=None, + n_seeds=1, + step_size=step_size, + minlen=minlen, + odf_model=odf, + ).streamlines - for sl in sls: - npt.assert_(len(sl) >= minlen / step_size) + for sl in sls: + npt.assert_(len(sl) >= minlen / step_size) diff --git a/AFQ/tractography/tractography.py b/AFQ/tractography/tractography.py index 7f66784f..8e261c54 100644 --- a/AFQ/tractography/tractography.py +++ b/AFQ/tractography/tractography.py @@ -1,18 +1,19 @@ import logging +from time import time import dipy.data as dpd import nibabel as nib import numpy as np from dipy.align import resample -from dipy.direction import ( - DeterministicMaximumDirectionGetter, - ProbabilisticDirectionGetter, -) from dipy.io.stateful_tractogram import Space, StatefulTractogram from dipy.reconst import shm from dipy.reconst.dti import decompose_tensor, from_lower_triangular -from dipy.tracking.local_tracking import LocalTracking, ParticleFilteringTracking from dipy.tracking.stopping_criterion import ActStoppingCriterion +from dipy.tracking.tracker import ( + deterministic_tracking, + pft_tracking, + probabilistic_tracking, +) from nibabel.streamlines.tractogram import LazyTractogram from skimage.segmentation import find_boundaries from tqdm import tqdm @@ -24,13 +25,14 @@ def track( params_file, pve, + n_threads, directions="prob", max_angle=30.0, sphere="repulsion724", seed_mask=None, seed_threshold=0.5, thresholds_as_percentages=False, - n_seeds=1e7, + n_seeds=2e7, random_seeds=True, rng_seed=None, step_size=0.5, @@ -39,7 +41,6 @@ def track( odf_model="CSD_AODF", basis_type="descoteaux07", legacy=True, - tracker="pft", trx=True, ): """ @@ -54,9 +55,13 @@ def track( Full path to a nifti file containing tissue probability maps, or nibabel img with tissue probability maps. This should be of the order (pve_csf, pve_gm, pve_wm). + n_threads : int + The number of threads to use in tracking. + If 0 or -1, uses all available threads. directions : str How tracking directions are determined. - One of: {"det" | "prob"} + One of: {"det" | "prob" | "pft"} + pft refers to Particle Filtering Tracking ([Girard2014]_). Default: "prob" max_angle : float, optional. The maximum turning angle in each step. Default: 30 @@ -105,12 +110,8 @@ def track( The spherical harmonic basis type used to represent the coefficients. One of {"descoteaux07", "tournier07"}. Default: "descoteaux07" legacy : bool, optional - Whether to use the legacy implementation of the direction getter. + Whether the legacy SH basis definition should be used. See Dipy documentation for more details. Default: True - tracker : str, optional - Which strategy to use in tracking. This can be the standard local - tracking ("local") or Particle Filtering Tracking ([Girard2014]_). - One of {"local", "pft"}. Default: "pft" trx : bool, optional Whether to return the streamlines compatible with input to TRX file (i.e., as a LazyTractogram class instance). @@ -148,9 +149,8 @@ def track( odf_model = odf_model.upper() directions = directions.lower() - # transform from mm to step size units - minlen = int(minlen / step_size) - maxlen = int(maxlen / step_size) + if n_threads == -1: + n_threads = 0 if seed_mask is None: seed_mask = np.ones(params_img.shape[:3]) @@ -168,37 +168,6 @@ def track( if isinstance(sphere, str): sphere = dpd.get_sphere(name=sphere) - logger.info("Getting Directions...") - if directions == "det": - dg = DeterministicMaximumDirectionGetter - elif directions == "prob": - dg = ProbabilisticDirectionGetter - else: - raise ValueError(f"Unrecognized direction '{directions}'.") - - logger.debug(f"Using basis type: {basis_type}") - logger.debug(f"Using legacy DG: {legacy}") - - if odf_model == "DTI" or odf_model == "DKI": - evals, evecs = decompose_tensor(from_lower_triangular(model_params)) - odf = tensor_odf(evals, evecs, sphere) - dg = dg.from_pmf(odf, max_angle=max_angle, sphere=sphere) - elif (odf_model == "GQ") or (odf_model == "RUMBA") or ("AODF" in odf_model): - sh_order = shm.order_from_ncoef(model_params.shape[3], full_basis=True) - pmf = shm.sh_to_sf(model_params, sphere, sh_order_max=sh_order, full_basis=True) - pmf[pmf < 0] = 0 - dg = dg.from_pmf( - np.asarray(pmf, dtype=float), max_angle=max_angle, sphere=sphere - ) - else: - dg = dg.from_shcoeff( - model_params, - max_angle=max_angle, - sphere=sphere, - basis_type=basis_type, - legacy=legacy, - ) - if not len(pve_data.shape) == 4 or pve_data.shape[3] != 3: raise RuntimeError( "For pve, expected pve_data with shape [x, y, z, 3]. " @@ -228,77 +197,83 @@ def track( static_affine=params_img.affine, ).get_fdata() - # here we treat edges as gm + # here we treat wm that borders the edge of the brain mask as gm # this is so that streamlines that hit the end of the # (presumably masked) fodf are treated as valid + # (think brain stem) brain_mask = np.any(model_params != 0, axis=-1).astype(np.uint8) edge = find_boundaries(brain_mask, mode="inner") pve_gm_data[edge] = 1.0 pve_wm_data[edge] = 0.0 pve_csf_data[edge] = 0.0 + # Here we adjust the stopping criterion to be slightly more permissive + pve_gm_data = pve_gm_data.astype(float) * 0.51 + pve_csf_data = pve_csf_data.astype(float) * 0.51 + stopping_criterion = ActStoppingCriterion.from_pve( pve_wm_data, pve_gm_data, pve_csf_data ) - if tracker == "local": - my_tracker = LocalTracking - elif tracker == "pft": - my_tracker = ParticleFilteringTracking + if odf_model == "DTI" or odf_model == "DKI": + evals, evecs = decompose_tensor(from_lower_triangular(model_params)) + odf = tensor_odf(evals, evecs, sphere) + elif (odf_model == "GQ") or (odf_model == "RUMBA") or ("AODF" in odf_model): + sh_order = shm.order_from_ncoef(model_params.shape[3], full_basis=True) + odf = shm.sh_to_sf(model_params, sphere, sh_order_max=sh_order, full_basis=True) + odf[odf < 0] = 0 else: - raise ValueError( - f"Unrecognized tracker '{tracker}'. Must be one of {{'local', 'pft'}}." - ) + odf = None - logger.info( - f"Tracking with {len(seeds)} seeds, average of 1-3 directions per seed..." - ) + if directions == "det": # /todo check if works with nonsymmetric + tracker = deterministic_tracking + elif directions == "prob": + tracker = probabilistic_tracking + elif directions == "pft": + tracker = pft_tracking + else: + raise ValueError(f"Unrecognized direction '{directions}'.") + tracking_kwargs = {} - return _tracking( - my_tracker, - seeds, - dg, - stopping_criterion, - params_img, - step_size=step_size, - minlen=minlen, - maxlen=maxlen, - random_seed=rng_seed, - trx=trx, - ) + if ( + (odf_model == "DTI") + or (odf_model == "DKI") + or (odf_model == "GQ") + or (odf_model == "RUMBA") + or ("AODF" in odf_model) + ): + tracking_kwargs["sf"] = odf + else: + tracking_kwargs["sh"] = model_params + logger.info(f"Tracking with {len(seeds)} seeds...") -def _tracking( - tracker, - seeds, - dg, - stopping_criterion, - params_img, - step_size=0.5, - minlen=40, - maxlen=200, - random_seed=None, - trx=False, -): - """ - Helper function - """ if len(seeds.shape) == 1: seeds = seeds[None, ...] + logger.info("Note there will be a long initial delay as seeds are initialized") + start_time = time() tracker = tqdm( tracker( - dg, - stopping_criterion, seeds, + stopping_criterion, params_img.affine, + max_angle=max_angle, + sphere=sphere, + basis_type=basis_type, + legacy=legacy, step_size=step_size, - minlen=minlen, - maxlen=maxlen, + min_len=minlen, + max_len=maxlen, return_all=False, - random_seed=random_seed, - ) + random_seed=rng_seed, + nbr_threads=n_threads, + **tracking_kwargs, + ), + total=len(seeds) * 0.7, + desc="Tracking...", ) + logger.info(f"Tracking took {time() - start_time:.2f} seconds.") if trx: return LazyTractogram(lambda: tracker, affine_to_rasmm=params_img.affine) diff --git a/AFQ/tractography/utils.py b/AFQ/tractography/utils.py index ba0ed482..bae5e1d2 100644 --- a/AFQ/tractography/utils.py +++ b/AFQ/tractography/utils.py @@ -1,4 +1,5 @@ import logging +from time import time import dipy.tracking.utils as dtu import numpy as np @@ -40,6 +41,7 @@ def gen_seeds( :func:`AFQ.tractography.tractography.track`. """ logger.info("Generating Seeds...") + start_time = time() if _is_int(n_seeds): # If it's float type, cast to integer: if isinstance(n_seeds, float): @@ -61,4 +63,6 @@ def gen_seeds( else: # If user provided an array, we'll use n_seeds as the seeds: seeds = n_seeds + + logger.info(f"Generated {len(seeds)} seeds in {time() - start_time:.2f} seconds.") return seeds diff --git a/examples/howto_examples/pyafq_2.py b/examples/howto_examples/pyafq_2.py index 3280a83b..06358b21 100644 --- a/examples/howto_examples/pyafq_2.py +++ b/examples/howto_examples/pyafq_2.py @@ -26,7 +26,7 @@ n_seeds=1, random_seeds=False, minlen=50, - tracker="local", + directions="prob", seed_mask=afm.ScalarImage("dti_fa"), seed_threshold=0.2 ) From 68d2235a112eaca51b2cc7846f7021537434c108 Mon Sep 17 00:00:00 2001 From: 36000 Date: Fri, 24 Apr 2026 16:13:52 -0700 Subject: [PATCH 02/26] unlimit default thread count --- AFQ/tasks/structural.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/AFQ/tasks/structural.py b/AFQ/tasks/structural.py index 9b600f04..71146d27 100644 --- a/AFQ/tasks/structural.py +++ b/AFQ/tasks/structural.py @@ -42,7 +42,7 @@ def configure_ncpus_nthreads(ray_n_cpus=None, numba_n_threads=None, low_memory=F if ray_n_cpus is None: ray_n_cpus = 1 if numba_n_threads is None: - numba_n_threads = min(max(get_num_threads() - 1, 1), 32) + numba_n_threads = max(get_num_threads() - 1, 1) return ray_n_cpus, numba_n_threads, low_memory From 98cf5ee13ce3c82839dec570a1460ad176af07f1 Mon Sep 17 00:00:00 2001 From: 36000 Date: Fri, 24 Apr 2026 16:17:29 -0700 Subject: [PATCH 03/26] also update the docs --- AFQ/tasks/structural.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/AFQ/tasks/structural.py b/AFQ/tasks/structural.py index 71146d27..097d79c8 100644 --- a/AFQ/tasks/structural.py +++ b/AFQ/tasks/structural.py @@ -31,8 +31,7 @@ def configure_ncpus_nthreads(ray_n_cpus=None, numba_n_threads=None, low_memory=F Default: None numba_n_threads : int, optional The number of threads to use for Numba and DIPY tracking. - If None, uses the number of available CPUs minus one, - but with a maximum of 32. + If None, uses the number of available CPUs minus one. Default: None low_memory : bool, optional Whether to use low-memory versions of algorithms From 6fe2c8dcfa28b2a01c765793740d1d8698339fb4 Mon Sep 17 00:00:00 2001 From: 36000 Date: Fri, 24 Apr 2026 16:31:54 -0700 Subject: [PATCH 04/26] return to 1e7 default n seeds, better logger info --- AFQ/tractography/tractography.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/AFQ/tractography/tractography.py b/AFQ/tractography/tractography.py index 8e261c54..74e281c7 100644 --- a/AFQ/tractography/tractography.py +++ b/AFQ/tractography/tractography.py @@ -32,7 +32,7 @@ def track( seed_mask=None, seed_threshold=0.5, thresholds_as_percentages=False, - n_seeds=2e7, + n_seeds=1e7, random_seeds=True, rng_seed=None, step_size=0.5, @@ -273,7 +273,7 @@ def track( total=len(seeds) * 0.7, desc="Tracking...", ) - logger.info(f"Tracking took {time() - start_time:.2f} seconds.") + logger.info((f"Seed initialization took {time() - start_time:.2f} seconds.")) if trx: return LazyTractogram(lambda: tracker, affine_to_rasmm=params_img.affine) From 623b5d5c1f1eddb30e80c09bbecb614170feb66a Mon Sep 17 00:00:00 2001 From: 36000 Date: Fri, 24 Apr 2026 17:02:09 -0700 Subject: [PATCH 05/26] copilot comments --- AFQ/tractography/tractography.py | 2 ++ AFQ/tractography/utils.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/AFQ/tractography/tractography.py b/AFQ/tractography/tractography.py index 74e281c7..da46f1cf 100644 --- a/AFQ/tractography/tractography.py +++ b/AFQ/tractography/tractography.py @@ -208,6 +208,8 @@ def track( pve_csf_data[edge] = 0.0 # Here we adjust the stopping criterion to be slightly more permissive + # DIPY stops ACT at 0.5, so this will cause streamlines to continue + # further into the WM-GM interface pve_gm_data = pve_gm_data.astype(float) * 0.51 pve_csf_data = pve_csf_data.astype(float) * 0.51 diff --git a/AFQ/tractography/utils.py b/AFQ/tractography/utils.py index bae5e1d2..50a97745 100644 --- a/AFQ/tractography/utils.py +++ b/AFQ/tractography/utils.py @@ -62,7 +62,7 @@ def gen_seeds( seeds = dtu.seeds_from_mask(seed_mask, density=n_seeds, affine=affine) else: # If user provided an array, we'll use n_seeds as the seeds: - seeds = n_seeds + seeds = np.asarray(n_seeds) logger.info(f"Generated {len(seeds)} seeds in {time() - start_time:.2f} seconds.") return seeds From d39b615f31b23fa8bc337c19f776e3ad8ad078ab Mon Sep 17 00:00:00 2001 From: 36000 Date: Fri, 24 Apr 2026 17:14:55 -0700 Subject: [PATCH 06/26] fix rng seed bug --- AFQ/tests/test_tractography.py | 4 ++-- AFQ/tractography/tractography.py | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/AFQ/tests/test_tractography.py b/AFQ/tests/test_tractography.py index f3631f85..06748ce5 100644 --- a/AFQ/tests/test_tractography.py +++ b/AFQ/tests/test_tractography.py @@ -51,7 +51,7 @@ def test_csd_local_tracking(): fname, fpve, 1, - directions, + directions=directions, odf_model="CSD", max_angle=30.0, sphere="repulsion724", @@ -72,7 +72,7 @@ def test_dti_local_tracking(): fdict["params"], fpve, 1, - directions, + directions=directions, max_angle=30.0, sphere="repulsion724", seed_mask=None, diff --git a/AFQ/tractography/tractography.py b/AFQ/tractography/tractography.py index da46f1cf..52265ca7 100644 --- a/AFQ/tractography/tractography.py +++ b/AFQ/tractography/tractography.py @@ -227,7 +227,7 @@ def track( else: odf = None - if directions == "det": # /todo check if works with nonsymmetric + if directions == "det": tracker = deterministic_tracking elif directions == "prob": tracker = probabilistic_tracking @@ -248,6 +248,9 @@ def track( else: tracking_kwargs["sh"] = model_params + if rng_seed is not None: + tracking_kwargs["random_seed"] = int(rng_seed) + logger.info(f"Tracking with {len(seeds)} seeds...") if len(seeds.shape) == 1: @@ -268,8 +271,7 @@ def track( min_len=minlen, max_len=maxlen, return_all=False, - random_seed=rng_seed, - nbr_threads=n_threads, + nbr_threads=int(n_threads), **tracking_kwargs, ), total=len(seeds) * 0.7, From 5bf1e1eb79123a80983f8faba34589f63a0ddb5c Mon Sep 17 00:00:00 2001 From: 36000 Date: Fri, 24 Apr 2026 17:26:15 -0700 Subject: [PATCH 07/26] better log message --- AFQ/tractography/tractography.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/AFQ/tractography/tractography.py b/AFQ/tractography/tractography.py index 52265ca7..d2b859a9 100644 --- a/AFQ/tractography/tractography.py +++ b/AFQ/tractography/tractography.py @@ -275,7 +275,7 @@ def track( **tracking_kwargs, ), total=len(seeds) * 0.7, - desc="Tracking...", + desc="Tracking, note that the total is only an estimate...", ) logger.info((f"Seed initialization took {time() - start_time:.2f} seconds.")) From 86d468131b89756910b06679decc5553390197d0 Mon Sep 17 00:00:00 2001 From: 36000 Date: Sat, 25 Apr 2026 17:54:04 -0700 Subject: [PATCH 08/26] upgrade dipy, refine new tracking --- AFQ/models/wmgm_interface.py | 2 +- AFQ/recognition/cleaning.py | 2 +- AFQ/tests/test_tractography.py | 1 + AFQ/tractography/tractography.py | 41 ++++++++----------- .../plot_001_group_afq_api.py | 4 +- setup.cfg | 2 +- 6 files changed, 23 insertions(+), 29 deletions(-) diff --git a/AFQ/models/wmgm_interface.py b/AFQ/models/wmgm_interface.py index c7af160c..6072809d 100644 --- a/AFQ/models/wmgm_interface.py +++ b/AFQ/models/wmgm_interface.py @@ -43,7 +43,7 @@ def fit_wm_gm_interface(PVE_img, dwiref_img): static_affine=dwiref_img.affine, ).get_fdata() - wm_boundary = find_boundaries(wm > 0.5, mode="inner") + wm_boundary = find_boundaries(wm > 0.9, mode="inner") gm_smoothed = gaussian_filter(gm, 1) csf_smoothed = gaussian_filter(csf, 1) diff --git a/AFQ/recognition/cleaning.py b/AFQ/recognition/cleaning.py index 3d37b52c..67f05fbd 100644 --- a/AFQ/recognition/cleaning.py +++ b/AFQ/recognition/cleaning.py @@ -86,7 +86,7 @@ def clean_by_orientation_mahalanobis( length_threshold = np.inf fgarray = abu.resample_tg(streamlines, n_points) - assignment_idxs = np.asarray(assignment_map(fgarray, fgarray, 100)) + _, assignment_idxs = np.asarray(assignment_map(fgarray, fgarray, n_points)) assignment_idxs = assignment_idxs.reshape((len(fgarray), n_points)) fgarray = np.asarray(fgarray) diff --git a/AFQ/tests/test_tractography.py b/AFQ/tests/test_tractography.py index 06748ce5..4a5c38ca 100644 --- a/AFQ/tests/test_tractography.py +++ b/AFQ/tests/test_tractography.py @@ -77,6 +77,7 @@ def test_dti_local_tracking(): sphere="repulsion724", seed_mask=None, n_seeds=1, + random_seeds=False, step_size=step_size, minlen=minlen, odf_model="DTI", diff --git a/AFQ/tractography/tractography.py b/AFQ/tractography/tractography.py index d2b859a9..6d35612d 100644 --- a/AFQ/tractography/tractography.py +++ b/AFQ/tractography/tractography.py @@ -37,7 +37,7 @@ def track( rng_seed=None, step_size=0.5, minlen=20, - maxlen=250, + maxlen=500, odf_model="CSD_AODF", basis_type="descoteaux07", legacy=True, @@ -207,11 +207,10 @@ def track( pve_wm_data[edge] = 0.0 pve_csf_data[edge] = 0.0 - # Here we adjust the stopping criterion to be slightly more permissive - # DIPY stops ACT at 0.5, so this will cause streamlines to continue - # further into the WM-GM interface - pve_gm_data = pve_gm_data.astype(float) * 0.51 - pve_csf_data = pve_csf_data.astype(float) * 0.51 + # We relax ACT stopping criterion here to allow streamlines closer + # to the WM/GM boundary. + pve_gm_data *= 0.8 + pve_csf_data *= 0.8 stopping_criterion = ActStoppingCriterion.from_pve( pve_wm_data, pve_gm_data, pve_csf_data @@ -220,12 +219,18 @@ def track( if odf_model == "DTI" or odf_model == "DKI": evals, evecs = decompose_tensor(from_lower_triangular(model_params)) odf = tensor_odf(evals, evecs, sphere) - elif (odf_model == "GQ") or (odf_model == "RUMBA") or ("AODF" in odf_model): + model_params = shm.sf_to_sh( + odf, sphere, basis_type=basis_type, legacy=legacy, full_basis=True + ) + + tracking_kwargs = {} + if (odf_model == "GQ") or (odf_model == "RUMBA") or ("AODF" in odf_model): sh_order = shm.order_from_ncoef(model_params.shape[3], full_basis=True) - odf = shm.sh_to_sf(model_params, sphere, sh_order_max=sh_order, full_basis=True) - odf[odf < 0] = 0 + pmf = shm.sh_to_sf(model_params, sphere, sh_order_max=sh_order, full_basis=True) + pmf[pmf < 0] = 0 + tracking_kwargs["sf"] = pmf else: - odf = None + tracking_kwargs["sh"] = model_params if directions == "det": tracker = deterministic_tracking @@ -235,18 +240,6 @@ def track( tracker = pft_tracking else: raise ValueError(f"Unrecognized direction '{directions}'.") - tracking_kwargs = {} - - if ( - (odf_model == "DTI") - or (odf_model == "DKI") - or (odf_model == "GQ") - or (odf_model == "RUMBA") - or ("AODF" in odf_model) - ): - tracking_kwargs["sf"] = odf - else: - tracking_kwargs["sh"] = model_params if rng_seed is not None: tracking_kwargs["random_seed"] = int(rng_seed) @@ -274,8 +267,8 @@ def track( nbr_threads=int(n_threads), **tracking_kwargs, ), - total=len(seeds) * 0.7, - desc="Tracking, note that the total is only an estimate...", + total=len(seeds), + desc="Tracking, note that the total is an overestimate...", ) logger.info((f"Seed initialization took {time() - start_time:.2f} seconds.")) diff --git a/examples/tutorial_examples/plot_001_group_afq_api.py b/examples/tutorial_examples/plot_001_group_afq_api.py index 3fb73028..19721fb9 100644 --- a/examples/tutorial_examples/plot_001_group_afq_api.py +++ b/examples/tutorial_examples/plot_001_group_afq_api.py @@ -46,7 +46,7 @@ bids_path = afd.fetch_hbn_preproc( ["NDARAA948VFH"], - clear_previous_afq="all")[1] + clear_previous_afq="track")[1] ########################################################################## # Set tractography parameters (optional) @@ -56,7 +56,7 @@ # distributed in the white matter. We only do this to make this example faster # and consume less space; normally, we use more seeds. -tracking_params = dict(n_seeds=200000, +tracking_params = dict(n_seeds=500000, random_seeds=True, rng_seed=2025, trx=True) diff --git a/setup.cfg b/setup.cfg index 4d6baa6c..9fef175b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -30,7 +30,7 @@ python_requires = >=3.11, <3.14 install_requires = # core packages scikit_image>=0.14.2 - dipy>=1.11.0,<1.12.0 + dipy>=1.12.0,<1.13.0 scikit-learn pandas>=2.2.3 pybids>=0.16.2 From cc8befda25c5af12a1d0d1824bc5c3de57f160af Mon Sep 17 00:00:00 2001 From: 36000 Date: Sat, 25 Apr 2026 17:55:23 -0700 Subject: [PATCH 09/26] put this back --- examples/tutorial_examples/plot_001_group_afq_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/tutorial_examples/plot_001_group_afq_api.py b/examples/tutorial_examples/plot_001_group_afq_api.py index 19721fb9..ce2d1ee9 100644 --- a/examples/tutorial_examples/plot_001_group_afq_api.py +++ b/examples/tutorial_examples/plot_001_group_afq_api.py @@ -46,7 +46,7 @@ bids_path = afd.fetch_hbn_preproc( ["NDARAA948VFH"], - clear_previous_afq="track")[1] + clear_previous_afq="all")[1] ########################################################################## # Set tractography parameters (optional) From d2478e3c3d3317095637229c891720e2257f08e6 Mon Sep 17 00:00:00 2001 From: 36000 Date: Sat, 25 Apr 2026 18:33:46 -0700 Subject: [PATCH 10/26] BFs, test fixes --- AFQ/tests/test_csd.py | 2 +- AFQ/tests/test_dki.py | 2 +- AFQ/tractography/tractography.py | 4 +++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/AFQ/tests/test_csd.py b/AFQ/tests/test_csd.py index f2648857..a07f65b6 100644 --- a/AFQ/tests/test_csd.py +++ b/AFQ/tests/test_csd.py @@ -20,7 +20,7 @@ def test_fit_csd(): np.savetxt(op.join(tmpdir, "bvecs.txt"), bvecs) for sh_order_max in [4, 6]: fname = csd.fit_csd( - fdata, + str(fdata), op.join(tmpdir, "bvals.txt"), op.join(tmpdir, "bvecs.txt"), out_dir=tmpdir, diff --git a/AFQ/tests/test_dki.py b/AFQ/tests/test_dki.py index b188cda2..f40bcec1 100644 --- a/AFQ/tests/test_dki.py +++ b/AFQ/tests/test_dki.py @@ -21,7 +21,7 @@ def test_fit_dki_inputs(): def test_fit_dki(): fdata, fbval, fbvec = dpd.get_fnames("small_101D") with nbtmp.InTemporaryDirectory() as tmpdir: - file_dict = dki.fit_dki(fdata, fbval, fbvec, out_dir=tmpdir) + file_dict = dki.fit_dki(str(fdata), str(fbval), str(fbvec), out_dir=tmpdir) for f in file_dict.values(): op.exists(f) diff --git a/AFQ/tractography/tractography.py b/AFQ/tractography/tractography.py index 6d35612d..9cd40f21 100644 --- a/AFQ/tractography/tractography.py +++ b/AFQ/tractography/tractography.py @@ -224,7 +224,9 @@ def track( ) tracking_kwargs = {} - if (odf_model == "GQ") or (odf_model == "RUMBA") or ("AODF" in odf_model): + if directions == "pft" and (odf_model == "DTI" or odf_model == "DKI"): + tracking_kwargs["sf"] = odf + elif (odf_model == "GQ") or (odf_model == "RUMBA") or ("AODF" in odf_model): sh_order = shm.order_from_ncoef(model_params.shape[3], full_basis=True) pmf = shm.sh_to_sf(model_params, sphere, sh_order_max=sh_order, full_basis=True) pmf[pmf < 0] = 0 From bfee52ca4204924a3fa8f8151132d02c38098714 Mon Sep 17 00:00:00 2001 From: 36000 Date: Sat, 25 Apr 2026 20:55:02 -0700 Subject: [PATCH 11/26] update test --- AFQ/tests/test_api.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/AFQ/tests/test_api.py b/AFQ/tests/test_api.py index df510534..169d61c0 100644 --- a/AFQ/tests/test_api.py +++ b/AFQ/tests/test_api.py @@ -385,13 +385,14 @@ def test_AFQ_seed_array(): LabelledImageFile(path=seg_file, inclusive_labels=[1, 2]), ) - seed_mask = nib.load(seg_file).get_fdata() == 1 + seg_img = nib.load(seg_file) + seed_mask = seg_img.get_fdata() == 1 seeds = dtu.random_seeds_from_mask( seed_mask, seeds_count=20, seed_count_per_voxel=False, - affine=np.eye(4), + affine=seg_img.affine, random_seed=20, ) From c4057ba5401d4b94d599aa5fbe3da92e0a0b97f4 Mon Sep 17 00:00:00 2001 From: 36000 Date: Mon, 27 Apr 2026 10:55:06 -0700 Subject: [PATCH 12/26] remove 0.8*csf --- AFQ/tractography/tractography.py | 1 - 1 file changed, 1 deletion(-) diff --git a/AFQ/tractography/tractography.py b/AFQ/tractography/tractography.py index 9cd40f21..ceb200c1 100644 --- a/AFQ/tractography/tractography.py +++ b/AFQ/tractography/tractography.py @@ -210,7 +210,6 @@ def track( # We relax ACT stopping criterion here to allow streamlines closer # to the WM/GM boundary. pve_gm_data *= 0.8 - pve_csf_data *= 0.8 stopping_criterion = ActStoppingCriterion.from_pve( pve_wm_data, pve_gm_data, pve_csf_data From 5053ede811ef2e8b13a43c06f115868d5ab3007b Mon Sep 17 00:00:00 2001 From: 36000 Date: Mon, 27 Apr 2026 11:59:17 -0700 Subject: [PATCH 13/26] allow test to set random seed --- AFQ/recognition/recognize.py | 4 +++- AFQ/recognition/tests/test_recognition.py | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/AFQ/recognition/recognize.py b/AFQ/recognition/recognize.py index a080a881..6136948f 100644 --- a/AFQ/recognition/recognize.py +++ b/AFQ/recognition/recognize.py @@ -155,7 +155,9 @@ def recognize( if isinstance(tg, StatefulTractogram): if nb_streamlines and len(tg) > nb_streamlines: tg = StatefulTractogram( - select_random_set_of_streamlines(tg.streamlines, nb_streamlines), + select_random_set_of_streamlines( + tg.streamlines, nb_streamlines, rng=rng + ), tg, tg.space, ) diff --git a/AFQ/recognition/tests/test_recognition.py b/AFQ/recognition/tests/test_recognition.py index 5b24d470..3e8d74be 100644 --- a/AFQ/recognition/tests/test_recognition.py +++ b/AFQ/recognition/tests/test_recognition.py @@ -244,6 +244,7 @@ def test_segment_sampled_streamlines(): reg_template, 1, nb_streamlines=nb_streamlines, + rng=2026, ) # sampled streamlines should be subset of the original streamlines From f4ac89bc033dc1da12845fc6f6e69b92df2726c5 Mon Sep 17 00:00:00 2001 From: 36000 Date: Mon, 27 Apr 2026 15:53:05 -0700 Subject: [PATCH 14/26] adjustable gm threshold --- AFQ/tractography/tractography.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/AFQ/tractography/tractography.py b/AFQ/tractography/tractography.py index ceb200c1..c2c72ea6 100644 --- a/AFQ/tractography/tractography.py +++ b/AFQ/tractography/tractography.py @@ -31,6 +31,7 @@ def track( sphere="repulsion724", seed_mask=None, seed_threshold=0.5, + gm_threshold=0.4, thresholds_as_percentages=False, n_seeds=1e7, random_seeds=True, @@ -76,6 +77,9 @@ def track( seed_threshold : float, optional. A value of the seed_mask above which tracking is seeded. Default to 0. + gm_threshold : float, optional. + A value of the pve_gm_data above which we consider a voxel to be GM + for the purposes of ACT stopping criterion. Default: 0.4. n_seeds : int or 2D array, optional. The seeding density: if this is an int, it is is how many seeds in each voxel on each dimension (for example, 2 => [2, 2, 2]). If this is a 2D @@ -209,7 +213,7 @@ def track( # We relax ACT stopping criterion here to allow streamlines closer # to the WM/GM boundary. - pve_gm_data *= 0.8 + pve_gm_data *= 0.5 / gm_threshold stopping_criterion = ActStoppingCriterion.from_pve( pve_wm_data, pve_gm_data, pve_csf_data From 60fc5e8a30ca3b0e05035fef303f901c20e138b2 Mon Sep 17 00:00:00 2001 From: 36000 Date: Mon, 27 Apr 2026 15:59:14 -0700 Subject: [PATCH 15/26] update left OR exclusion ROI to be symmetric --- AFQ/data/fetch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/AFQ/data/fetch.py b/AFQ/data/fetch.py index 735d4625..a5ad6bbf 100644 --- a/AFQ/data/fetch.py +++ b/AFQ/data/fetch.py @@ -1520,7 +1520,7 @@ def read_slf_templates(as_img=True, resample_to=False): "26831642", "26831645", "26831648", - "26831651", + "63998848", "26831654", "26831657", "26831660", @@ -1537,7 +1537,7 @@ def read_slf_templates(as_img=True, resample_to=False): "9cff03af586d9dd880750cef3e0bf63f", "ff728ba3ffa5d1600bcd19fdef8182c4", "4f1978e418a3169609375c28b3eba0fd", - "fd163893081b520f4594171aeea04f39", + "ebdfe9d26fc4d7b018a26d7e38895055", "bf795d197912b5e074d248d2763c6930", "13efde1efe0de52683cbf352ecba457e", "75c7bd2092950578e599a2dcb218909f", From 8cb62088f502639c52dbfcdb42e7d9e40a74cc92 Mon Sep 17 00:00:00 2001 From: 36000 Date: Tue, 28 Apr 2026 12:31:26 -0700 Subject: [PATCH 16/26] update to new GPUStreamlines AODF support --- AFQ/tasks/tractography.py | 1 + AFQ/tractography/gputractography.py | 16 +++++++++------- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/AFQ/tasks/tractography.py b/AFQ/tasks/tractography.py index 7d844134..bac163c0 100644 --- a/AFQ/tasks/tractography.py +++ b/AFQ/tasks/tractography.py @@ -322,6 +322,7 @@ def gpu_tractography( sphere, tracking_params["directions"], tracking_params["seed_threshold"], + tracking_params["gm_threshold"], tracking_params["thresholds_as_percentages"], tracking_params["max_angle"], tracking_params["step_size"], diff --git a/AFQ/tractography/gputractography.py b/AFQ/tractography/gputractography.py index 6d68c77e..d29f4036 100644 --- a/AFQ/tractography/gputractography.py +++ b/AFQ/tractography/gputractography.py @@ -21,6 +21,7 @@ def gpu_track( sphere, directions, seed_threshold, + stop_threshold, thresholds_as_percentages, max_angle, step_size, @@ -54,6 +55,8 @@ def gpu_track( The discretization of the ODF. seed_threshold : float The value of the seed_path above which tracking is seeded. + stop_threshold : float + A value of the WM data below which we stop tracking. thresholds_as_percentages : bool Interpret seed_threshold as percentages of the total non-nan voxels in the seed mask to include @@ -148,9 +151,6 @@ def gpu_track( vox_dim = np.mean(np.diag(np.linalg.cholesky(R.T.dot(R)))) step_size = step_size / vox_dim - # Roughly handle ACT/CMC for now - wm_threshold = 0.5 - pve_img = nib.load(pve_path) wm_img = resample( @@ -170,14 +170,15 @@ def gpu_track( dg = BootDirectionGetter.from_dipy_csa(gtab, sphere) else: raise ValueError(f"odf_model must be 'opdt' or 'csa', not {odf_model}") + full_basis = False else: # Convert SH coefficients to ODFs sym_order = (-3.0 + np.sqrt(1.0 + 8.0 * data.shape[3])) / 2.0 if sym_order.is_integer(): sh_order_max = sym_order full_basis = False - full_order = np.sqrt(data.shape[3]) - 1.0 - if full_order.is_integer(): + else: + full_order = np.sqrt(data.shape[3]) - 1.0 sh_order_max = full_order full_basis = True @@ -194,7 +195,7 @@ def gpu_track( if directions == "ptt": # Set FOD to 0 outside mask for probing - data[wm_data < wm_threshold, :] = 0 + data[wm_data < stop_threshold, :] = 0 dg = PttDirectionGetter() elif directions == "prob": dg = ProbDirectionGetter() @@ -220,9 +221,10 @@ def gpu_track( dg, data, wm_data, - wm_threshold, + stop_threshold, sphere.vertices, sphere.edges, + full_basis=full_basis, max_angle=radians(max_angle), step_size=step_size, min_pts=minlen, From 3c04d403a1ff82f56b01addd046b83803bfd4173 Mon Sep 17 00:00:00 2001 From: 36000 Date: Mon, 4 May 2026 20:43:37 -0700 Subject: [PATCH 17/26] update to latest gpustreamlines --- AFQ/tractography/gputractography.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/AFQ/tractography/gputractography.py b/AFQ/tractography/gputractography.py index d29f4036..5162d47d 100644 --- a/AFQ/tractography/gputractography.py +++ b/AFQ/tractography/gputractography.py @@ -224,7 +224,7 @@ def gpu_track( stop_threshold, sphere.vertices, sphere.edges, - full_basis=full_basis, + sphere_symm=not full_basis, max_angle=radians(max_angle), step_size=step_size, min_pts=minlen, From e32d0605bed2f5ffd1a4af10366e1a771982c5e4 Mon Sep 17 00:00:00 2001 From: 36000 Date: Wed, 6 May 2026 12:41:19 -0700 Subject: [PATCH 18/26] GPUTracker->Tracker --- AFQ/tractography/gputractography.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/AFQ/tractography/gputractography.py b/AFQ/tractography/gputractography.py index 5162d47d..021e4c48 100644 --- a/AFQ/tractography/gputractography.py +++ b/AFQ/tractography/gputractography.py @@ -98,23 +98,23 @@ def gpu_track( if gpu_backend == "auto": from cuslines import ( BootDirectionGetter, - GPUTracker, ProbDirectionGetter, PttDirectionGetter, + Tracker, ) elif gpu_backend == "cuda": from cuslines.cuda_python import ( BootDirectionGetter, - GPUTracker, ProbDirectionGetter, PttDirectionGetter, + Tracker, ) elif gpu_backend == "metal": from cuslines.metal import ( MetalBootDirectionGetter as BootDirectionGetter, ) from cuslines.metal import ( - MetalGPUTracker as GPUTracker, + MetalGPUTracker as Tracker, ) from cuslines.metal import ( MetalProbDirectionGetter as ProbDirectionGetter, @@ -133,7 +133,7 @@ def gpu_track( WebGPUPttDirectionGetter as PttDirectionGetter, ) from cuslines.webgpu import ( - WebGPUTracker as GPUTracker, + WebGPUTracker as Tracker, ) else: raise ValueError( @@ -217,7 +217,7 @@ def gpu_track( if rng_seed is None: rng_seed = np.random.randint(0, 2**31 - 1) - with GPUTracker( + with Tracker( dg, data, wm_data, From 98fc23c80422d05fbc88e9479adf89c96fe2a20b Mon Sep 17 00:00:00 2001 From: 36000 Date: Thu, 7 May 2026 14:46:32 -0700 Subject: [PATCH 19/26] use new numba tracking, merge gpu and nongpu functions, remove ray --- AFQ/api/group.py | 8 - AFQ/models/msmt.py | 91 ++----- AFQ/recognition/criteria.py | 43 +--- AFQ/recognition/recognize.py | 4 - AFQ/tasks/segmentation.py | 1 - AFQ/tasks/structural.py | 16 +- AFQ/tasks/tissue.py | 2 +- AFQ/tasks/tractography.py | 239 ++---------------- AFQ/tests/test_api.py | 4 +- AFQ/tractography/gputractography.py | 239 ------------------ AFQ/tractography/tractography.py | 203 +++++++++++---- AFQ/tractography/utils.py | 4 + AFQ/viz/plot.py | 1 - docs/source/reference/kwargs.rst | 3 - docs/source/reference/methods.rst | 5 - examples/howto_examples/optic_tract.py | 1 - .../plot_001_group_afq_api.py | 15 +- .../plot_002_participant_afq_api.py | 8 +- examples/tutorial_examples/plot_004_export.py | 1 - setup.cfg | 3 +- 20 files changed, 213 insertions(+), 678 deletions(-) delete mode 100644 AFQ/tractography/gputractography.py diff --git a/AFQ/api/group.py b/AFQ/api/group.py index c48c0d70..559c9a80 100644 --- a/AFQ/api/group.py +++ b/AFQ/api/group.py @@ -283,14 +283,6 @@ def __init__( if len(self.sessions) * len(self.subjects) < 2: self.parallel_params["engine"] = "serial" - # do not parallelize within subject if parallelizing across - # subject-sessions - if self.parallel_params["engine"] != "serial": - if "ray_n_cpus" not in kwargs: - kwargs["ray_n_cpus"] = 1 - if "numba_n_threads" not in kwargs: - kwargs["numba_n_threads"] = 1 - self.valid_sub_list = [] self.valid_ses_list = [] self.pAFQ_list = [] diff --git a/AFQ/models/msmt.py b/AFQ/models/msmt.py index b7ffbad6..60cfdc15 100644 --- a/AFQ/models/msmt.py +++ b/AFQ/models/msmt.py @@ -1,24 +1,16 @@ -import multiprocessing - import numpy as np import osqp -import ray from dipy.reconst.mcsd import MSDeconvFit, MultiShellDeconvModel from scipy.sparse import csr_matrix from tqdm import tqdm -from AFQ.utils.stats import chunk_indices - __all__ = ["MultiShellDeconvModel"] -def _fit(self, data, mask=None, n_cpus=None): +def _fit(self, data, mask=None): """ Use OSQP to fit the multi-shell spherical deconvolution model. """ - if n_cpus is None: - n_cpus = max(multiprocessing.cpu_count() - 1, 1) - og_data_shape = data.shape if len(data.shape) < 4: data = data.reshape((1,) * (4 - data.ndim) + data.shape) @@ -47,69 +39,24 @@ def _fit(self, data, mask=None, n_cpus=None): A = csr_matrix(A) Q = csr_matrix(Q) - if n_cpus > 1: - ray.init(ignore_reinit_error=True) - - data_id = ray.put(data) - mask_id = ray.put(mask) - Q_id = ray.put(Q) - A_id = ray.put(A) - b_id = ray.put(b) - R_id = ray.put(R) - - @ray.remote(num_cpus=n_cpus) - def process_batch_remote(batch_indices, data, mask, Q, A, b, R): - import numpy as np - import osqp - - m = osqp.OSQP() - m.setup(P=Q, A=A, l=b, u=None, q=None, verbose=False) - return_values = np.zeros( - (len(batch_indices),) + data.shape[1:3] + (A.shape[1],), - dtype=np.float64, - ) - for i, ii in enumerate(batch_indices): - for jj in range(data.shape[1]): - for kk in range(data.shape[2]): - if mask[ii, jj, kk]: - c = np.dot(-R.T, data[ii, jj, kk]) - m.update(q=c) - results = m.solve() - return_values[i, jj, kk] = results.x - return return_values - - # Launch tasks in chunks - all_indices = list(range(data.shape[0])) - indices_chunked = list(chunk_indices(all_indices, n_cpus * 2)) - futures = [ - process_batch_remote.remote(batch, data_id, mask_id, Q_id, A_id, b_id, R_id) - for batch in indices_chunked - ] - - # Collect and assign results - for batch, future in zip(indices_chunked, tqdm(futures)): - results = ray.get(future) - for i, ii in enumerate(batch): - coeff[ii] = results[i] - else: - m = osqp.OSQP() - m.setup( - P=Q, - A=A, - l=b, - u=None, - q=None, - verbose=False, - adaptive_rho=True, - ) - for ii in tqdm(range(data.shape[0])): - for jj in range(data.shape[1]): - for kk in range(data.shape[2]): - if mask[ii, jj, kk]: - c = np.dot(-R.T, data[ii, jj, kk]) - m.update(q=c) - results = m.solve() - coeff[ii, jj, kk] = results.x + m = osqp.OSQP() + m.setup( + P=Q, + A=A, + l=b, + u=None, + q=None, + verbose=False, + adaptive_rho=True, + ) + for ii in tqdm(range(data.shape[0])): + for jj in range(data.shape[1]): + for kk in range(data.shape[2]): + if mask[ii, jj, kk]: + c = np.dot(-R.T, data[ii, jj, kk]) + m.update(q=c) + results = m.solve() + coeff[ii, jj, kk] = results.x coeff = coeff.reshape(og_data_shape[:-1] + (n,)) return MSDeconvFit(self, coeff, None) diff --git a/AFQ/recognition/criteria.py b/AFQ/recognition/criteria.py index 709c38a4..c7326d99 100644 --- a/AFQ/recognition/criteria.py +++ b/AFQ/recognition/criteria.py @@ -4,7 +4,6 @@ import dipy.tracking.streamline as dts import nibabel as nib import numpy as np -import ray from dipy.io.stateful_tractogram import Space, StatefulTractogram from dipy.io.streamline import load_tractogram from dipy.segment.bundles import RecoBundles @@ -20,7 +19,6 @@ import AFQ.recognition.utils as abu from AFQ.api.bundle_dict import apply_to_roi_dict from AFQ.recognition.clustering import subcluster_by_atlas -from AFQ.utils.stats import chunk_indices from AFQ.utils.streamlines import move_streamlines criteria_order_pre_other_bundles = [ @@ -178,7 +176,7 @@ def primary_axis(b_sls, bundle_def, img, **kwargs): b_sls.select(accept_idx, "orientation") -def include(b_sls, bundle_def, preproc_imap, n_cpus, **kwargs): +def include(b_sls, bundle_def, preproc_imap, **kwargs): accept_idx = b_sls.initiate_selection("include") flip_using_include = len(bundle_def["include"]) > 1 and not b_sls.oriented_yet @@ -191,39 +189,9 @@ def include(b_sls, bundle_def, preproc_imap, n_cpus, **kwargs): else: include_roi_tols = [preproc_imap["tol"] ** 2] * len(bundle_def["include"]) - # For now I am turning ray parallelization here off. - # It is never worthwhile considering other changes we - # have made to speed up this step, - # so spinning up ray and transferring data back - # and forth is not worth it. - # In the future, I think we should redo this with numba and - # use multithreading - n_cpus = 1 - - # with parallel segmentation, the first for loop will - # only collect streamlines and does not need tqdm - if n_cpus > 1: - inc_results = np.zeros(len(b_sls), dtype=tuple) - - inc_rois_id = ray.put(bundle_def["include"]) - inc_roi_tols_id = ray.put(include_roi_tols) - - _check_inc_parallel = ray.remote(num_cpus=n_cpus)(abr.check_sls_with_inclusion) - - sls_chunks = list(chunk_indices(np.arange(len(b_sls)), n_cpus)) - futures = [ - _check_inc_parallel.remote( - b_sls.get_selected_sls()[sls_chunk], inc_rois_id, inc_roi_tols_id - ) - for sls_chunk in sls_chunks - ] - - for ii, future in enumerate(futures): - inc_results[sls_chunks[ii]] = ray.get(future) - else: - inc_results = abr.check_sls_with_inclusion( - b_sls.get_selected_sls(), bundle_def["include"], include_roi_tols - ) + inc_results = abr.check_sls_with_inclusion( + b_sls.get_selected_sls(), bundle_def["include"], include_roi_tols + ) n_inc = len(bundle_def["include"]) roi_closest = np.zeros((n_inc, len(b_sls)), dtype=np.int32) @@ -405,13 +373,12 @@ def orient_mahal(b_sls, bundle_def, **kwargs): b_sls.select(accept_idx, "orient_mahal") -def isolation_forest(b_sls, bundle_def, n_cpus, rng, **kwargs): +def isolation_forest(b_sls, bundle_def, rng, **kwargs): b_sls.initiate_selection("isolation_forest") accept_idx = abc.clean_by_isolation_forest( b_sls.get_selected_sls(), distance_threshold=bundle_def["isolation_forest"].get("distance_threshold", 3), n_rounds=bundle_def["isolation_forest"].get("n_rounds", 5), - n_jobs=n_cpus, random_state=rng, ) b_sls.select(accept_idx, "isolation_forest") diff --git a/AFQ/recognition/recognize.py b/AFQ/recognition/recognize.py index 6136948f..d6575dc9 100644 --- a/AFQ/recognition/recognize.py +++ b/AFQ/recognition/recognize.py @@ -23,7 +23,6 @@ def recognize( mapping, bundle_dict, reg_template, - n_cpus, nb_points=False, nb_streamlines=False, clip_edges=False, @@ -53,8 +52,6 @@ def recognize( Dictionary of bundles to segment. reg_template : str, nib.Nifti1Image Template image for registration. - n_cpus : int - Number of CPUs to use for parallelization. nb_points : int, boolean Resample streamlines to nb_points number of points. If False, no resampling is done. Can only be done @@ -190,7 +187,6 @@ def recognize( bundle_name, recognized_bundles_dict, clip_edges=clip_edges, - n_cpus=n_cpus, rb_recognize_params=rb_recognize_params, prob_threshold=prob_threshold, refine_reco=refine_reco, diff --git a/AFQ/tasks/segmentation.py b/AFQ/tasks/segmentation.py index a1e8e412..d40db1d7 100644 --- a/AFQ/tasks/segmentation.py +++ b/AFQ/tasks/segmentation.py @@ -94,7 +94,6 @@ def segment( mapping_imap["mapping"], bundle_dict, reg_template, - structural_imap["n_cpus"], **segmentation_params, ) diff --git a/AFQ/tasks/structural.py b/AFQ/tasks/structural.py index 097d79c8..1b1c81b0 100644 --- a/AFQ/tasks/structural.py +++ b/AFQ/tasks/structural.py @@ -14,21 +14,15 @@ logger = logging.getLogger("AFQ") -@immlib.calc("n_cpus", "n_threads", "low_mem") -def configure_ncpus_nthreads(ray_n_cpus=None, numba_n_threads=None, low_memory=False): +@immlib.calc("n_threads", "low_mem") +def configure_ncpus_nthreads(numba_n_threads=None, low_memory=False): """ - Configure the number of CPUs to use for parallel processing with Ray, - the number of threads to use for Numba, + Configure the number of threads to use for Numba, and whether to use low-memory versions of algorithms where available Parameters ---------- - ray_n_cpus : int, optional - The number of CPUs to use for parallel processing with Ray. - If None, uses the number of available CPUs minus one. - Tractography and MSMT use Ray. - Default: None numba_n_threads : int, optional The number of threads to use for Numba and DIPY tracking. If None, uses the number of available CPUs minus one. @@ -38,12 +32,10 @@ def configure_ncpus_nthreads(ray_n_cpus=None, numba_n_threads=None, low_memory=F where available. Default: False """ - if ray_n_cpus is None: - ray_n_cpus = 1 if numba_n_threads is None: numba_n_threads = max(get_num_threads() - 1, 1) - return ray_n_cpus, numba_n_threads, low_memory + return numba_n_threads, low_memory @immlib.calc("onnx_kwargs") diff --git a/AFQ/tasks/tissue.py b/AFQ/tasks/tissue.py index d8f96083..61ebb06e 100644 --- a/AFQ/tasks/tissue.py +++ b/AFQ/tasks/tissue.py @@ -208,7 +208,7 @@ def msmt_params( mcsd_model = MultiShellDeconvModel(data_imap["gtab"], response_mcsd) logger.info("Fitting Multi-Shell CSD model...") - mcsd_fit = mcsd_model.fit(data_imap["data"], mask, n_cpus=structural_imap["n_cpus"]) + mcsd_fit = mcsd_model.fit(data_imap["data"], mask) def _get_meta(desc, sh_order, response): return dict( diff --git a/AFQ/tasks/tractography.py b/AFQ/tasks/tractography.py index bac163c0..e5ea36b4 100644 --- a/AFQ/tasks/tractography.py +++ b/AFQ/tasks/tractography.py @@ -1,45 +1,27 @@ import logging from time import time -import dipy.data as dpd import immlib import nibabel as nib import numpy as np from trx.trx_file_memmap import TrxFile -from trx.trx_file_memmap import concatenate as trx_concatenate import AFQ.tractography.tractography as aft from AFQ.definitions.image import ScalarImage from AFQ.definitions.utils import Definition from AFQ.tasks.decorators import as_file from AFQ.tasks.utils import get_default_args, with_name -from AFQ.tractography.utils import gen_seeds - -try: - import ray - - has_ray = True -except ModuleNotFoundError: - has_ray = False - -try: - from AFQ.tractography.gputractography import gpu_track - - has_gputrack = True -except ModuleNotFoundError: - has_gputrack = False logger = logging.getLogger("AFQ") -def _meta_from_tracking_params(tracking_params, start_time, n_streamlines, seed, pve): +def _meta_from_tracking_params(tracking_params, start_time, seed, pve, n_streamlines=0): meta_directions = {"det": "deterministic", "prob": "probabilistic"} meta = dict( TractographyClass="local", TractographyMethod=meta_directions.get( tracking_params["directions"], tracking_params["directions"] ), - Count=n_streamlines, Seeding=dict( ROI=seed, n_seeds=tracking_params["n_seeds"], @@ -55,6 +37,8 @@ def _meta_from_tracking_params(tracking_params, start_time, n_streamlines, seed, ), Timing=time() - start_time, ) + if n_streamlines != 0: + meta["Count"] = n_streamlines return meta @@ -102,119 +86,23 @@ def streamlines( is_trx = this_tracking_params.get("trx", False) - num_chunks = structural_imap["n_cpus"] - if is_trx: start_time = time() dtype_dict = {"positions": np.float32, "offsets": np.uint32} - if num_chunks and num_chunks > 1: - if not has_ray: - raise ImportError( - "Ray is required to perform tractography in" - "parallel, please install ray or remove the" - " 'num_chunks' arg" - ) - - this_tracking_params["pve"] = tissue_imap["pve_internal"] - this_tracking_params["n_threads"] = structural_imap["n_threads"] - - @ray.remote - class TractActor: - def __init__(self): - self.TrxFile = TrxFile - self.aft = aft - self.objects = {} - - def trx_from_lazy_tractogram(self, lazyt_id, seed, dtype_dict): - id = self.objects[lazyt_id] - return self.TrxFile.from_lazy_tractogram( - id, seed, dtype_dict=dtype_dict - ) - - def create_lazyt(self, id, *args, **kwargs): - self.objects[id] = self.aft.track(*args, **kwargs) - return id - - def delete_lazyt(self, id): - if id in self.objects: - del self.objects[id] - - actors = [TractActor.remote() for _ in range(num_chunks)] - object_id = 1 - tracking_params_list = [] - - # random seeds case - if isinstance( - this_tracking_params.get("n_seeds"), int - ) and this_tracking_params.get("random_seeds"): - remainder = this_tracking_params["n_seeds"] % num_chunks - for i in range(num_chunks): - # create copy of tracking params - copy = this_tracking_params.copy() - n_seeds = this_tracking_params["n_seeds"] - copy["n_seeds"] = n_seeds // num_chunks - # add remainder to 1st list - if i == 0: - copy["n_seeds"] += remainder - tracking_params_list.append(copy) - - elif isinstance(this_tracking_params["n_seeds"], (np.ndarray, list)): - n_seeds = np.array(this_tracking_params["n_seeds"]) - seed_chunks = np.array_split(n_seeds, num_chunks) - tracking_params_list = [ - this_tracking_params.copy() for _ in range(num_chunks) - ] - - for i in range(num_chunks): - tracking_params_list[i]["n_seeds"] = seed_chunks[i] - - else: - seeds = gen_seeds( - this_tracking_params["seed_mask"], - this_tracking_params["seed_threshold"], - this_tracking_params["n_seeds"], - this_tracking_params["thresholds_as_percentages"], - this_tracking_params["random_seeds"], - this_tracking_params["rng_seed"], - data_imap["dwi_affine"], - ) - seed_chunks = np.array_split(seeds, num_chunks) - tracking_params_list = [ - this_tracking_params.copy() for _ in range(num_chunks) - ] - for i in range(num_chunks): - tracking_params_list[i]["n_seeds"] = seed_chunks[i] - - # create lazyt inside each actor - tasks = [ - ray_actor.create_lazyt.remote( - object_id, fodf, **tracking_params_list[i] - ) - for i, ray_actor in enumerate(actors) - ] - ray.get(tasks) - - # create trx from lazyt - tasks = [ - ray_actor.trx_from_lazy_tractogram.remote( - object_id, seed, dtype_dict=dtype_dict - ) - for ray_actor in actors - ] - sfts = ray.get(tasks) - - # cleanup objects - tasks = [ray_actor.delete_lazyt.remote(object_id) for ray_actor in actors] - ray.get(tasks) - - sft = trx_concatenate(sfts) + + lazyt = aft.track( + fodf, + tissue_imap["pve_internal"], + structural_imap["n_threads"], + **this_tracking_params, + ) + + if this_tracking_params["directions"] == "prob": + # We do not count these as we go yet, + # this needs to be implemented in GPUStreamlines + n_streamlines = 0 + sft = lazyt else: - lazyt = aft.track( - fodf, - tissue_imap["pve_internal"], - structural_imap["n_threads"], - **this_tracking_params, - ) # Chunk size is number of streamlines tracked before saving to disk. sft = TrxFile.from_lazy_tractogram( lazyt, @@ -223,7 +111,7 @@ def delete_lazyt(self, id): chunk_size=1e5, extra_buffer=int(1e6), ) - n_streamlines = len(sft) + n_streamlines = len(sft) else: start_time = time() @@ -242,7 +130,11 @@ def delete_lazyt(self, id): ) return sft, _meta_from_tracking_params( - tracking_params, start_time, n_streamlines, seed, tissue_imap["pve_internal"] + tracking_params, + start_time, + seed, + tissue_imap["pve_internal"], + n_streamlines, ) @@ -264,99 +156,12 @@ def custom_tractography(import_tract=None): return import_tract -@immlib.calc("streamlines") -@as_file("_tractography", subfolder="tractography") -def gpu_tractography( - data_imap, - tracking_params, - seed, - tissue_imap, - tractography_ngpus=0, - gpu_backend="auto", - chunk_size=25000, -): - """ - full path to the complete, unsegmented tractography file - - Parameters - ---------- - tractography_ngpus : int, optional - Number of GPUs to use in tractography. If non-0, - this algorithm is used for tractography, - https://github.com/dipy/GPUStreamlines - PTT, Prob can be used with any SHM model. - Bootstrapped can be done with CSA/OPDT. - Default: 0 - gpu_backend : str, optional - GPU backend to use for tractography. - One of {"auto", "cuda", "metal", "webgpu"}. - Default: "auto" - chunk_size : int, optional - Chunk size for GPU tracking. - Default: 25000 - """ - start_time = time() - fodf = _fiber_odf(data_imap, tissue_imap, tracking_params) - - if tracking_params["directions"] == "boot": - data = data_imap["data"] - else: - if isinstance(fodf, str): - fodf = nib.load(fodf) - data = fodf.get_fdata() - - pve = tissue_imap["pve_internal"] - - sphere = tracking_params["sphere"] - if sphere is None: - sphere = dpd.get_sphere(name="repulsion724") - else: - sphere = dpd.get_sphere(name=tracking_params["sphere"]) - - sft = gpu_track( - data, - data_imap["gtab"], - seed, - pve, - tracking_params["odf_model"], - sphere, - tracking_params["directions"], - tracking_params["seed_threshold"], - tracking_params["gm_threshold"], - tracking_params["thresholds_as_percentages"], - tracking_params["max_angle"], - tracking_params["step_size"], - tracking_params["minlen"], - tracking_params["maxlen"], - tracking_params["n_seeds"], - tracking_params["random_seeds"], - tracking_params["rng_seed"], - tracking_params["trx"], - tractography_ngpus, - chunk_size, - gpu_backend, - ) - - return sft, _meta_from_tracking_params(tracking_params, start_time, sft, seed, pve) - - def get_tractography_plan(kwargs): if "tracking_params" in kwargs and not isinstance(kwargs["tracking_params"], dict): raise TypeError("tracking_params a dict") tractography_tasks = with_name([streamlines]) - # use GPU accelerated tractography if asked for - if "tractography_ngpus" in kwargs and kwargs["tractography_ngpus"] != 0: - if not has_gputrack: - raise ImportError( - "Please install from ghcr.io/nrdg/pyafq_gpu" - " docker file or from " - "https://github.com/dipy/GPUStreamlines" - " to use gpu-accelerated" - " tractography" - ) - tractography_tasks["streamlines_res"] = gpu_tractography # use imported tractography if given if "import_tract" in kwargs and kwargs["import_tract"] is not None: tractography_tasks["streamlines_res"] = custom_tractography diff --git a/AFQ/tests/test_api.py b/AFQ/tests/test_api.py index 169d61c0..c9fe6123 100644 --- a/AFQ/tests/test_api.py +++ b/AFQ/tests/test_api.py @@ -576,7 +576,6 @@ def test_AFQ_slr(): "full_segmented_cleaned_tractography.trk", ), segmentation_params={"dist_to_waypoint": 10}, - n_cpus=1, bundle_info=bd, mapping_definition=SlrMap(slr_kwargs={"rng": np.random.RandomState(seed)}), ) @@ -806,7 +805,6 @@ def test_AFQ_data_waypoint(): pve=pve, brain_mask_definition=bm_def, n_points_profile=50, - ray_n_cpus=1, tracking_params=tracking_params, segmentation_params=segmentation_params, ) @@ -970,7 +968,7 @@ def test_AFQ_data_waypoint(): dwi_preproc_pipeline="vistasoft", t1_preproc_pipeline="freesurfer", ), - DATA=dict(bundle_info=bundle_dict_as_str, ray_n_cpus=1), + DATA=dict(bundle_info=bundle_dict_as_str), TISSUE=dict(pve=pve_as_str), SEGMENTATION=dict( n_points_profile=50, diff --git a/AFQ/tractography/gputractography.py b/AFQ/tractography/gputractography.py deleted file mode 100644 index 021e4c48..00000000 --- a/AFQ/tractography/gputractography.py +++ /dev/null @@ -1,239 +0,0 @@ -import logging -from math import radians - -import nibabel as nib -import numpy as np -from dipy.align import resample -from dipy.reconst import shm - -from AFQ.tractography.utils import gen_seeds - -logger = logging.getLogger("AFQ") - - -# Modified from https://github.com/dipy/GPUStreamlines/blob/master/run_dipy_gpu.py -def gpu_track( - data, - gtab, - seed_path, - pve_path, - odf_model, - sphere, - directions, - seed_threshold, - stop_threshold, - thresholds_as_percentages, - max_angle, - step_size, - minlen, - maxlen, - n_seeds, - random_seeds, - rng_seed, - use_trx, - ngpus, - chunk_size, - gpu_backend, -): - """ - Perform GPU tractography on DWI data. - - Parameters - ---------- - data : ndarray - DWI data. - gtab : GradientTable - The gradient table. - seed_path : str - Float or binary mask describing the ROI within which we seed for - tracking. - pve_path : str - Estimations of partial volumes of WM, GM, and CSF. - odf_model : str, optional - One of {"OPDT", "CSA"} - sphere : DIPY Sphere - The discretization of the ODF. - seed_threshold : float - The value of the seed_path above which tracking is seeded. - stop_threshold : float - A value of the WM data below which we stop tracking. - thresholds_as_percentages : bool - Interpret seed_threshold as percentages of the - total non-nan voxels in the seed mask to include - (between 0 and 100), instead of as a threshold on the - values themselves. - max_angle : float - The maximum turning angle in each step. - step_size : float - The size of a step (in mm) of tractography. - n_seeds : int - The seeding density: if this is an int, it is is how many seeds in each - voxel on each dimension (for example, 2 => [2, 2, 2]). If this is a 2D - array, these are the coordinates of the seeds. Unless random_seeds is - set to True, in which case this is the total number of random seeds - to generate within the mask. Default: 1 - minlen: int, optional - The minimal length (mm) in a streamline - maxlen: int, optional - The maximum length (mm) in a streamline - random_seeds : bool - If True, n_seeds is total number of random seeds to generate. - If False, n_seeds encodes the density of seeds to generate. - rng_seed : int - random seed used to generate random seeds if random_seeds is - set to True. Default: None - use_trx : bool - Whether to use trx. - ngpus : int - Number of GPUs to use. - chunk_size : int - Chunk size for GPU tracking. - gpu_backend : str, optional - GPU backend to use for tractography. - One of {"auto", "cuda", "metal", "webgpu"}. - Returns - ------- - """ - gpu_backend = gpu_backend.lower() - if gpu_backend == "auto": - from cuslines import ( - BootDirectionGetter, - ProbDirectionGetter, - PttDirectionGetter, - Tracker, - ) - elif gpu_backend == "cuda": - from cuslines.cuda_python import ( - BootDirectionGetter, - ProbDirectionGetter, - PttDirectionGetter, - Tracker, - ) - elif gpu_backend == "metal": - from cuslines.metal import ( - MetalBootDirectionGetter as BootDirectionGetter, - ) - from cuslines.metal import ( - MetalGPUTracker as Tracker, - ) - from cuslines.metal import ( - MetalProbDirectionGetter as ProbDirectionGetter, - ) - from cuslines.metal import ( - MetalPttDirectionGetter as PttDirectionGetter, - ) - elif gpu_backend == "webgpu": - from cuslines.webgpu import ( - WebGPUBootDirectionGetter as BootDirectionGetter, - ) - from cuslines.webgpu import ( - WebGPUProbDirectionGetter as ProbDirectionGetter, - ) - from cuslines.webgpu import ( - WebGPUPttDirectionGetter as PttDirectionGetter, - ) - from cuslines.webgpu import ( - WebGPUTracker as Tracker, - ) - else: - raise ValueError( - "gpu_backend must be one of 'auto', 'cuda', " - f"'metal', or 'webgpu', not {gpu_backend}" - ) - - seed_img = nib.load(seed_path) - directions = directions.lower() - - minlen = int(minlen / step_size) - maxlen = int(maxlen / step_size) - - R = seed_img.affine[0:3, 0:3] - vox_dim = np.mean(np.diag(np.linalg.cholesky(R.T.dot(R)))) - step_size = step_size / vox_dim - - pve_img = nib.load(pve_path) - - wm_img = resample( - pve_img.get_fdata()[..., 2], - seed_img.get_fdata(), - moving_affine=pve_img.affine, - static_affine=seed_img.affine, - ) - wm_data = wm_img.get_fdata() - - seed_data = seed_img.get_fdata() - - if directions == "boot": - if odf_model.lower() == "opdt": - dg = BootDirectionGetter.from_dipy_opdt(gtab, sphere) - elif odf_model.lower() == "csa": - dg = BootDirectionGetter.from_dipy_csa(gtab, sphere) - else: - raise ValueError(f"odf_model must be 'opdt' or 'csa', not {odf_model}") - full_basis = False - else: - # Convert SH coefficients to ODFs - sym_order = (-3.0 + np.sqrt(1.0 + 8.0 * data.shape[3])) / 2.0 - if sym_order.is_integer(): - sh_order_max = sym_order - full_basis = False - else: - full_order = np.sqrt(data.shape[3]) - 1.0 - sh_order_max = full_order - full_basis = True - - theta = sphere.theta - phi = sphere.phi - - sampling_matrix, _, _ = shm.real_sh_descoteaux( - sh_order_max, theta, phi, full_basis=full_basis, legacy=False - ) - model = shm.SphHarmModel(gtab) - model.cache_set("sampling_matrix", sphere, sampling_matrix) - model_fit = shm.SphHarmFit(model, data, None) - data = model_fit.odf(sphere).clip(min=0) - - if directions == "ptt": - # Set FOD to 0 outside mask for probing - data[wm_data < stop_threshold, :] = 0 - dg = PttDirectionGetter() - elif directions == "prob": - dg = ProbDirectionGetter() - else: - raise ValueError( - f"directions must be 'boot', 'ptt', or 'prob', not {directions}" - ) - - seeds = gen_seeds( - seed_data, - seed_threshold, - n_seeds, - thresholds_as_percentages, - random_seeds, - rng_seed, - np.eye(4), - ) - - if rng_seed is None: - rng_seed = np.random.randint(0, 2**31 - 1) - - with Tracker( - dg, - data, - wm_data, - stop_threshold, - sphere.vertices, - sphere.edges, - sphere_symm=not full_basis, - max_angle=radians(max_angle), - step_size=step_size, - min_pts=minlen, - max_pts=maxlen, - ngpus=ngpus, - rng_seed=rng_seed, - chunk_size=chunk_size, - ) as gpu_tracker: - if use_trx: - return gpu_tracker.generate_trx(seeds, seed_img) - else: - return gpu_tracker.generate_sft(seeds, seed_img) diff --git a/AFQ/tractography/tractography.py b/AFQ/tractography/tractography.py index c2c72ea6..7aa5766e 100644 --- a/AFQ/tractography/tractography.py +++ b/AFQ/tractography/tractography.py @@ -1,10 +1,13 @@ import logging +from math import radians from time import time import dipy.data as dpd import nibabel as nib +import numba import numpy as np from dipy.align import resample +from dipy.core.sphere import HemiSphere from dipy.io.stateful_tractogram import Space, StatefulTractogram from dipy.reconst import shm from dipy.reconst.dti import decompose_tensor, from_lower_triangular @@ -12,7 +15,6 @@ from dipy.tracking.tracker import ( deterministic_tracking, pft_tracking, - probabilistic_tracking, ) from nibabel.streamlines.tractogram import LazyTractogram from skimage.segmentation import find_boundaries @@ -43,6 +45,8 @@ def track( basis_type="descoteaux07", legacy=True, trx=True, + jit_backend="numba", + jit_chunk_size=25000, ): """ Tractography @@ -120,6 +124,12 @@ def track( Whether to return the streamlines compatible with input to TRX file (i.e., as a LazyTractogram class instance). Default: True + jit_backend : str, optional + If directions is "prob", the JIT backend to use. One of {"auto", + "cuda", "metal", "webgpu", or "numba"}. Default: "numba" + jit_chunk_size : int, optional + If directions is "prob", the chunk size to use for JIT tracking. + Default: 25000 Returns ------- @@ -159,16 +169,6 @@ def track( if seed_mask is None: seed_mask = np.ones(params_img.shape[:3]) - seeds = gen_seeds( - seed_mask, - seed_threshold, - n_seeds, - thresholds_as_percentages, - random_seeds, - rng_seed, - params_img.affine, - ) - if isinstance(sphere, str): sphere = dpd.get_sphere(name=sphere) @@ -229,55 +229,154 @@ def track( tracking_kwargs = {} if directions == "pft" and (odf_model == "DTI" or odf_model == "DKI"): tracking_kwargs["sf"] = odf - elif (odf_model == "GQ") or (odf_model == "RUMBA") or ("AODF" in odf_model): - sh_order = shm.order_from_ncoef(model_params.shape[3], full_basis=True) - pmf = shm.sh_to_sf(model_params, sphere, sh_order_max=sh_order, full_basis=True) + else: + sym_order = (-3.0 + np.sqrt(1.0 + 8.0 * model_params.shape[3])) / 2.0 + if sym_order.is_integer(): + sh_order_max = sym_order + full_basis = False + else: + full_order = np.sqrt(model_params.shape[3]) - 1.0 + sh_order_max = full_order + full_basis = True + pmf = shm.sh_to_sf( + model_params, sphere, sh_order_max=sh_order_max, full_basis=full_basis + ) pmf[pmf < 0] = 0 tracking_kwargs["sf"] = pmf - else: - tracking_kwargs["sh"] = model_params - - if directions == "det": - tracker = deterministic_tracking - elif directions == "prob": - tracker = probabilistic_tracking - elif directions == "pft": - tracker = pft_tracking - else: - raise ValueError(f"Unrecognized direction '{directions}'.") if rng_seed is not None: tracking_kwargs["random_seed"] = int(rng_seed) + else: + tracking_kwargs["random_seed"] = np.random.randint(0, 2**31 - 1) - logger.info(f"Tracking with {len(seeds)} seeds...") + if directions == "prob": + jit_backend = jit_backend.lower() + if jit_backend == "auto": + from cuslines import ( + ProbDirectionGetter, + Tracker, + ) + elif jit_backend == "cuda": + from cuslines.cuda_python import ( + ProbDirectionGetter, + Tracker, + ) + elif jit_backend == "metal": + from cuslines.metal import ( + MetalGPUTracker as Tracker, + ) + from cuslines.metal import ( + MetalProbDirectionGetter as ProbDirectionGetter, + ) + elif jit_backend == "webgpu": + from cuslines.webgpu import ( + WebGPUProbDirectionGetter as ProbDirectionGetter, + ) + from cuslines.webgpu import ( + WebGPUTracker as Tracker, + ) + elif jit_backend == "numba": + from cuslines.numba import ( + CPUProbDirectionGetter as ProbDirectionGetter, + ) + from cuslines.numba import ( + CPUTracker as Tracker, + ) + else: + raise ValueError( + "gpu_backend must be one of 'auto', 'cuda', " + f"'metal', 'numba', or 'webgpu', not {jit_backend}" + ) - if len(seeds.shape) == 1: - seeds = seeds[None, ...] + dg = ProbDirectionGetter() - logger.info("Note there will be a long initial delay as seeds are initialized") - start_time = time() - tracker = tqdm( - tracker( - seeds, - stopping_criterion, - params_img.affine, - max_angle=max_angle, - sphere=sphere, - basis_type=basis_type, - legacy=legacy, + seeds = gen_seeds( + seed_mask, + seed_threshold, + n_seeds, + thresholds_as_percentages, + random_seeds, + rng_seed, + np.eye(4), # JIT expects seeds in voxel space + ) + + minlen = int(minlen / step_size) + maxlen = int(maxlen / step_size) + + R = params_img.affine[0:3, 0:3] + vox_dim = np.mean(np.diag(np.linalg.cholesky(R.T.dot(R)))) + step_size = step_size / vox_dim + + if n_threads != 0: + old_numba_n_threads = numba.get_num_threads() + numba.set_num_threads(n_threads) + + with Tracker( + dg, + tracking_kwargs["sf"], + pve_wm_data, + gm_threshold, + sphere.vertices, + sphere.edges, + sphere_symm=isinstance(sphere, HemiSphere), + max_angle=radians(max_angle), step_size=step_size, - min_len=minlen, - max_len=maxlen, - return_all=False, - nbr_threads=int(n_threads), - **tracking_kwargs, - ), - total=len(seeds), - desc="Tracking, note that the total is an overestimate...", - ) - logger.info((f"Seed initialization took {time() - start_time:.2f} seconds.")) + min_pts=minlen, + max_pts=maxlen, + rng_seed=tracking_kwargs["random_seed"], + chunk_size=jit_chunk_size, + ) as jit_tracker: + if trx: + res = jit_tracker.generate_trx(seeds, params_img) + else: + res = jit_tracker.generate_sft(seeds, params_img) - if trx: - return LazyTractogram(lambda: tracker, affine_to_rasmm=params_img.affine) + if n_threads != 0: + numba.set_num_threads(old_numba_n_threads) + return res else: - return StatefulTractogram(tracker, params_img, Space.RASMM) + if directions == "det": + tracker = deterministic_tracking + elif directions == "pft": + tracker = pft_tracking + else: + raise ValueError(f"Unrecognized direction '{directions}'.") + + logger.info("Note there will be a long initial delay as seeds are initialized") + + seeds = gen_seeds( + seed_mask, + seed_threshold, + n_seeds, + thresholds_as_percentages, + random_seeds, + rng_seed, + params_img.affine, + ) + + start_time = time() + tracker = tqdm( + tracker( + seeds, + stopping_criterion, + params_img.affine, + max_angle=max_angle, + sphere=sphere, + basis_type=basis_type, + legacy=legacy, + step_size=step_size, + min_len=minlen, + max_len=maxlen, + return_all=False, + nbr_threads=int(n_threads), + **tracking_kwargs, + ), + total=len(seeds), + desc="Tracking, note that the total is an overestimate...", + ) + logger.info((f"Seed initialization took {time() - start_time:.2f} seconds.")) + + if trx: + return LazyTractogram(lambda: tracker, affine_to_rasmm=params_img.affine) + else: + return StatefulTractogram(tracker, params_img, Space.RASMM) diff --git a/AFQ/tractography/utils.py b/AFQ/tractography/utils.py index 50a97745..d60c07c6 100644 --- a/AFQ/tractography/utils.py +++ b/AFQ/tractography/utils.py @@ -64,5 +64,9 @@ def gen_seeds( # If user provided an array, we'll use n_seeds as the seeds: seeds = np.asarray(n_seeds) + if len(seeds.shape) == 1: + seeds = seeds[None, ...] + logger.info(f"Generated {len(seeds)} seeds in {time() - start_time:.2f} seconds.") + logger.info(f"Tracking with {len(seeds)} seeds...") return seeds diff --git a/AFQ/viz/plot.py b/AFQ/viz/plot.py index 2577c93e..1fcb621c 100644 --- a/AFQ/viz/plot.py +++ b/AFQ/viz/plot.py @@ -444,7 +444,6 @@ def reco_flip(df): self.bundles = bundles self.color_dict = vut.gen_color_dict([*self.bundles, "median"]) - # TODO: make these parameters self.scalar_markers = ["o", "x"] self.patterns = (None, "/", "o", "x", "-", ".", "+", "//", "\\", "*", "O", "|") diff --git a/docs/source/reference/kwargs.rst b/docs/source/reference/kwargs.rst index 3c17a0a4..bf537b56 100644 --- a/docs/source/reference/kwargs.rst +++ b/docs/source/reference/kwargs.rst @@ -17,9 +17,6 @@ Here are the arguments you can pass to kwargs, to customize the tractometry pipe ========================================================== STRUCTURAL ========================================================== -ray_n_cpus: int - The number of CPUs to use for parallel processing with Ray. If None, uses the number of available CPUs minus one. Tractography and MSMT use Ray. Default: None - numba_n_threads: int The number of threads to use for Numba. If None, uses the number of available CPUs minus one, but with a maximum of 16. ASYM fit uses Numba. Default: None diff --git a/docs/source/reference/methods.rst b/docs/source/reference/methods.rst index 75dd8ebc..a6435d20 100644 --- a/docs/source/reference/methods.rst +++ b/docs/source/reference/methods.rst @@ -46,11 +46,6 @@ pve_gm: pve_wm: White matter partial volume estimate map - -n_cpus: - Configure the number of CPUs to use for parallel processing with Ray - - n_threads: the number of threads to use for Numba diff --git a/examples/howto_examples/optic_tract.py b/examples/howto_examples/optic_tract.py index b68fdf11..ce760e9d 100644 --- a/examples/howto_examples/optic_tract.py +++ b/examples/howto_examples/optic_tract.py @@ -151,7 +151,6 @@ output_dir=op.join(study_dir, "derivatives", "afq_otoc"), pve=pve, tracking_params=tractography_params, - ray_n_cpus=4, segmentation_params=segmentation_params, bundle_info=otoc_bd) diff --git a/examples/tutorial_examples/plot_001_group_afq_api.py b/examples/tutorial_examples/plot_001_group_afq_api.py index ce2d1ee9..8a456ea7 100644 --- a/examples/tutorial_examples/plot_001_group_afq_api.py +++ b/examples/tutorial_examples/plot_001_group_afq_api.py @@ -116,12 +116,7 @@ # the name of the t1 preprocessing pipeline we want to use (in this case, # its the same, qsiprep [3]), the participant labels we want to process # (in this case, just a single subject), the PVE images we defined above, and -# the tracking parameters we defined above. We set ray_n_cpus=1 and -# low_memory=True to avoid memory issues running this example on -# Github actions. If these settings are omitted, -# which can be done in most cases, the default behavior will -# parallelize processing, resulting in faster runtime, -# but also in higher memory usage. +# the tracking parameters we defined above. myafq = GroupAFQ( bids_path=op.join(afd.afq_home, 'HBN'), @@ -130,9 +125,7 @@ participant_labels=['NDARAA948VFH'], pve=pve, brain_mask_definition=brain_mask_definition, - tracking_params=tracking_params, - ray_n_cpus=1, - low_memory=True) + tracking_params=tracking_params) ########################################################################## # Calculating DKI FA (Diffusion Kurtosis Imaging Fractional Anisotropy) @@ -294,9 +287,7 @@ myafq = GroupAFQ.from_qsiprep( qsi_dir=op.join(afd.afq_home, 'HBN'), participant_labels=['NDARAA948VFH'], - tracking_params=tracking_params, - ray_n_cpus=1, - low_memory=True) + tracking_params=tracking_params) ############################################################################# # References diff --git a/examples/tutorial_examples/plot_002_participant_afq_api.py b/examples/tutorial_examples/plot_002_participant_afq_api.py index 3155a9dc..d684d051 100644 --- a/examples/tutorial_examples/plot_002_participant_afq_api.py +++ b/examples/tutorial_examples/plot_002_participant_afq_api.py @@ -131,12 +131,7 @@ # # To initialize the object, we will pass in the diffusion data files and specify # the output directory where we want to store the results. We will also -# pass in the tracking parameters we defined above. We set ray_n_cpus=1 -# and low_memory=True to avoid memory issues running this example on -# Github actions. If these settings are omitted, -# which can be done in most cases, the default behavior will -# parallelize processing, resulting in faster runtime, -# but also in higher memory usage. +# pass in the tracking parameters we defined above. myafq = ParticipantAFQ( dwi_data_file=dwi_data_file, @@ -147,7 +142,6 @@ tracking_params=tracking_params, pve=pve, brain_mask_definition=brain_mask_definition, - ray_n_cpus=1, ) ########################################################################## diff --git a/examples/tutorial_examples/plot_004_export.py b/examples/tutorial_examples/plot_004_export.py index af92be78..b3137a1d 100644 --- a/examples/tutorial_examples/plot_004_export.py +++ b/examples/tutorial_examples/plot_004_export.py @@ -67,7 +67,6 @@ t1_file=t1_file, output_dir=output_dir, pve=pve, - ray_n_cpus=1, tracking_params={ "n_seeds": 10000, "random_seeds": True, diff --git a/setup.cfg b/setup.cfg index 9fef175b..ace67f02 100644 --- a/setup.cfg +++ b/setup.cfg @@ -38,6 +38,7 @@ install_requires = immlib trx-python bibtexparser + cuslines==2.2 # efficiency numba osqp @@ -94,7 +95,7 @@ aws = nn = onnxruntime gpu = - cuslines==2.0.0 + cuslines[cu13]==2.2 onnxruntime-gpu all = %(dev)s From beaf5ceea897400e3c9de8d81d62cd5a0ba8eecd Mon Sep 17 00:00:00 2001 From: John Kruper <36000@users.noreply.github.com> Date: Thu, 7 May 2026 15:32:55 -0700 Subject: [PATCH 20/26] typo Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- AFQ/tractography/tractography.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/AFQ/tractography/tractography.py b/AFQ/tractography/tractography.py index 7aa5766e..e1fda001 100644 --- a/AFQ/tractography/tractography.py +++ b/AFQ/tractography/tractography.py @@ -284,7 +284,7 @@ def track( ) else: raise ValueError( - "gpu_backend must be one of 'auto', 'cuda', " + "jit_backend must be one of 'auto', 'cuda', " f"'metal', 'numba', or 'webgpu', not {jit_backend}" ) From 4b360d09cdc5954fa118b96540d179796d625bf7 Mon Sep 17 00:00:00 2001 From: 36000 Date: Thu, 7 May 2026 15:35:50 -0700 Subject: [PATCH 21/26] copilot BFs --- AFQ/tractography/tractography.py | 2 ++ examples/tutorial_examples/plot_001_group_afq_api.py | 2 +- examples/tutorial_examples/plot_002_participant_afq_api.py | 4 ++-- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/AFQ/tractography/tractography.py b/AFQ/tractography/tractography.py index e1fda001..e8f87931 100644 --- a/AFQ/tractography/tractography.py +++ b/AFQ/tractography/tractography.py @@ -156,6 +156,8 @@ def track( if isinstance(pve, str): pve_img = nib.load(pve) + if isinstance(pve, nib.Nifti1Image): + pve_img = pve pve_data = pve_img.get_fdata() model_params = params_img.get_fdata() diff --git a/examples/tutorial_examples/plot_001_group_afq_api.py b/examples/tutorial_examples/plot_001_group_afq_api.py index 8a456ea7..79c354e1 100644 --- a/examples/tutorial_examples/plot_001_group_afq_api.py +++ b/examples/tutorial_examples/plot_001_group_afq_api.py @@ -52,7 +52,7 @@ # Set tractography parameters (optional) # --------------------------------------- # We make create a `tracking_params` variable, which we will pass to the -# GroupAFQ object which specifies that we want 200,000 seeds randomly +# GroupAFQ object which specifies that we want 500,000 seeds randomly # distributed in the white matter. We only do this to make this example faster # and consume less space; normally, we use more seeds. diff --git a/examples/tutorial_examples/plot_002_participant_afq_api.py b/examples/tutorial_examples/plot_002_participant_afq_api.py index d684d051..4fb195a9 100644 --- a/examples/tutorial_examples/plot_002_participant_afq_api.py +++ b/examples/tutorial_examples/plot_002_participant_afq_api.py @@ -70,11 +70,11 @@ # Set tractography parameters (optional) # --------------------------------------- # We make create a `tracking_params` variable, which we will pass to the -# ParticipantAFQ object which specifies that we want 200,000 seeds randomly +# ParticipantAFQ object which specifies that we want 500,000 seeds randomly # distributed in the white matter. We only do this to make this example faster # and consume less space; normally, we use more seeds. -tracking_params = dict(n_seeds=200000, +tracking_params = dict(n_seeds=500000, random_seeds=True, rng_seed=2025, trx=True) From 3214d9d627ad9533719b81da1b13a61f1bed46ff Mon Sep 17 00:00:00 2001 From: 36000 Date: Thu, 7 May 2026 15:58:04 -0700 Subject: [PATCH 22/26] standardize n seeds to accept RASMM, update test_recognition to api with n_cpus removed --- AFQ/recognition/tests/test_recognition.py | 23 +++++--------- AFQ/tractography/tractography.py | 37 ++++++++++------------- setup.cfg | 4 +-- 3 files changed, 26 insertions(+), 38 deletions(-) diff --git a/AFQ/recognition/tests/test_recognition.py b/AFQ/recognition/tests/test_recognition.py index 3e8d74be..c81f7eb9 100644 --- a/AFQ/recognition/tests/test_recognition.py +++ b/AFQ/recognition/tests/test_recognition.py @@ -62,7 +62,7 @@ def test_segment(): - fiber_groups, _ = recognize(tg, hardi_img, mapping, bundles, reg_template, 2) + fiber_groups, _ = recognize(tg, hardi_img, mapping, bundles, reg_template) # We asked for 2 fiber groups: npt.assert_equal(len(fiber_groups), 2) @@ -100,9 +100,7 @@ def test_segment_mixed_roi(): "resample_subject_to cannot be False." ), ): - fiber_groups, _ = recognize( - tg, hardi_img, mapping, bundle_info, reg_template, 2 - ) + fiber_groups, _ = recognize(tg, hardi_img, mapping, bundle_info, reg_template) bundle_info = abd.BundleDict(bundle_info, resample_subject_to=hardi_fdata) fiber_groups, _ = recognize( @@ -111,7 +109,6 @@ def test_segment_mixed_roi(): mapping, bundle_info, reg_template, - 2, dist_to_atlas=10, ) @@ -135,9 +132,7 @@ def test_segment_no_prob(): }, } - fiber_groups, _ = recognize( - tg, hardi_img, mapping, bundles_no_prob, reg_template, 1 - ) + fiber_groups, _ = recognize(tg, hardi_img, mapping, bundles_no_prob, reg_template) # This condition should still hold npt.assert_equal(len(fiber_groups), 2) @@ -147,7 +142,7 @@ def test_segment_no_prob(): def test_segment_return_idx(): # Test with the return_idx kwarg set to True: fiber_groups, _ = recognize( - tg, hardi_img, mapping, bundles, reg_template, 1, return_idx=True + tg, hardi_img, mapping, bundles, reg_template, return_idx=True ) npt.assert_equal(len(fiber_groups), 2) @@ -163,7 +158,7 @@ def test_segment_return_idx(): def test_segment_clip_edges_api(): # Test with the clip_edges kwarg set to True: fiber_groups, _ = recognize( - tg, hardi_img, mapping, bundles, reg_template, 1, clip_edges=True + tg, hardi_img, mapping, bundles, reg_template, clip_edges=True ) npt.assert_equal(len(fiber_groups), 2) npt.assert_(len(fiber_groups["Right Corticospinal"]) > 0) @@ -184,7 +179,6 @@ def test_segment_reco(): mapping, bundles_reco, reg_template, - 1, rng=np.random.RandomState(seed=8), ) @@ -213,19 +207,19 @@ def test_exclusion_ROI(): hardi_img, Space.VOX, ) - fiber_groups, _ = recognize(slf_tg, hardi_img, mapping, slf_bundle, reg_template, 1) + fiber_groups, _ = recognize(slf_tg, hardi_img, mapping, slf_bundle, reg_template) npt.assert_equal(len(fiber_groups["Left Superior Longitudinal"]), 2) slf_bundle["Left Superior Longitudinal"]["exclude"] = [templates["SLFt_roi2_L"]] - fiber_groups, _ = recognize(slf_tg, hardi_img, mapping, slf_bundle, reg_template, 1) + fiber_groups, _ = recognize(slf_tg, hardi_img, mapping, slf_bundle, reg_template) npt.assert_equal(len(fiber_groups["Left Superior Longitudinal"]), 1) def test_segment_sampled_streamlines(): - fiber_groups, _ = recognize(tg, hardi_img, mapping, bundles, reg_template, 1) + fiber_groups, _ = recognize(tg, hardi_img, mapping, bundles, reg_template) # Already using a subsampled tck # the Right Corticospinal has two streamlines and @@ -242,7 +236,6 @@ def test_segment_sampled_streamlines(): mapping, bundles, reg_template, - 1, nb_streamlines=nb_streamlines, rng=2026, ) diff --git a/AFQ/tractography/tractography.py b/AFQ/tractography/tractography.py index e8f87931..62f6bbbf 100644 --- a/AFQ/tractography/tractography.py +++ b/AFQ/tractography/tractography.py @@ -87,8 +87,9 @@ def track( n_seeds : int or 2D array, optional. The seeding density: if this is an int, it is is how many seeds in each voxel on each dimension (for example, 2 => [2, 2, 2]). If this is a 2D - array, these are the coordinates of the seeds. Unless random_seeds is - set to True, in which case this is the total number of random seeds + array, these are the coordinates of the seeds in RASMM. + Unless random_seeds is set to True, + in which case this is the total number of random seeds to generate within the mask. Default: 1e7 random_seeds : bool Whether to generate a total of n_seeds random seeds in the mask. @@ -251,6 +252,16 @@ def track( else: tracking_kwargs["random_seed"] = np.random.randint(0, 2**31 - 1) + seeds = gen_seeds( + seed_mask, + seed_threshold, + n_seeds, + thresholds_as_percentages, + random_seeds, + rng_seed, + params_img.affine, + ) + if directions == "prob": jit_backend = jit_backend.lower() if jit_backend == "auto": @@ -292,15 +303,9 @@ def track( dg = ProbDirectionGetter() - seeds = gen_seeds( - seed_mask, - seed_threshold, - n_seeds, - thresholds_as_percentages, - random_seeds, - rng_seed, - np.eye(4), # JIT expects seeds in voxel space - ) + inv_affine = np.linalg.inv(params_img.affine) + seeds = np.dot(seeds, inv_affine[:3, :3].T) + seeds += inv_affine[:3, 3] minlen = int(minlen / step_size) maxlen = int(maxlen / step_size) @@ -346,16 +351,6 @@ def track( logger.info("Note there will be a long initial delay as seeds are initialized") - seeds = gen_seeds( - seed_mask, - seed_threshold, - n_seeds, - thresholds_as_percentages, - random_seeds, - rng_seed, - params_img.affine, - ) - start_time = time() tracker = tqdm( tracker( diff --git a/setup.cfg b/setup.cfg index ace67f02..5c4c52c2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -38,7 +38,7 @@ install_requires = immlib trx-python bibtexparser - cuslines==2.2 + cuslines==2.2.1 # efficiency numba osqp @@ -95,7 +95,7 @@ aws = nn = onnxruntime gpu = - cuslines[cu13]==2.2 + cuslines[cu13]==2.2.1 onnxruntime-gpu all = %(dev)s From 3272adfaf401f2fac7c33383b002c70e71274a1b Mon Sep 17 00:00:00 2001 From: 36000 Date: Thu, 7 May 2026 18:30:52 -0700 Subject: [PATCH 23/26] update examples and docs --- docs/source/explanations/whats_new_3.rst | 2 +- docs/source/reference/kwargs.rst | 17 +++---------- docs/source/reference/methods.rst | 7 +++++- examples/howto_examples/pyAFQ_with_GPU.py | 25 +++++++++++++------ .../plot_001_group_afq_api.py | 10 +++----- .../plot_002_participant_afq_api.py | 4 +-- 6 files changed, 34 insertions(+), 31 deletions(-) diff --git a/docs/source/explanations/whats_new_3.rst b/docs/source/explanations/whats_new_3.rst index d39a9f55..3e42ba54 100644 --- a/docs/source/explanations/whats_new_3.rst +++ b/docs/source/explanations/whats_new_3.rst @@ -20,7 +20,7 @@ estimates (PVEs) from other pipelines such as FSLfast (https://web.mit.edu/fsl_v5.0.10/fsl/doc/wiki/FAST.html) or Freesurfer (https://surfer.nmr.mgh.harvard.edu/). If these are not provided, pyAFQ will generate them using SynthSeg -and the T1 :cite:`Tzourio-billot_synthseg_2023,billot_robust_2023`. +and the T1 :cite:`billot_synthseg_2023,billot_robust_2023`. These PVEs are used for tractography, which has been radically changed in pyAFQ 3.0, relative to previous versions: First, diff --git a/docs/source/reference/kwargs.rst b/docs/source/reference/kwargs.rst index bf537b56..b5489a8a 100644 --- a/docs/source/reference/kwargs.rst +++ b/docs/source/reference/kwargs.rst @@ -18,7 +18,7 @@ Here are the arguments you can pass to kwargs, to customize the tractometry pipe STRUCTURAL ========================================================== numba_n_threads: int - The number of threads to use for Numba. If None, uses the number of available CPUs minus one, but with a maximum of 16. ASYM fit uses Numba. Default: None + The number of threads to use for Numba and DIPY tracking. If None, uses the number of available CPUs minus one. Default: None low_memory: bool Whether to use low-memory versions of algorithms where available. Default: False @@ -45,6 +45,9 @@ max_bval: float b0_threshold: int The value of b under which it is considered to be b0. Default: 50. +min_b0_for_r1_approximation: float + The minimum value of b0 to consider when doing the division. This is to avoid dividing by small numbers. Default: 1e-2 + robust_tensor_fitting: bool Whether to use robust_tensor_fitting when doing dti. Only applies to dti. Default: False @@ -84,9 +87,6 @@ opdt_sh_order_max: int csa_sh_order_max: int Spherical harmonics order for CSA model. Must be even. Default: 8 -sphere: Sphere class instance - The sphere providing sample directions for the initial search of the maximal value of kurtosis. Default: 'repulsion100' - gtol: float This input is to refine kurtosis maxima under the precision of the directions sampled on the sphere class instance. The gradient of the convergence procedure must be less than gtol before successful termination. If gtol is None, fiber direction is directly taken from the initial sampled directions of the given sphere object. Default: 1e-2 @@ -148,15 +148,6 @@ tracking_params: dict import_tract: dict or str or None BIDS filters for inputing a user made tractography file, or a path to the tractography file. If None, DIPY is used to generate the tractography. Default: None -tractography_ngpus: int - Number of GPUs to use in tractography. If non-0, this algorithm is used for tractography, https://github.com/dipy/GPUStreamlines PTT, Prob can be used with any SHM model. Bootstrapped can be done with CSA/OPDT. Default: 0 - -gpu_backend: str - GPU backend to use for tractography. One of {"auto", "cuda", "metal", "webgpu"}. Default: "auto" - -chunk_size: int - Chunk size for GPU tracking. Default: 25000 - ========================================================== VIZ diff --git a/docs/source/reference/methods.rst b/docs/source/reference/methods.rst index a6435d20..d8bdc092 100644 --- a/docs/source/reference/methods.rst +++ b/docs/source/reference/methods.rst @@ -46,8 +46,9 @@ pve_gm: pve_wm: White matter partial volume estimate map + n_threads: - the number of threads to use for Numba + Configure the number of threads to use for Numba low_mem: @@ -98,6 +99,10 @@ b0: full path to a nifti file containing the mean b0 +t1w_over_b0: + full path to a nifti file containing the T1w over mean b0 which is a proxy for R1 [1]_ + + masked_b0: full path to a nifti file containing the mean b0 after applying the brain mask diff --git a/examples/howto_examples/pyAFQ_with_GPU.py b/examples/howto_examples/pyAFQ_with_GPU.py index 0884c34b..af234c23 100644 --- a/examples/howto_examples/pyAFQ_with_GPU.py +++ b/examples/howto_examples/pyAFQ_with_GPU.py @@ -4,12 +4,12 @@ ============================================ Running pyAFQ using the GPU for tractography is as simple as (1) Installing GPUStreamlines using `pip install` and -(2) passing in the ``tractography_ngpus`` parameter when you create your +(2) passing in the ``jit_backend`` parameter when you create your GroupAFQ object. To install GPUStreamlines, do: `pip install git+https://github.com/dipy/GPUStreamlines.git` That's step 1 complete! The rest of this example is the same as the GroupAFQ -example except with the ``tractography_ngpus`` parameter set. +example except with the ``jit_backend`` parameter set. """ from AFQ.api.group import GroupAFQ @@ -26,24 +26,33 @@ afd.organize_stanford_data() -tracking_params = dict(n_seeds=1000000, + +########################################################################## +# Set tractography parameters +# --------------------------- +# We make create a `tracking_params` variable to define the parameters for tractography. +# The only parameter we need to set to use the GPU is `jit_backend`, +# which we set to "cuda". Other backends include: "metal", "webgpu", or "numba". +# Numba is the default. +# Note that the GPU backend will only run for probabilistic tracking, +# which is the default. + +tracking_params = dict(n_seeds=1e7, random_seeds=True, rng_seed=2025, + jit_backend="cuda", trx=True) ###################### # Running with the GPU # -------------------- -# We will use the GPU for tractography. This is done by -# passing in `tractography_ngpus` +# Then, run pyAFQ normally. # That's it! myafq = GroupAFQ( bids_path=op.join(afd.afq_home, 'stanford_hardi'), dwi_preproc_pipeline='vistasoft', t1_preproc_pipeline='freesurfer', - tracking_params=tracking_params, - tractography_ngpus=1) + tracking_params=tracking_params) -# From here, pyAFQ should run normally bundle_html = myafq.export("all_bundles_figure") plotly.io.show(bundle_html["01"][0]) diff --git a/examples/tutorial_examples/plot_001_group_afq_api.py b/examples/tutorial_examples/plot_001_group_afq_api.py index 79c354e1..5ab1e101 100644 --- a/examples/tutorial_examples/plot_001_group_afq_api.py +++ b/examples/tutorial_examples/plot_001_group_afq_api.py @@ -52,11 +52,11 @@ # Set tractography parameters (optional) # --------------------------------------- # We make create a `tracking_params` variable, which we will pass to the -# GroupAFQ object which specifies that we want 500,000 seeds randomly +# GroupAFQ object which specifies that we want 100,000 seeds randomly # distributed in the white matter. We only do this to make this example faster # and consume less space; normally, we use more seeds. -tracking_params = dict(n_seeds=500000, +tracking_params = dict(n_seeds=1e5, random_seeds=True, rng_seed=2025, trx=True) @@ -262,11 +262,9 @@ "NDARAA948VFH"]["HBNsiteRU"], index_col=[0]) for ind in bundle_counts.index: if ind == "Total Recognized": - threshold = 3000 - elif "Vertical Occipital" in ind: - threshold = 5 + threshold = 3e4 else: - threshold = 15 + threshold = 30 if bundle_counts["n_streamlines"][ind] < threshold: raise ValueError(( "Small number of streamlines found " diff --git a/examples/tutorial_examples/plot_002_participant_afq_api.py b/examples/tutorial_examples/plot_002_participant_afq_api.py index 4fb195a9..b810564d 100644 --- a/examples/tutorial_examples/plot_002_participant_afq_api.py +++ b/examples/tutorial_examples/plot_002_participant_afq_api.py @@ -70,11 +70,11 @@ # Set tractography parameters (optional) # --------------------------------------- # We make create a `tracking_params` variable, which we will pass to the -# ParticipantAFQ object which specifies that we want 500,000 seeds randomly +# ParticipantAFQ object which specifies that we want 100,000 seeds randomly # distributed in the white matter. We only do this to make this example faster # and consume less space; normally, we use more seeds. -tracking_params = dict(n_seeds=500000, +tracking_params = dict(n_seeds=1e5, random_seeds=True, rng_seed=2025, trx=True) From 4ce9b2da56af153a648c8b45cf67a4a6e16c3b60 Mon Sep 17 00:00:00 2001 From: 36000 Date: Thu, 7 May 2026 21:54:18 -0700 Subject: [PATCH 24/26] asym filtering norm fix --- AFQ/models/asym_filtering.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/AFQ/models/asym_filtering.py b/AFQ/models/asym_filtering.py index 5e1604d7..bfe07eb7 100644 --- a/AFQ/models/asym_filtering.py +++ b/AFQ/models/asym_filtering.py @@ -232,7 +232,6 @@ def _unified_filter_build_nx( # the direction controls the align weight if i == j == k == 0 or disable_align: # hack for main direction to have maximal weight - # w_align = np.ones((1, len(directions)), dtype=np.float32) w_align = np.zeros((1, len(directions)), dtype=np.float32) else: dxy /= len_xy @@ -252,7 +251,7 @@ def _unified_filter_build_nx( for ui in range(len(directions)): w_sum = np.sum(nx_weights[..., ui]) - nx_weights /= w_sum + nx_weights[..., ui] /= w_sum return nx_weights From c0e127d0f35cfbe188de96de8e56cb132088844a Mon Sep 17 00:00:00 2001 From: 36000 Date: Fri, 8 May 2026 11:03:05 -0700 Subject: [PATCH 25/26] organize aodf code, update group example thresholds --- AFQ/models/asym_filtering.py | 4 +- AFQ/tests/test_csd.py | 52 +++++++++++++++++++ .../plot_001_group_afq_api.py | 2 +- 3 files changed, 55 insertions(+), 3 deletions(-) diff --git a/AFQ/models/asym_filtering.py b/AFQ/models/asym_filtering.py index bfe07eb7..13dbef8f 100644 --- a/AFQ/models/asym_filtering.py +++ b/AFQ/models/asym_filtering.py @@ -47,7 +47,7 @@ def unified_filtering( sh_data, sphere, sh_basis="descoteaux07", - is_legacy=False, + is_legacy=True, sigma_spatial=1.0, sigma_align=0.8, sigma_angle=None, @@ -667,7 +667,7 @@ def compute_nufid_asym(sh_coeffs, sphere, csf, mask): sh_order_max=sh_order, basis_type="descoteaux07", full_basis=full_basis, - legacy=False, + legacy=True, ) # Guess at threshold from 2.0 * mean of ODF maxes in CSF diff --git a/AFQ/tests/test_csd.py b/AFQ/tests/test_csd.py index a07f65b6..82baa3d4 100644 --- a/AFQ/tests/test_csd.py +++ b/AFQ/tests/test_csd.py @@ -5,9 +5,20 @@ import nibabel.tmpdirs as nbtmp import numpy as np import numpy.testing as npt +import pytest from dipy.reconst.shm import calculate_max_order from AFQ.models import csd +from AFQ.models.asym_filtering import unified_filtering as pyafq_unified_filtering + +try: + from scilpy.denoise.asym_filtering import ( + unified_filtering as scilpy_unified_filtering, + ) + + has_scilpy = True +except ImportError: + has_scilpy = False def test_fit_csd(): @@ -30,3 +41,44 @@ def test_fit_csd(): npt.assert_(op.exists(fname)) sh_coeffs_img = nib.load(fname) npt.assert_equal(sh_order_max, calculate_max_order(sh_coeffs_img.shape[-1])) + + +# Note we do not want to run this by default, Scilpy +# Has many specific requirements for dependency versions +# That we do not want to interfere with pyAFQ testing generally +@pytest.mark.skipif(not has_scilpy, reason="scilpy is not installed") +def test_afod(): + fdata, fbval, fbvec = dpd.get_fnames("small_64D") + sphere = dpd.get_sphere("repulsion100") + with nbtmp.InTemporaryDirectory() as tmpdir: + # Convert from npy to txt: + bvals = np.loadtxt(fbval) + bvecs = np.loadtxt(fbvec) + np.savetxt(op.join(tmpdir, "bvals.txt"), bvals) + np.savetxt(op.join(tmpdir, "bvecs.txt"), bvecs) + fname = csd.fit_csd( + str(fdata), + op.join(tmpdir, "bvals.txt"), + op.join(tmpdir, "bvecs.txt"), + out_dir=tmpdir, + sh_order_max=6, + ) + + npt.assert_(op.exists(fname)) + sh_coeffs_img = nib.load(fname) + + aodf_pyafq = pyafq_unified_filtering( + sh_coeffs_img.get_fdata(), + sphere, + ) + + aodf_scilpy = scilpy_unified_filtering( + sh_coeffs_img.get_fdata(), + 6, + "descoteaux07", + True, + False, + "repulsion100", + ) + + npt.assert_allclose(aodf_pyafq, aodf_scilpy, atol=1e-6) diff --git a/examples/tutorial_examples/plot_001_group_afq_api.py b/examples/tutorial_examples/plot_001_group_afq_api.py index 5ab1e101..14949dcd 100644 --- a/examples/tutorial_examples/plot_001_group_afq_api.py +++ b/examples/tutorial_examples/plot_001_group_afq_api.py @@ -264,7 +264,7 @@ if ind == "Total Recognized": threshold = 3e4 else: - threshold = 30 + threshold = 20 if bundle_counts["n_streamlines"][ind] < threshold: raise ValueError(( "Small number of streamlines found " From 01a6ffabb90c0ba85061047b669ac4c065eadb1e Mon Sep 17 00:00:00 2001 From: 36000 Date: Fri, 8 May 2026 12:12:08 -0700 Subject: [PATCH 26/26] add ptt back in --- AFQ/tasks/tractography.py | 5 ++++- AFQ/tractography/tractography.py | 30 ++++++++++++++++++++++++------ 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/AFQ/tasks/tractography.py b/AFQ/tasks/tractography.py index e5ea36b4..8f6a565a 100644 --- a/AFQ/tasks/tractography.py +++ b/AFQ/tasks/tractography.py @@ -97,7 +97,10 @@ def streamlines( **this_tracking_params, ) - if this_tracking_params["directions"] == "prob": + if ( + this_tracking_params["directions"] == "prob" + or this_tracking_params["directions"] == "ptt" + ): # We do not count these as we go yet, # this needs to be implemented in GPUStreamlines n_streamlines = 0 diff --git a/AFQ/tractography/tractography.py b/AFQ/tractography/tractography.py index 62f6bbbf..9057ad45 100644 --- a/AFQ/tractography/tractography.py +++ b/AFQ/tractography/tractography.py @@ -126,10 +126,12 @@ def track( (i.e., as a LazyTractogram class instance). Default: True jit_backend : str, optional - If directions is "prob", the JIT backend to use. One of {"auto", - "cuda", "metal", "webgpu", or "numba"}. Default: "numba" + If directions is "prob" or "ptt", the JIT backend to use. + One of {"auto", "cuda", "metal", "webgpu", or "numba"}. + Default: "numba" jit_chunk_size : int, optional - If directions is "prob", the chunk size to use for JIT tracking. + If directions is "prob" or "ptt", the chunk size to use + for JIT tracking. Default: 25000 Returns @@ -262,17 +264,21 @@ def track( params_img.affine, ) - if directions == "prob": + if directions == "prob" or directions == "ptt": jit_backend = jit_backend.lower() if jit_backend == "auto": from cuslines import ( ProbDirectionGetter, + PttDirectionGetter, Tracker, ) elif jit_backend == "cuda": + from cuslines.cuda_python import ( + GPUTracker as Tracker, + ) from cuslines.cuda_python import ( ProbDirectionGetter, - Tracker, + PttDirectionGetter, ) elif jit_backend == "metal": from cuslines.metal import ( @@ -281,10 +287,16 @@ def track( from cuslines.metal import ( MetalProbDirectionGetter as ProbDirectionGetter, ) + from cuslines.metal import ( + MetalPttDirectionGetter as PttDirectionGetter, + ) elif jit_backend == "webgpu": from cuslines.webgpu import ( WebGPUProbDirectionGetter as ProbDirectionGetter, ) + from cuslines.webgpu import ( + WebGPUPttDirectionGetter as PttDirectionGetter, + ) from cuslines.webgpu import ( WebGPUTracker as Tracker, ) @@ -292,6 +304,9 @@ def track( from cuslines.numba import ( CPUProbDirectionGetter as ProbDirectionGetter, ) + from cuslines.numba import ( + CPUPttDirectionGetter as PttDirectionGetter, + ) from cuslines.numba import ( CPUTracker as Tracker, ) @@ -301,7 +316,10 @@ def track( f"'metal', 'numba', or 'webgpu', not {jit_backend}" ) - dg = ProbDirectionGetter() + if directions == "ptt": + dg = PttDirectionGetter() + else: + dg = ProbDirectionGetter() inv_affine = np.linalg.inv(params_img.affine) seeds = np.dot(seeds, inv_affine[:3, :3].T)