From 33deed694792f6ee9f3ba975771dccf2da9a2d27 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Mon, 13 Apr 2026 16:19:32 -0400 Subject: [PATCH 1/2] add ValueError subtype for excision --- chirho/observational/ops.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/chirho/observational/ops.py b/chirho/observational/ops.py index 4ac82b899..d8bddd1cf 100644 --- a/chirho/observational/ops.py +++ b/chirho/observational/ops.py @@ -27,6 +27,10 @@ def observe(rv, obs: Optional[Observation[T]] = None, **kwargs) -> T: raise NotImplementedError(f"observe not implemented for type {type(rv)}") +class ExcisionError(ValueError): + pass + + class ExcisedNormal(TorchDistribution): """ A normal distribution with specified intervals excised (removed). @@ -140,7 +144,7 @@ def __init__( self._removed_pr_mass += interval_mass if torch.any(self._removed_pr_mass >= 1.0): - raise ValueError("Total probability mass in excised intervals >= 1.0!") + raise ExcisionError("Total probability mass in excised intervals >= 1.0!") self._normalization_constant = torch.ones_like(self._base_loc) - self._removed_pr_mass @@ -307,7 +311,7 @@ def __init__( ratio_all_neg_inf = num_all_neg_inf / all_neg_inf.numel() # <--- define ratio if num_all_neg_inf > 0: - raise ValueError( + raise ExcisionError( f"{num_all_neg_inf} batch elements ({ratio_all_neg_inf:.2%}) " "have all logits excised (-inf); cannot sample from these elements." ) From 58a9fa11ba1946b2da263236eb9755fb69ca60b7 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Mon, 13 Apr 2026 16:20:36 -0400 Subject: [PATCH 2/2] make test more precise --- tests/observational/test_excised.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/observational/test_excised.py b/tests/observational/test_excised.py index bee9e2eef..c9ddaba6a 100644 --- a/tests/observational/test_excised.py +++ b/tests/observational/test_excised.py @@ -1,7 +1,7 @@ import pytest import torch -from chirho.observational.ops import ExcisedCategorical, ExcisedNormal +from chirho.observational.ops import ExcisedCategorical, ExcisedNormal, ExcisionError # needed for testing interval CDFs @@ -260,6 +260,6 @@ def test_excised_categorical_all_excised_error_category_indices(): ) ] - # Expect ValueError because second batch element is fully excised - with pytest.raises(ValueError, match="have all logits excised"): + # Expect error because second batch element is fully excised + with pytest.raises(ExcisionError): ExcisedCategorical(intervals=intervals, logits=logits)