diff --git a/cuslines/generic_tracker.py b/cuslines/generic_tracker.py index ec308aa..b2dbcb4 100644 --- a/cuslines/generic_tracker.py +++ b/cuslines/generic_tracker.py @@ -3,6 +3,7 @@ from tqdm import tqdm from trx.trx_file_memmap import TrxFile from dipy.io.stateful_tractogram import Space, StatefulTractogram +from dipy.tracking.streamlinespeed import compress_streamlines from nibabel.streamlines.array_sequence import ArraySequence from nibabel.streamlines.tractogram import Tractogram @@ -16,11 +17,55 @@ def __enter__(self): def __exit__(self, exc_type, exc, tb): return False + def set_compression_parameters(self, pos_dtype=np.float32, linearize=False, tol_error=0.1, max_segment_length=10): + """ + Set compression parameters to compress generated streamlines. + Only works with TRX. + + Parameters + ---------- + pos_dtype : dtype, optional + Data type to use for the positions of the streamlines. + Default: np.float32 + + linearize : bool, optional + Whether to linearize the streamlines using [1]. + Default: False + + tol_error : float, optional + If linearize is true, tolerance error in mm. + Default: 0.1 + + max_segment_length : float, optional + If linearize is true, maximum length in mm of any given segment produced by the compression. + Default: 10 + + References + ---------- + [1] Caroline Presseau, Pierre-Marc Jodoin, Jean-Christophe Houde, and Maxime Descoteaux. + A new compression format for fiber tracking datasets. + NeuroImage, 109:73-83, 2015. URL: 10.1016/j.neuroimage.2014.12.058 + """ + self.pos_dtype = pos_dtype + self.linearize = linearize + self.tol_error = tol_error + self.max_segment_length = max_segment_length + + def _ngpus(self): - if hasattr(self, "ngpus"): - return self.ngpus - else: - return 1 + return getattr(self, "ngpus", 1) + + def _pos_dtype(self): + return getattr(self, "pos_dtype", np.float16) + + def _linearize(self): + return getattr(self, "linearize", False) + + def _tol_error(self): + return getattr(self, "tol_error", 0.1) + + def _max_segment_length(self): + return getattr(self, "max_segment_length", np.inf) def _divide_chunks(self, seeds): global_chunk_sz = self.chunk_size * self._ngpus() @@ -58,7 +103,7 @@ def generate_trx(self, seeds, ref_img): # trx files use memory mapping trx_reference = TrxFile(reference=ref_img) trx_reference.streamlines._data = trx_reference.streamlines._data.astype( - np.float32 + self._pos_dtype() ) trx_reference.streamlines._offsets = trx_reference.streamlines._offsets.astype( np.uint64 @@ -87,6 +132,14 @@ def generate_trx(self, seeds, ref_img): tractogram.to_world() sls = tractogram.streamlines + if self._linearize(): + sls = ArraySequence(compress_streamlines( + sls, + tol_error=self._tol_error(), + max_segment_length=self._max_segment_length(), + )) + sls._data = sls._data.astype(self._pos_dtype()) + new_offsets_idx = offsets_idx + len(sls._offsets) new_sls_data_idx = sls_data_idx + len(sls._data) diff --git a/run_gpu_streamlines.py b/run_gpu_streamlines.py index 7776763..0484a24 100644 --- a/run_gpu_streamlines.py +++ b/run_gpu_streamlines.py @@ -127,7 +127,7 @@ def get_img(ep2_seq): parser.add_argument( "--ngpus", type=int, default=1, help="number of GPUs to use if using gpu" ) -parser.add_argument("--write-method", type=str, default="trk", help="Can be trx or trk") +parser.add_argument("--write-method", type=str, default="trx", help="Can be trx or trk") parser.add_argument( "--max-angle", type=float, default=60, help="max angle (in degrees)" )