Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions chirho/observational/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ def observe(rv, obs: Optional[Observation[T]] = None, **kwargs) -> T:
raise NotImplementedError(f"observe not implemented for type {type(rv)}")


class ExcisionError(ValueError):
pass


class ExcisedNormal(TorchDistribution):
"""
A normal distribution with specified intervals excised (removed).
Expand Down Expand Up @@ -140,7 +144,7 @@ def __init__(
self._removed_pr_mass += interval_mass

if torch.any(self._removed_pr_mass >= 1.0):
raise ValueError("Total probability mass in excised intervals >= 1.0!")
raise ExcisionError("Total probability mass in excised intervals >= 1.0!")

self._normalization_constant = torch.ones_like(self._base_loc) - self._removed_pr_mass

Expand Down Expand Up @@ -307,7 +311,7 @@ def __init__(
ratio_all_neg_inf = num_all_neg_inf / all_neg_inf.numel() # <--- define ratio

if num_all_neg_inf > 0:
raise ValueError(
raise ExcisionError(
f"{num_all_neg_inf} batch elements ({ratio_all_neg_inf:.2%}) "
"have all logits excised (-inf); cannot sample from these elements."
)
Expand Down
6 changes: 3 additions & 3 deletions tests/observational/test_excised.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import torch

from chirho.observational.ops import ExcisedCategorical, ExcisedNormal
from chirho.observational.ops import ExcisedCategorical, ExcisedNormal, ExcisionError


# needed for testing interval CDFs
Expand Down Expand Up @@ -260,6 +260,6 @@ def test_excised_categorical_all_excised_error_category_indices():
)
]

# Expect ValueError because second batch element is fully excised
with pytest.raises(ValueError, match="have all logits excised"):
# Expect error because second batch element is fully excised
with pytest.raises(ExcisionError):
ExcisedCategorical(intervals=intervals, logits=logits)
Loading