[ENH] add truncated_mean interface to BaseDistribution#995
[ENH] add truncated_mean interface to BaseDistribution#995patelchaitany wants to merge 1 commit into
Conversation
JiwaniZakir
left a comment
There was a problem hiding this comment.
The sampling-based fallback in _truncated_mean calls self.sample(1) inside a loop of approx_spl_size (default 1000) iterations, which is orders of magnitude slower than calling self.sample(approx_spl_size) once and filtering. The existing _mean fallback at the end of the file does exactly this with _sample_mean, so the pattern is already established — the loop approach here is inconsistent with the rest of the codebase and will be a significant performance bottleneck.
In _energy_x, the added docstring block (around line 1235) describes the truncated-mean identity used in the new code path, but it's inserted into the existing _energy_x docstring rather than being clearly separated from the original parameter/return docs. This makes the docstring structurally confusing — the formula section appears before the "Parameters" block, which is fine, but it references _truncated_mean and cdf without explaining that this only triggers when both are exact capabilities, which is important context for subclass implementors.
The zero_mass guard in the ppf-based approximation (np.abs(cdf_u - cdf_l) < 1e-15) correctly handles degenerate intervals, but the same guard is absent in the _energy_x shortcut path — if a point x lies exactly at a probability-0 region boundary, right_mean or left_mean could return nan and the nan_to_num(..., nan=0.0) silently masks what might be a meaningful divergence rather than a true zero-contribution interval.
fkiraly
left a comment
There was a problem hiding this comment.
Very nice!
Can you please add a test to TestAllDistributions that checks that the method works? I think the method just has to be added to the right places.
| if spl_arr.ndim == 1: | ||
| spl_arr = spl_arr.reshape(-1, 1, 1) | ||
| elif spl_arr.ndim == 2: | ||
| spl_arr = spl_arr.reshape(spl_arr.shape[0], 1, -1) |
There was a problem hiding this comment.
why is this correct? or is it a mistake?
There was a problem hiding this comment.
Yeah, those reshape branches are dead code on my end. sample(1) always returns a DataFrame, so .values gives 2D, stacking gives 3D neither ndim == 1 nor ndim == 2 ever fires. I added them as defensive guards but they're unnecessary. Happy to remove them.
dde2b4d to
da87e95
Compare
Signed-off-by: Chaitany patel <patelchaitany93@gmail.com>
da87e95 to
4c21cf3
Compare
|
@fkiraly, I have added truncated_mean to METHODS_SCALAR plus two new tests one checks that unbounded truncated_mean() matches mean(), the other verifies output format with actual bounds. |
fkiraly
left a comment
There was a problem hiding this comment.
Reviewed in more detail - this is very very nice!
I think we could in-principle merge, but I would like to discuss two things:
- have you considered a design where, instead of adding a new public API point
truncated_mean, you instead add two arguments tomean? Where, of course, internally, it could still link to_meanand_truncated_mean(or_mean_truncated), similar to howenergydispatches to two different private methods. I am not saying this is how we should do it, but I would like to see your thoughts on a discussion of pros/cons. - can you quickly give me a summary where your analytical formulae are coming from for the various distributions? I agree with the quadrature in the base class, but am asking for the various distributions. Did you ask AI; Wolfram; or did you do this on paper? In either case, it may be worth documenting the exact formula somewhere in an md file, so we can add it to the docstring once the docstring injection PR is resolved.
Reference Issues/PRs
Addresses #221. Closes #221.
What does this implement/fix? Explain your changes.
Adds truncated_mean(lower, upper) to BaseDistribution. Returns E[X | lower < X < upper]. Closed-form implementations for Normal, Exponential, Laplace, and Uniform. Other distributions fall back to a ppf-based numerical approximation.
The base class _energy_x now uses truncated means when available: if a distribution has exact truncated_mean and cdf, it computes E[|X - c|] from the energy identity instead of Monte Carlo sampling.
TruncatedDistribution gets a _mean() that calls the inner distribution's truncated_mean, so TruncatedDistribution(Normal(...), lower=0).mean() is exact now.
Changes
base/_base.py- newtruncated_mean/_truncated_meanwith ppf/MC default, updated_energy_xnormal.py- exact_truncated_meanexponential.py- exact_truncated_meanlaplace.py- exact_truncated_meanuniform.py- exact_truncated_meantruncated.py-_mean()via inner distribution'struncated_mean, updated tag logiDoes your contribution introduce a new dependency? If yes, which one?
No
What should a reviewer concentrate their feedback on?
Did you add any tests for the change?
No
Any other comments?
PR checklist
For all contributions
How to: add yourself to the all-contributors file in the
skproroot directory (not theCONTRIBUTORS.md). Common badges:code- fixing a bug, or adding code logic.doc- writing or improving documentation or docstrings.bug- reporting or diagnosing a bug (get this pluscodeif you also fixed the bug in the PR).maintenance- CI, test framework, release.See here for full badge reference
For new estimators
docs/source/api_reference/taskname.rst, follow the pattern.Examplessection.python_dependenciestag and ensureddependency isolation, see the estimator dependencies guide.