diff --git a/python/python/lance/torch/data.py b/python/python/lance/torch/data.py index fd2be0da161..626b9b14f9a 100644 --- a/python/python/lance/torch/data.py +++ b/python/python/lance/torch/data.py @@ -193,7 +193,9 @@ def __init__( shard_granularity: Optional[Literal["fragment", "batch"]] = None, batch_readahead: int = 16, to_tensor_fn: Optional[ - Callable[[pa.RecordBatch], Union[dict[str, torch.Tensor], torch.Tensor]] + Callable[ + [pa.RecordBatch, ...], Union[dict[str, torch.Tensor], torch.Tensor] + ] ] = _to_tensor, sampler: Optional[Sampler] = None, auto_detect_rank: bool = True, @@ -236,6 +238,9 @@ def __init__( A function that samples the dataset. to_tensor_fn : callable, optional A function that converts a pyarrow RecordBatch to torch.Tensor. + Should accept a batch (RecordBatch or Dict[str, pa.Array]) as the first + argument, plus optional keyword arguments ``hf_converter`` and + ``use_blob_api``. auto_detect_rank: bool = True, optional If set true, the rank and world_size will be detected automatically. """