Skip to content

[BUG] Fix probability mass loss and temporal alignment in _SksurvAdapter predict_proba function#990

Open
MurtuzaShaikh26 wants to merge 2 commits into
sktime:mainfrom
MurtuzaShaikh26:fix-sksurv-mass-loss
Open

[BUG] Fix probability mass loss and temporal alignment in _SksurvAdapter predict_proba function#990
MurtuzaShaikh26 wants to merge 2 commits into
sktime:mainfrom
MurtuzaShaikh26:fix-sksurv-mass-loss

Conversation

@MurtuzaShaikh26
Copy link
Copy Markdown

Reference Issues/PRs

Fixes #958

What does this implement/fix? Explain your changes.

The predict_proba method in _SksurvAdapter (used by estimators like CoxPHSkSurv) calculated distribution weights via raw np.diff ignoring boundaries, resulting in massive probability drops and time regressions:

  1. Initial Mass Loss: The drop from 1.0 to the first timepoint was discarded.
  2. Tail Mass Loss: The remaining survival probability after the last valid timepoint was discarded instead of mapped to the last step, resulting in Empirical distribution weights summing to far less than 1.0.
  3. Temporal Alignment: It incorrectly stripped the last recorded survival times using [:-1], stripping context and effectively dragging all probability masses backwards in time to previous disconnected timesteps.

**Proposed Solution:
This PR replaces np.diff with _clip_surv from _common.py which leverages _surv_diff (applying prepend=1.0, append=0.0) handling both margin boundaries. This efficiently validates that every mass is conserved, monotonically scaled, and sums exactly to 1.0. We also preserve the full unmodified unique_times_ for perfectly matched event distribution timestamps.

Verification Script When tracking an estimator with survival probabilities [0.8, 0.5, 0.5] descending over timestamps [10.0, 20.0, 30.0]:
import numpy as np
import pandas as pd
from unittest.mock import MagicMock
from skpro.survival.adapters.sksurv import _SksurvAdapter

class MockSksurvAdapter(_SksurvAdapter):
    def _get_sksurv_class(self): return MagicMock()
    def get_params(self, deep=True): return {}

X = pd.DataFrame({"feature1": [1.0]})
adapter = MockSksurvAdapter()
adapter._estimator = MagicMock()
adapter._estimator.predict_survival_function = MagicMock(return_value=np.array([[0.8, 0.5, 0.5]]))
adapter._estimator.unique_times_ = np.array([10.0, 20.0, 30.0])
adapter._y_cols = ["time"]

dist = adapter._predict_proba(X)

print("\n--- RESULTS ---")
print("Times in resulting Empirical distribution:", dist.spl.values.flatten())
print("Weights in resulting Empirical distribution:", dist.weights.values)
print(f"Total mass: {dist.weights.sum()}")
image

Does your contribution introduce a new dependency? If yes, which one?

No.

What should a reviewer concentrate their feedback on?

  • Reviewing the np.diff/[:-1] replacement logic using _clip_surv to preserve the probability boundaries.
  • The new adapter test test_sksurv_adapter.py validating the resulting dist.weights, timeline boundary assignments, and asserting total mass strictly evaluates to 1.0.

Did you add any tests for the change?

Any other comments?

N/A

PR checklist

For all contributions
  • I've added myself to the list of contributors with any new badges I've earned :-)
  • The PR title starts with either [ENH], [MNT], [DOC], or [BUG].
For new estimators
  • I've added the estimator to the API reference
  • I've added one or more illustrative usage examples to the docstring
  • If the estimator relies on a soft dependency, I've set the python_dependencies tag

@fkiraly fkiraly added bug module:survival&time-to-event module for time-to-event prediction aka survival prediction labels Mar 26, 2026
Copy link
Copy Markdown
Collaborator

@fkiraly fkiraly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Can you comment why you did not use the _surf_diff function, as discussed in the issue?

@direkkakkar319-ops
Copy link
Copy Markdown

this PR looks similar to PR #970

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug module:survival&time-to-event module for time-to-event prediction aka survival prediction

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] _SksurvAdapter fails to preserve total probability mass and shifts distributions

3 participants