Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,26 @@
# Author: Kristupas Pranckietis, Vilnius University 05/2024
# Author: Nopphakorn Subsa-Ard, King Mongkut's University of Technology Thonburi (KMUTT) (TH) 08/2024
# Author: Vincenzo Eduardo Padulano, CERN 10/2024
# Author: Martin Føll, University of Oslo (UiO) & CERN 05/2025
# Author: Martin Føll, University of Oslo (UiO) & CERN 01/2026

################################################################################
# Copyright (C) 1995-2025, Rene Brun and Fons Rademakers. #
# Copyright (C) 1995-2026, Rene Brun and Fons Rademakers. #
# All rights reserved. #
# #
# For the licensing terms see $ROOTSYS/LICENSE. #
# For the list of contributors see $ROOTSYS/README/CREDITS. #
################################################################################

from __future__ import annotations

from typing import Any, Callable, Tuple, TYPE_CHECKING
import atexit

Check failure on line 18 in bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_batchgenerator.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_batchgenerator.py:15:1: I001 Import block is un-sorted or un-formatted

if TYPE_CHECKING:
import numpy as np
import tensorflow as tf
import torch
import ROOT

Check failure on line 24 in bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_batchgenerator.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_batchgenerator.py:21:5: I001 Import block is un-sorted or un-formatted


class BaseGenerator:
Expand Down Expand Up @@ -82,10 +82,10 @@

def __init__(
self,
rdataframe: ROOT.RDF.RNode,
batch_size: int,
chunk_size: int,
block_size: int,
rdataframes: ROOT.RDF.RNode | list[ROOT.RDF.RNode] = list(),
batch_size: int = 0,
chunk_size: int = 0,
block_size: int = 0,
columns: list[str] = list(),
max_vec_sizes: dict[str, int] = dict(),
vec_padding: int = 0,
Expand All @@ -96,6 +96,8 @@
shuffle: bool = True,
drop_remainder: bool = True,
set_seed: int = 0,
load_eager: bool = False,
sampling_type: str = "random",
):
"""Wrapper around the Cpp RBatchGenerator

Expand All @@ -105,6 +107,10 @@
chunk_size (int):
The size of the chunks loaded from the ROOT file. Higher chunk size
results in better randomization, but also higher memory usage.
block_size (int):
The size of the blocks of consecutive entries from the dataframe.
A chunk is build up from multiple blocks. Lower block size results in
a better randomization, but also higher memory usage.
columns (list[str], optional):
Columns to be returned. If not given, all columns are used.
max_vec_sizes (dict[std, int], optional):
Expand Down Expand Up @@ -134,13 +140,20 @@
For reproducibility: Set the seed for the random number generator used
to split the dataset into training and validation and shuffling of the chunks
Defaults to 0 which means that the seed is set to the random device.
load_eager (bool):
Load the full dataframe(s) into memory (True) or
load chunks from the dataframe into memory (False).
Defuaults to False.
sampling_type (str):
Describes the mode of sampling from the dataframe(s). Options: 'random'.
Defaults to 'random' and requires load_eager = True.
"""

import ROOT

Check failure on line 152 in bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_batchgenerator.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_batchgenerator.py:152:16: F401 `ROOT` imported but unused
from ROOT import RDF

try:
import numpy as np

Check failure on line 156 in bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_batchgenerator.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_batchgenerator.py:156:29: F401 `numpy` imported but unused; consider using `importlib.util.find_spec` to test for availability

except ImportError:
raise ImportError(
Expand All @@ -148,7 +161,7 @@
using RBatchGenerator"
)

if chunk_size < batch_size:
if load_eager == False and chunk_size < batch_size:

Check failure on line 164 in bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_batchgenerator.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E712)

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_batchgenerator.py:164:12: E712 Avoid equality comparisons to `False`; use `not load_eager:` for false checks
raise ValueError(
f"chunk_size cannot be smaller than batch_size: chunk_size: \
{chunk_size}, batch_size: {batch_size}"
Expand All @@ -160,7 +173,9 @@
given value is {validation_split}"
)

self.noded_rdf = RDF.AsRNode(rdataframe)
if not isinstance(rdataframes, list):
rdataframes = [rdataframes]
self.noded_rdfs = [RDF.AsRNode(rdf) for rdf in rdataframes]

if isinstance(target, str):
target = [target]
Expand All @@ -169,7 +184,7 @@
self.weights_column = weights

template, max_vec_sizes_list = self.get_template(
rdataframe, columns, max_vec_sizes
rdataframes[0], columns, max_vec_sizes
)

self.num_columns = len(self.all_columns)
Expand Down Expand Up @@ -222,7 +237,7 @@
EnableThreadSafety()

self.generator = TMVA.Experimental.Internal.RBatchGenerator(template)(
self.noded_rdf,
self.noded_rdfs,
chunk_size,
block_size,
batch_size,
Expand All @@ -234,6 +249,8 @@
shuffle,
drop_remainder,
set_seed,
load_eager,
sampling_type,
)

atexit.register(self.DeActivate)
Expand Down Expand Up @@ -317,7 +334,7 @@
np.zeros((self.batch_size)).reshape(-1, 1),
)

def ConvertBatchToNumpy(self, batch: "RTensor") -> np.ndarray:

Check failure on line 337 in bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_batchgenerator.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_batchgenerator.py:337:43: F821 Undefined name `RTensor`
"""Convert a RTensor into a NumPy array

Args:
Expand Down Expand Up @@ -368,8 +385,8 @@
Returns:
torch.Tensor: converted batch
"""
import torch
import numpy as np

Check failure on line 389 in bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_batchgenerator.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_batchgenerator.py:388:9: I001 Import block is un-sorted or un-formatted

data = batch.GetData()
batch_size, num_columns = tuple(batch.GetShape())
Expand Down Expand Up @@ -477,7 +494,7 @@
def __init__(self, base_generator: BaseGenerator):
self.base_generator = base_generator
# create training batches from the first chunk
self.base_generator.CreateTrainBatches();

Check failure on line 497 in bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_batchgenerator.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E703)

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_batchgenerator.py:497:49: E703 Statement ends with an unnecessary semicolon

def __enter__(self):
self.base_generator.ActivateTrainingEpoch()
Expand Down Expand Up @@ -652,10 +669,10 @@
return None

def CreateNumPyGenerators(
rdataframe: ROOT.RDF.RNode,
batch_size: int,
chunk_size: int,
block_size: int,
rdataframes: ROOT.RDF.RNode | list[ROOT.RDF.RNode] = list(),
batch_size: int = 0,
chunk_size: int = 0,
block_size: int = 0,
columns: list[str] = list(),
max_vec_sizes: dict[str, int] = dict(),
vec_padding: int = 0,
Expand All @@ -666,6 +683,8 @@
shuffle: bool = True,
drop_remainder=True,
set_seed: int = 0,
load_eager: bool = False,
sampling_type: str = "random",
) -> Tuple[TrainRBatchGenerator, ValidationRBatchGenerator]:
"""
Return two batch generators based on the given ROOT file and tree or RDataFrame
Expand All @@ -678,6 +697,10 @@
chunk_size (int):
The size of the chunks loaded from the ROOT file. Higher chunk size
results in better randomization, but also higher memory usage.
block_size (int):
The size of the blocks of consecutive entries from the dataframe.
A chunk is build up from multiple blocks. Lower block size results in
a better randomization, but also higher memory usage.
columns (list[str], optional):
Columns to be returned. If not given, all columns are used.
max_vec_sizes (list[int], optional):
Expand Down Expand Up @@ -706,6 +729,20 @@
[4, 5, 6, 7] will be returned.
If drop_remainder = False, then three batches [0, 1, 2, 3],
[4, 5, 6, 7] and [8, 9] will be returned.
set_seed (int):
For reproducibility: Set the seed for the random number generator used
to split the dataset into training and validation and shuffling of the chunks
Defaults to 0 which means that the seed is set to the random device.
load_eager (bool):
Load the full dataframe(s) into memory (True) or
load chunks from the dataframe into memory (False).
Defuaults to False.
sampling_type (str):
Describes the mode of sampling from the dataframe(s). Options: 'random'.
Defaults to 'random' and requires load_eager = True.




Returns:
TrainRBatchGenerator or
Expand All @@ -718,10 +755,10 @@
validation generator will return no batches.
"""

import numpy as np

Check failure on line 758 in bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_batchgenerator.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_batchgenerator.py:758:21: F401 `numpy` imported but unused

base_generator = BaseGenerator(
rdataframe,
rdataframes,
batch_size,
chunk_size,
block_size,
Expand All @@ -734,7 +771,9 @@
max_chunks,
shuffle,
drop_remainder,
set_seed,
set_seed,
load_eager,
sampling_type,
)

train_generator = TrainRBatchGenerator(
Expand All @@ -752,10 +791,10 @@


def CreateTFDatasets(
rdataframe: ROOT.RDF.RNode,
batch_size: int,
chunk_size: int,
block_size: int,
rdataframes: ROOT.RDF.RNode | list[ROOT.RDF.RNode] = list(),
batch_size: int = 0,
chunk_size: int = 0,
block_size: int = 0,
columns: list[str] = list(),
max_vec_sizes: dict[str, int] = dict(),
vec_padding: int = 0,
Expand All @@ -765,7 +804,9 @@
max_chunks: int = 0,
shuffle: bool = True,
drop_remainder=True,
set_seed: int = 0,
set_seed: int = 0,
load_eager: bool = False,
sampling_type: str = "random",
) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
"""
Return two Tensorflow Datasets based on the given ROOT file and tree or RDataFrame
Expand All @@ -778,6 +819,10 @@
chunk_size (int):
The size of the chunks loaded from the ROOT file. Higher chunk size
results in better randomization, but also higher memory usage.
block_size (int):
The size of the blocks of consecutive entries from the dataframe.
A chunk is build up from multiple blocks. Lower block size results in
a better randomization, but also higher memory usage.
columns (list[str], optional):
Columns to be returned. If not given, all columns are used.
max_vec_sizes (list[int], optional):
Expand Down Expand Up @@ -806,6 +851,17 @@
[4, 5, 6, 7] will be returned.
If drop_remainder = False, then three batches [0, 1, 2, 3],
[4, 5, 6, 7] and [8, 9] will be returned.
set_seed (int):
For reproducibility: Set the seed for the random number generator used
to split the dataset into training and validation and shuffling of the chunks
Defaults to 0 which means that the seed is set to the random device.
load_eager (bool):
Load the full dataframe(s) into memory (True) or
load chunks from the dataframe into memory (False).
Defuaults to False.
sampling_type (str):
Describes the mode of sampling from the dataframe(s). Options: 'random'.
Defaults to 'random' and requires load_eager = True.

Returns:
TrainRBatchGenerator or
Expand All @@ -820,7 +876,7 @@
import tensorflow as tf

base_generator = BaseGenerator(
rdataframe,
rdataframes,
batch_size,
chunk_size,
block_size,
Expand All @@ -833,7 +889,9 @@
max_chunks,
shuffle,
drop_remainder,
set_seed,
set_seed,
load_eager,
sampling_type,
)

train_generator = TrainRBatchGenerator(
Expand Down Expand Up @@ -901,10 +959,10 @@


def CreatePyTorchGenerators(
rdataframe: ROOT.RDF.RNode,
batch_size: int,
chunk_size: int,
block_size: int,
rdataframes: ROOT.RDF.RNode | list[ROOT.RDF.RNode] = list(),
batch_size: int = 0,
chunk_size: int = 0,
block_size: int = 0,
columns: list[str] = list(),
max_vec_sizes: dict[str, int] = dict(),
vec_padding: int = 0,
Expand All @@ -914,7 +972,9 @@
max_chunks: int = 0,
shuffle: bool = True,
drop_remainder=True,
set_seed: int = 0,
set_seed: int = 0,
load_eager: bool = False,
sampling_type: str = "random",
) -> Tuple[TrainRBatchGenerator, ValidationRBatchGenerator]:
"""
Return two Tensorflow Datasets based on the given ROOT file and tree or RDataFrame
Expand All @@ -927,6 +987,10 @@
chunk_size (int):
The size of the chunks loaded from the ROOT file. Higher chunk size
results in better randomization, but also higher memory usage.
block_size (int):
The size of the blocks of consecutive entries from the dataframe.
A chunk is build up from multiple blocks. Lower block size results in
a better randomization, but also higher memory usage.
columns (list[str], optional):
Columns to be returned. If not given, all columns are used.
max_vec_sizes (list[int], optional):
Expand Down Expand Up @@ -955,6 +1019,17 @@
[4, 5, 6, 7] will be returned.
If drop_remainder = False, then three batches [0, 1, 2, 3],
[4, 5, 6, 7] and [8, 9] will be returned.
set_seed (int):
For reproducibility: Set the seed for the random number generator used
to split the dataset into training and validation and shuffling of the chunks
Defaults to 0 which means that the seed is set to the random device.
load_eager (bool):
Load the full dataframe(s) into memory (True) or
load chunks from the dataframe into memory (False).
Defuaults to False.
sampling_type (str):
Describes the mode of sampling from the dataframe(s). Options: 'random'.
Defaults to 'random' and requires load_eager = True.

Returns:
TrainRBatchGenerator or
Expand All @@ -967,7 +1042,7 @@
validation generator will return no batches.
"""
base_generator = BaseGenerator(
rdataframe,
rdataframes,
batch_size,
chunk_size,
block_size,
Expand All @@ -980,7 +1055,9 @@
max_chunks,
shuffle,
drop_remainder,
set_seed,
set_seed,
load_eager,
sampling_type,
)

train_generator = TrainRBatchGenerator(
Expand Down
Loading
Loading