Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 58 additions & 5 deletions cuslines/generic_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from tqdm import tqdm
from trx.trx_file_memmap import TrxFile
from dipy.io.stateful_tractogram import Space, StatefulTractogram
from dipy.tracking.streamlinespeed import compress_streamlines
from nibabel.streamlines.array_sequence import ArraySequence
from nibabel.streamlines.tractogram import Tractogram

Expand All @@ -16,11 +17,55 @@ def __enter__(self):
def __exit__(self, exc_type, exc, tb):
return False

def set_compression_parameters(self, pos_dtype=np.float32, linearize=False, tol_error=0.1, max_segment_length=10):
"""
Set compression parameters to compress generated streamlines.
Only works with TRX.

Parameters
----------
pos_dtype : dtype, optional
Data type to use for the positions of the streamlines.
Default: np.float32

linearize : bool, optional
Whether to linearize the streamlines using [1].
Default: False

tol_error : float, optional
If linearize is true, tolerance error in mm.
Default: 0.1

max_segment_length : float, optional
If linearize is true, maximum length in mm of any given segment produced by the compression.
Default: 10

References
----------
[1] Caroline Presseau, Pierre-Marc Jodoin, Jean-Christophe Houde, and Maxime Descoteaux.
A new compression format for fiber tracking datasets.
NeuroImage, 109:73-83, 2015. URL: 10.1016/j.neuroimage.2014.12.058
"""
self.pos_dtype = pos_dtype
self.linearize = linearize
self.tol_error = tol_error
self.max_segment_length = max_segment_length


def _ngpus(self):
if hasattr(self, "ngpus"):
return self.ngpus
else:
return 1
return getattr(self, "ngpus", 1)

def _pos_dtype(self):
return getattr(self, "pos_dtype", np.float16)

def _linearize(self):
return getattr(self, "linearize", False)

def _tol_error(self):
return getattr(self, "tol_error", 0.1)

def _max_segment_length(self):
return getattr(self, "max_segment_length", np.inf)

def _divide_chunks(self, seeds):
global_chunk_sz = self.chunk_size * self._ngpus()
Expand Down Expand Up @@ -58,7 +103,7 @@ def generate_trx(self, seeds, ref_img):
# trx files use memory mapping
trx_reference = TrxFile(reference=ref_img)
trx_reference.streamlines._data = trx_reference.streamlines._data.astype(
np.float32
self._pos_dtype()
)
trx_reference.streamlines._offsets = trx_reference.streamlines._offsets.astype(
np.uint64
Expand Down Expand Up @@ -87,6 +132,14 @@ def generate_trx(self, seeds, ref_img):
tractogram.to_world()
sls = tractogram.streamlines

if self._linearize():
sls = ArraySequence(compress_streamlines(
sls,
tol_error=self._tol_error(),
max_segment_length=self._max_segment_length(),
))
sls._data = sls._data.astype(self._pos_dtype())

new_offsets_idx = offsets_idx + len(sls._offsets)
new_sls_data_idx = sls_data_idx + len(sls._data)

Expand Down
2 changes: 1 addition & 1 deletion run_gpu_streamlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def get_img(ep2_seq):
parser.add_argument(
"--ngpus", type=int, default=1, help="number of GPUs to use if using gpu"
)
parser.add_argument("--write-method", type=str, default="trk", help="Can be trx or trk")
parser.add_argument("--write-method", type=str, default="trx", help="Can be trx or trk")
parser.add_argument(
"--max-angle", type=float, default=60, help="max angle (in degrees)"
)
Expand Down
Loading