diff --git a/bindings/pyroot/pythonizations/python/ROOT/_pythonization/_ml_dataloader.py b/bindings/pyroot/pythonizations/python/ROOT/_pythonization/_ml_dataloader.py index 3e0030d243d28..9b0d778a360f1 100644 --- a/bindings/pyroot/pythonizations/python/ROOT/_pythonization/_ml_dataloader.py +++ b/bindings/pyroot/pythonizations/python/ROOT/_pythonization/_ml_dataloader.py @@ -175,19 +175,6 @@ def __init__( given value is {validation_split}" ) - if load_eager: - # TODO: overhead, check if we can improve the following lines - if sampling_type == "undersampling" and not replacement: - rdf_0 = rdataframes[0].Count().GetValue() - rdf_1 = rdataframes[1].Count().GetValue() - rdf_minor = min(rdf_0, rdf_1) - rdf_major = max(rdf_0, rdf_1) - if rdf_major < rdf_minor / sampling_ratio: - raise ValueError( - f"The sampling_ratio is too low: not enough entries in the majority class to sample from. \n \ - Choose sampling_ratio > {round(rdf_minor / rdf_major, 3)} or set replacement to False." - ) - if not hasattr(rdataframes, "__iter__"): rdataframes = [rdataframes] self.noded_rdfs = [RDF.AsRNode(rdf) for rdf in rdataframes] diff --git a/bindings/pyroot/pythonizations/test/ml_dataloader.py b/bindings/pyroot/pythonizations/test/ml_dataloader.py index eb6326dc2f19e..5445114f08125 100644 --- a/bindings/pyroot/pythonizations/test/ml_dataloader.py +++ b/bindings/pyroot/pythonizations/test/ml_dataloader.py @@ -4683,8 +4683,37 @@ def test(size_of_batch, num_of_entries_major, num_of_entries_minor, sampling_rat self.teardown_file(file_name2) raise + def test_raises(size_of_batch, num_of_entries_major, num_of_entries_minor, sampling_ratio): + define_rdf_major(num_of_entries_major, file_name1) + define_rdf_minor(num_of_entries_minor, file_name2) + + df1 = ROOT.RDataFrame(tree_name, file_name1) + df2 = ROOT.RDataFrame(tree_name, file_name2) + + with self.assertRaisesRegex( + Exception, r"The sampling_ratio is too low: not enough entries in the majority class to sample from." + ): + ROOT.Experimental.ML.CreateNumPyGenerators( + [df1, df2], + batch_size=size_of_batch, + target=["b3", "b5"], + weights="b1", + validation_split=0.3, + shuffle=False, + drop_remainder=False, + load_eager=True, + sampling_type="undersampling", + sampling_ratio=sampling_ratio, + replacement=False, + ) + + # test the functionality with a proper sampling ratio test(batch_size, entries_in_rdf_major, entries_in_rdf_minor, sampling_ratio) + bad_sampling_ratio = round(max(min_allowed_sampling_ratio - 0.01, 0.01), 2) + # test that an error is raised when the sampling ratio is too low + test_raises(batch_size, entries_in_rdf_major, entries_in_rdf_minor, bad_sampling_ratio) + def test14_big_data_replacement_true(self): file_name1 = "big_data_major.root" file_name2 = "big_data_minor.root" diff --git a/tree/ml/inc/ROOT/ML/RSampler.hxx b/tree/ml/inc/ROOT/ML/RSampler.hxx index 0dbbe812c6278..2c5f0e52723a4 100644 --- a/tree/ml/inc/ROOT/ML/RSampler.hxx +++ b/tree/ml/inc/ROOT/ML/RSampler.hxx @@ -103,6 +103,14 @@ public: fNumMajor = fDatasets[fMajor].GetRows(); fNumMinor = fDatasets[fMinor].GetRows(); fNumResampledMajor = static_cast(fNumMinor / fSampleRatio); + if (!fReplacement && fNumResampledMajor > fNumMajor) { + auto minRatio = std::to_string(std::round(double(fNumMinor) / double(fNumMajor) * 100.0) / 100.0); + minRatio.erase(minRatio.find('.') + 3); + throw std::invalid_argument( + "The sampling_ratio is too low: not enough entries in the majority class to sample from.\n" + "Choose sampling_ratio > " + + minRatio + " or set replacement to True."); + } fNumEntries = fNumMinor + fNumResampledMajor; }