Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -175,19 +175,6 @@ def __init__(
given value is {validation_split}"
)

if load_eager:
# TODO: overhead, check if we can improve the following lines
if sampling_type == "undersampling" and not replacement:
rdf_0 = rdataframes[0].Count().GetValue()
rdf_1 = rdataframes[1].Count().GetValue()
rdf_minor = min(rdf_0, rdf_1)
rdf_major = max(rdf_0, rdf_1)
if rdf_major < rdf_minor / sampling_ratio:
raise ValueError(
f"The sampling_ratio is too low: not enough entries in the majority class to sample from. \n \
Choose sampling_ratio > {round(rdf_minor / rdf_major, 3)} or set replacement to False."
)

if not hasattr(rdataframes, "__iter__"):
rdataframes = [rdataframes]
self.noded_rdfs = [RDF.AsRNode(rdf) for rdf in rdataframes]
Expand Down
29 changes: 29 additions & 0 deletions bindings/pyroot/pythonizations/test/ml_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4683,8 +4683,37 @@ def test(size_of_batch, num_of_entries_major, num_of_entries_minor, sampling_rat
self.teardown_file(file_name2)
raise

def test_raises(size_of_batch, num_of_entries_major, num_of_entries_minor, sampling_ratio):
define_rdf_major(num_of_entries_major, file_name1)
define_rdf_minor(num_of_entries_minor, file_name2)

df1 = ROOT.RDataFrame(tree_name, file_name1)
df2 = ROOT.RDataFrame(tree_name, file_name2)

with self.assertRaisesRegex(
Exception, r"The sampling_ratio is too low: not enough entries in the majority class to sample from."
):
ROOT.Experimental.ML.CreateNumPyGenerators(
[df1, df2],
batch_size=size_of_batch,
target=["b3", "b5"],
weights="b1",
validation_split=0.3,
shuffle=False,
drop_remainder=False,
load_eager=True,
sampling_type="undersampling",
sampling_ratio=sampling_ratio,
replacement=False,
)

# test the functionality with a proper sampling ratio
test(batch_size, entries_in_rdf_major, entries_in_rdf_minor, sampling_ratio)

bad_sampling_ratio = round(max(min_allowed_sampling_ratio - 0.01, 0.01), 2)
# test that an error is raised when the sampling ratio is too low
test_raises(batch_size, entries_in_rdf_major, entries_in_rdf_minor, bad_sampling_ratio)

def test14_big_data_replacement_true(self):
file_name1 = "big_data_major.root"
file_name2 = "big_data_minor.root"
Expand Down
8 changes: 8 additions & 0 deletions tree/ml/inc/ROOT/ML/RSampler.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,14 @@ public:
fNumMajor = fDatasets[fMajor].GetRows();
fNumMinor = fDatasets[fMinor].GetRows();
fNumResampledMajor = static_cast<std::size_t>(fNumMinor / fSampleRatio);
if (!fReplacement && fNumResampledMajor > fNumMajor) {
auto minRatio = std::to_string(std::round(double(fNumMinor) / double(fNumMajor) * 100.0) / 100.0);
minRatio.erase(minRatio.find('.') + 3);
throw std::invalid_argument(
"The sampling_ratio is too low: not enough entries in the majority class to sample from.\n"
"Choose sampling_ratio > " +
minRatio + " or set replacement to True.");
}
fNumEntries = fNumMinor + fNumResampledMajor;
}

Expand Down
Loading