diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index cdf2d43..4fe70a5 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -16,7 +16,7 @@ jobs: strategy: fail-fast: false matrix: - os: [macos-26, ubuntu-latest] + os: [macos-26] python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] steps: diff --git a/pySEQTarget/SEQopts.py b/pySEQTarget/SEQopts.py index ae0fec9..19becbd 100644 --- a/pySEQTarget/SEQopts.py +++ b/pySEQTarget/SEQopts.py @@ -25,7 +25,8 @@ class SEQopts: :param excused_colnames: Column names (at the same length of treatment_level) specifying excused conditions, default ``[]`` :param expand_only: If True, ``SEQuential.expand()`` returns the expanded dataset and skips weighting, modelling, and survival steps - :param glm_package: Backend for fitting logistic (outcome/competing-event) models ["statsmodels", "glum", or "jax"], default "statsmodels". + :param glm_package: Backend for fitting logistic (outcome/competing-event) + models ["statsmodels", "glum", or "jax"], default "statsmodels". :param followup_class: Boolean to force followup values to be treated as classes :param followup_include: Boolean to force regular followup values into model covariates :param followup_spline: Boolean to force followup values to be fit to cubic spline diff --git a/pySEQTarget/SEQoutput.py b/pySEQTarget/SEQoutput.py index cecf28c..bcce6a7 100644 --- a/pySEQTarget/SEQoutput.py +++ b/pySEQTarget/SEQoutput.py @@ -8,7 +8,7 @@ import polars as pl from statsmodels.base.wrapper import ResultsWrapper -from .helpers import _build_md, _build_pdf +from .helpers import Offloader, _build_md, _build_pdf from .SEQopts import SEQopts @@ -41,7 +41,13 @@ class SEQoutput: :type risk_difference: pl.DataFrame or None :param time: Timings for every step of the process completed thus far :type time: dict or None - :param diagnostic_tables: Diagnostic tables for unique and nonunique outcome events and treatment switches + :param diagnostic_tables: Diagnostic tables (outcome, follow-up, switch, and + competing-event counts where applicable), each split by baseline treatment + arm. The "unique" tables count distinct subjects; the "nonunique" tables + count rows: total outcome events for the outcome tables, and total + person-time intervals (expanded follow-up rows) for the follow-up tables. + For a one-time (terminal) outcome the unique and nonunique outcome counts + coincide, since each subject contributes at most one event row. :type diagnostic_tables: dict or None """ @@ -72,7 +78,10 @@ def plot(self) -> None: plt.show() def summary( - self, type=Optional[Literal["numerator", "denominator", "outcome", "compevent"]] + self, + type: Optional[ + Literal["numerator", "denominator", "outcome", "compevent"] + ] = None, ) -> List: """ Returns a list of model summaries of either the numerator, denominator, outcome, or competing event models @@ -90,11 +99,26 @@ def summary( case _: models = self.outcome_models - return [model.summary() for model in models if model is not None] + if models is None: + return [] + + # Under offload=True the stored entries are path refs; load them back. + loader = None + if self.options is not None and self.options.offload: + loader = Offloader(enabled=True, dir=self.options.offload_dir) + + summaries = [] + for model in models: + if model is None: + continue + if loader is not None: + model = loader.load_model(model) + summaries.append(model.summary()) + return summaries def retrieve_data( self, - type=Optional[ + type: Optional[ Literal[ "km_data", "hazard", @@ -109,11 +133,24 @@ def retrieve_data( "unique_switches", "nonunique_switches", ] - ], + ] = None, ) -> pl.DataFrame: """ Getter for data stored within ``SEQoutput`` + The diagnostic tables come in "unique" and "nonunique" variants that count + different things, each broken down by baseline treatment arm: + + - ``unique_outcomes`` / ``nonunique_outcomes``: distinct subjects who had + the outcome vs. the total number of outcome events. These coincide for a + one-time (terminal) outcome, since each subject contributes at most one + event row. + - ``unique_followup`` / ``nonunique_followup``: distinct subjects + contributing follow-up vs. the total number of person-time intervals + (expanded rows). The nonunique count is much larger because each subject + contributes one row per follow-up period; it is the denominator that, + with ``nonunique_outcomes``, gives the per-arm event rate. + :param type: Data which you would like to access, ['km_data', 'hazard', 'risk_ratio', 'risk_difference', 'unique_outcomes', 'nonunique_outcomes', 'unique_followup', 'nonunique_followup', @@ -141,19 +178,13 @@ def retrieve_data( case "nonunique_compevent": data = self.diagnostic_tables.get("nonunique_compevent") case "unique_switches": - if self.diagnostic_tables.has_key("unique_switches"): - data = self.diagnostic_tables["unique_switches"] - else: - data = None + data = self.diagnostic_tables.get("unique_switches") case "nonunique_switches": - if self.diagnostic_tables.has_key("nonunique_switches"): - data = self.diagnostic_tables["nonunique_switches"] - else: - data = None + data = self.diagnostic_tables.get("nonunique_switches") case _: data = self.km_data if data is None: - raise ValueError("Data {type} was not created in the SEQuential process") + raise ValueError(f"Data {type} was not created in the SEQuential process") return data def to_md(self, filename="SEQuential_results.md") -> None: diff --git a/pySEQTarget/SEQuential.py b/pySEQTarget/SEQuential.py index 111e92b..d979e67 100644 --- a/pySEQTarget/SEQuential.py +++ b/pySEQTarget/SEQuential.py @@ -64,8 +64,11 @@ def __init__( self.eligible_col = eligible_col self.treatment_col = treatment_col self.outcome_col = outcome_col - self.time_varying_cols = time_varying_cols - self.fixed_cols = fixed_cols + # Normalize the documented-Optional covariate lists to [] once, so the + # many downstream `for col in self.fixed_cols` / set() sites need no + # None guards. + self.time_varying_cols = time_varying_cols if time_varying_cols else [] + self.fixed_cols = fixed_cols if fixed_cols else [] self.method = method self._time_initialized = datetime.datetime.now() @@ -111,6 +114,16 @@ def __init__( _param_checker(self) _data_checker(self) + def __getstate__(self): + # The glum design-info cache (_outcome_fit) holds patsy DesignInfo + # objects, which can't be pickled (patsy #26). It is a per-process speed + # cache rebuilt lazily on first fit, so drop it when crossing a process + # boundary (parallel bootstrap / offload); workers repopulate it. Without + # this, parallel=True + glm_package="glum" crashes on pickling. + state = self.__dict__.copy() + state.pop("_patsy_design_cache", None) + return state + def expand(self): """ Creates the sequentially nested, emulated target trial structure. @@ -187,6 +200,17 @@ def expand(self): if self.verbose: n, m = self.DT.shape print(f"Final analysis dataset: {n:,} observations, {m} variables") + # Under censoring the outcome model is fit only on the un-censored + # rows (switch != 1, matching _outcome_fit); the rest are retained in + # the dataset but artificially censored. Report the split so the + # count lines up with implementations that print only the modelled + # rows (e.g. Stata seqtte). + if self.method == "censoring" and "switch" in self.DT.columns: + n_censored = self.DT.filter(pl.col("switch") == 1).height + print( + f" entering outcome model (uncensored): {n - n_censored:,}\n" + f" artificially censored (treatment switch): {n_censored:,}" + ) end = time.perf_counter() self._expansion_time = _format_time(start, end) @@ -250,32 +274,50 @@ def fit(self) -> None: boot_idx = self._current_boot_idx if self.weighted: - WDT_pl = _weight_setup(self) - if not self.weight_preexpansion and not self.excused: - WDT_pl = WDT_pl.filter(pl.col("followup") > 0) - - # The weight-fit helpers (_fit_LTFU etc.) use pandas-style indexing - # and pass pandas frames to glum/statsmodels, so we convert once. - # The fits don't mutate WDT_pd - they store models on `self` - so - # we keep the original polars frame for the downstream steps - # rather than paying a pl.from_pandas() round-trip per replicate. - WDT_pd = WDT_pl.to_pandas() - for col in self.fixed_cols: - if col in WDT_pd.columns: - WDT_pd[col] = WDT_pd[col].astype("category") - - _fit_LTFU(self, WDT_pd) - _fit_visit(self, WDT_pd) - _fit_numerator(self, WDT_pd) - _fit_denominator(self, WDT_pd) - - if self.offload: - _offload_weights(self, boot_idx) - - del WDT_pd - WDT = _weight_predict(self, WDT_pl) - _weight_bind(self, WDT) - self.weight_stats = _weight_stats(self) + # With weight_preexpansion the weight models are fit on the + # un-resampled pre-expansion data, so every bootstrap replicate + # would refit bit-identical models and re-predict identical + # weights. Cache the predicted weight frame from the main fit and + # reuse it on replicates; only the join onto the resampled DT + # (_weight_bind) and the resulting weight stats differ. + cached_WDT = ( + getattr(self, "_main_weight_WDT", None) + if boot_idx is not None and self.weight_preexpansion + else None + ) + if cached_WDT is not None: + _weight_bind(self, cached_WDT) + self.weight_stats = _weight_stats(self) + else: + WDT_pl = _weight_setup(self) + if not self.weight_preexpansion and not self.excused: + WDT_pl = WDT_pl.filter(pl.col("followup") > 0) + + # The weight-fit helpers (_fit_LTFU etc.) use pandas-style + # indexing and pass pandas frames to glum/statsmodels, so we + # convert once. The fits don't mutate WDT_pd - they store + # models on `self` - so we keep the original polars frame for + # the downstream steps rather than paying a pl.from_pandas() + # round-trip per replicate. + WDT_pd = WDT_pl.to_pandas() + for col in self.fixed_cols: + if col in WDT_pd.columns: + WDT_pd[col] = WDT_pd[col].astype("category") + + _fit_LTFU(self, WDT_pd) + _fit_visit(self, WDT_pd) + _fit_numerator(self, WDT_pd) + _fit_denominator(self, WDT_pd) + + if self.offload: + _offload_weights(self, boot_idx) + + del WDT_pd + WDT = _weight_predict(self, WDT_pl) + if self.weight_preexpansion and boot_idx is None: + self._main_weight_WDT = WDT + _weight_bind(self, WDT) + self.weight_stats = _weight_stats(self) is_boot = boot_idx is not None start = getattr(self, "_outcome_start_params", None) if is_boot else None @@ -365,6 +407,14 @@ def hazard(self) -> None: """ start = time.perf_counter() + if self.method == "dose-response": + raise NotImplementedError( + "Hazard ratio estimation is not supported for method='dose-response': " + "the counterfactual simulation only sets the baseline treatment, but " + "the dose-response outcome model depends on the cumulative dose, so " + "both arms would simulate identical outcomes (HR ≈ 1)." + ) + if not hasattr(self, "outcome_model") or not self.outcome_model: raise ValueError( "Outcome model not found. Please run the 'fit' method before calculating hazard ratio." @@ -429,13 +479,17 @@ def collect(self) -> SEQoutput: "collection_time": self._time_collected, } - if self.compevent_colname is not None: - compevent_models = [model["compevent"] for model in self.outcome_model] - else: - compevent_models = None - if self.outcome_model is not None: outcome_models = [model["outcome"] for model in self.outcome_model] + if self.compevent_colname is not None: + compevent_models = [model["compevent"] for model in self.outcome_model] + else: + compevent_models = None + else: + # collect() before fit(): no models to report, but the rest of the + # output (diagnostics, timings) is still valid. + outcome_models = None + compevent_models = None if self.risk_estimates is None: risk_ratio = risk_difference = None diff --git a/pySEQTarget/analysis/_hazard.py b/pySEQTarget/analysis/_hazard.py index a8331bf..eb16eba 100644 --- a/pySEQTarget/analysis/_hazard.py +++ b/pySEQTarget/analysis/_hazard.py @@ -1,4 +1,6 @@ +import copy import warnings +from concurrent.futures import ProcessPoolExecutor import numpy as np import polars as pl @@ -33,8 +35,6 @@ def _calculate_hazard_single(self, data, idx=None, val=None): return _create_hazard_output(None, None, None, val, self) if self.bootstrap_nboot > 0: - boot_log_hrs = [] - # outcome_model[model_pos + 1] was fit on _boot_samples[sample_idx]; # skipped replicates make this mapping non-identity, so iterate it # explicitly rather than assuming model index == sample index. @@ -42,31 +42,18 @@ def _calculate_hazard_single(self, data, idx=None, val=None): if boot_sample_idx is None: boot_sample_idx = list(range(len(self._boot_samples))) - for model_pos, sample_idx in enumerate(boot_sample_idx): - if self.seed is not None: - self._rng = np.random.RandomState(self.seed + sample_idx + 1) - id_counts = self._boot_samples[sample_idx] - - counts = pl.DataFrame( - { - self.id_col: list(id_counts.keys()), - "_count": list(id_counts.values()), - } - ) - boot_data = ( - data.lazy() - .join(counts.lazy(), on=self.id_col, how="inner") - .with_columns(pl.int_ranges(0, pl.col("_count")).alias("_rep")) - .explode("_rep") - .drop("_count", "_rep") - .collect() - ) - - boot_log_hr = _hazard_handler( - self, boot_data, idx, model_pos + 1, self._rng - ) - if boot_log_hr is not None and not np.isnan(boot_log_hr): - boot_log_hrs.append(boot_log_hr) + # The per-replicate hazard simulation is GIL-bound (patsy design build), + # so spread it over a process pool when parallel=True. Needs a concrete + # seed (always set since SEQuential pins a default) so each replicate's + # RNG — and therefore the result — is identical to the serial path. + if getattr(self, "parallel", False) and self.seed is not None: + boot_log_hrs = _parallel_boot_log_hrs(self, data, idx, boot_sample_idx) + else: + boot_log_hrs = [] + for model_pos, sample_idx in enumerate(boot_sample_idx): + boot_log_hr = _one_boot_log_hr(self, data, idx, model_pos, sample_idx) + if boot_log_hr is not None and not np.isnan(boot_log_hr): + boot_log_hrs.append(boot_log_hr) if len(boot_log_hrs) == 0: return _create_hazard_output(np.exp(full_log_hr), None, None, val, self) @@ -87,6 +74,122 @@ def _calculate_hazard_single(self, data, idx=None, val=None): return _create_hazard_output(np.exp(full_log_hr), lci, uci, val, self) +def _truncate_to_first_event(tmp, id_col, event_col): + """Reduce a simulated counterfactual grid to one survival row per (id, trial). + + Keeps the FIRST row in which ``event_col`` fires (status 1 at the first event + time); if the unit never has an event it keeps the final follow-up row + (status 0, censored at max follow-up). + + Outcomes are simulated independently at every follow-up row, so a unit may + have ``event_col == 1`` at several rows. We therefore keep only rows whose + cumulative event count *strictly before* the current row is 0 — i.e. every + row up to and including the first event — and then take the last of those, + which is the first-event row (or the max-follow-up row when there is no + event). + + NOTE: the inclusive form ``cum_sum(event_col) <= 1`` is WRONG here: it + retains post-event rows (the cumulative count stays at 1 until a second + event), so ``.last()`` returns the final follow-up row and a single event is + silently recorded as censored. That dropped ~99% of simulated events and + inflated the marginal-HR variance ~8x relative to SEQTaRget (R). See + tests/test_hazard_truncation.py. + """ + return ( + tmp.with_columns( + ( + pl.col(event_col).cum_sum().over([id_col, "trial"]) - pl.col(event_col) + ).alias("_event_prior") + ) + .filter(pl.col("_event_prior") == 0) + .group_by([id_col, "trial"]) + .last() + .drop("_event_prior") + ) + + +def _one_boot_log_hr(self, data, idx, model_pos, sample_idx): + """Build one bootstrap resample of ``data`` and return its log hazard ratio. + + The RNG is rebuilt from ``seed + sample_idx + 1`` (matching the serial loop + exactly), so this is bit-identical whether called serially or in a worker. + """ + seed = getattr(self, "seed", None) + rng = ( + np.random.RandomState(seed + sample_idx + 1) if seed is not None else self._rng + ) + + id_counts = self._boot_samples[sample_idx] + counts = pl.DataFrame( + { + self.id_col: list(id_counts.keys()), + "_count": list(id_counts.values()), + } + ) + boot_data = ( + data.lazy() + .join(counts.lazy(), on=self.id_col, how="inner") + .with_columns(pl.int_ranges(0, pl.col("_count")).alias("_rep")) + .explode("_rep") + .drop("_count", "_rep") + .collect() + ) + return _hazard_handler(self, boot_data, idx, model_pos + 1, rng) + + +# Process-pool worker state. Set once per worker process by the initializer so +# each task ships only small integers, not the (slimmed) SEQuential object or +# the analysis frame. +_HZ_WORKER_OBJ = None +_HZ_WORKER_DATA = None + + +def _hazard_pool_init(obj, data_ref): + global _HZ_WORKER_OBJ, _HZ_WORKER_DATA + _HZ_WORKER_OBJ = obj + _HZ_WORKER_DATA = obj._offloader.load_dataframe(data_ref) + + +def _hazard_pool_task(idx, model_pos, sample_idx): + return _one_boot_log_hr(_HZ_WORKER_OBJ, _HZ_WORKER_DATA, idx, model_pos, sample_idx) + + +def _parallel_boot_log_hrs(self, data, idx, boot_sample_idx): + """Run the bootstrap hazard simulations over a process pool. + + The analysis frame is handed to each worker process once (via the offloader + ref + pool initializer), and a slimmed copy of ``self`` carries the fitted + models. Results are gathered in submission order, matching the serial loop; + NaN/None replicates are dropped the same way. + """ + data_ref = self._offloader.save_dataframe(data, f"_haz_DT_{idx}") + + # Slim copy for the pool: drop the large frames workers reload from data_ref; + # keep the fitted models, bootstrap samples, and config. _GlumFit and + # SEQuential each drop their unpicklable / heavy state on pickle. + slim = copy.copy(self) + slim.DT = None + slim.data = None + slim._rng = None + + boot_log_hrs = [] + with ProcessPoolExecutor( + max_workers=self.ncores, + initializer=_hazard_pool_init, + initargs=(slim, data_ref), + ) as executor: + futures = [ + executor.submit(_hazard_pool_task, idx, model_pos, sample_idx) + for model_pos, sample_idx in enumerate(boot_sample_idx) + ] + for future in futures: + result = future.result() + if result is not None and not np.isnan(result): + boot_log_hrs.append(result) + + return boot_log_hrs + + def _hazard_handler(self, data, idx, boot_idx, rng): exclude_cols = [ "followup", @@ -140,36 +243,18 @@ def _hazard_handler(self, data, idx, boot_idx, rng): ce_sim = rng.binomial(1, ce_prob) tmp = tmp.with_columns([pl.Series("ce", ce_sim)]) - tmp = ( - tmp.with_columns( - [ - pl.when((pl.col("outcome") == 1) | (pl.col("ce") == 1)) - .then(1) - .otherwise(0) - .alias("any_event") - ] - ) - .with_columns( - [ - pl.col("any_event") - .cum_sum() - .over([self.id_col, "trial"]) - .alias("event_cumsum") - ] - ) - .filter(pl.col("event_cumsum") <= 1) - ) - else: tmp = tmp.with_columns( [ - pl.col("outcome") - .cum_sum() - .over([self.id_col, "trial"]) - .alias("event_cumsum") + pl.when((pl.col("outcome") == 1) | (pl.col("ce") == 1)) + .then(1) + .otherwise(0) + .alias("any_event") ] - ).filter(pl.col("event_cumsum") <= 1) + ) + tmp = _truncate_to_first_event(tmp, self.id_col, "any_event") + else: + tmp = _truncate_to_first_event(tmp, self.id_col, "outcome") - tmp = tmp.group_by([self.id_col, "trial"]).last() all_treatments.append(tmp) sim_data = pl.concat(all_treatments) diff --git a/pySEQTarget/error/_data_checker.py b/pySEQTarget/error/_data_checker.py index 217be1c..ec6bc8a 100644 --- a/pySEQTarget/error/_data_checker.py +++ b/pySEQTarget/error/_data_checker.py @@ -37,6 +37,14 @@ def _data_checker(self): ) for col in self.excused_colnames: + # _param_checker pads the list with None up to len(treatment_level) + # when fewer excused columns are supplied. + if col is None: + continue + if col not in self.data.columns: + raise ValueError( + f"excused_colnames entry '{col}' not found in data columns." + ) violations = ( self.data.sort([self.id_col, self.time_col]) .group_by(self.id_col) diff --git a/pySEQTarget/error/_param_checker.py b/pySEQTarget/error/_param_checker.py index 14b0b94..9efa54a 100644 --- a/pySEQTarget/error/_param_checker.py +++ b/pySEQTarget/error/_param_checker.py @@ -41,9 +41,18 @@ def _param_checker(self): if self.km_curves and self.hazard_estimate: raise ValueError("km_curves and hazard cannot both be set to True.") + if self.hazard_estimate and self.method == "dose-response": + raise ValueError( + "Hazard ratio estimation is not supported for method='dose-response': " + "the counterfactual simulation only sets the baseline treatment, but " + "the dose-response outcome model depends on the cumulative dose, so " + "both arms would simulate identical outcomes (HR ≈ 1)." + ) + if sum([self.followup_class, self.followup_include, self.followup_spline]) > 1: raise ValueError( - "Only one of followup_class or followup_include can be set to True." + "Only one of followup_class, followup_include, or followup_spline " + "can be set to True." ) if self.followup_spline_df < 2: diff --git a/pySEQTarget/expansion/_selection.py b/pySEQTarget/expansion/_selection.py index 63b7361..adaedf2 100644 --- a/pySEQTarget/expansion/_selection.py +++ b/pySEQTarget/expansion/_selection.py @@ -38,8 +38,10 @@ def _random_selection(self): ) .filter( pl.col("trialID").is_in(sample) - | pl.col(f"{self.treatment_col}{self.indicator_baseline}") - != self.treatment_level[0] + | ( + pl.col(f"{self.treatment_col}{self.indicator_baseline}") + != self.treatment_level[0] + ) ) .drop("trialID") ) diff --git a/pySEQTarget/helpers/_bootstrap.py b/pySEQTarget/helpers/_bootstrap.py index 733be5e..a80fecb 100644 --- a/pySEQTarget/helpers/_bootstrap.py +++ b/pySEQTarget/helpers/_bootstrap.py @@ -10,6 +10,22 @@ from ._format_time import _format_time +# Side-effect attributes set by the main fit that bootstrap replicates overwrite +# when they run in-process (the serial path: each replicate calls the fit body +# again, and _fit_numerator/_fit_denominator do `self.X_model = fits`). Snapshot +# them after the main fit and restore after the replicate loop so summaries +# reflect the main fit, not the last replicate. The parallel path already keeps +# them (replicates run in worker copies), so restore is a no-op there. +_MAIN_FIT_ATTRS = ( + "numerator_model", + "denominator_model", + "cense_numerator_model", + "cense_denominator_model", + "visit_numerator_model", + "visit_denominator_model", + "weight_stats", +) + def _prepare_boot_data(self, data, boot_id): id_counts = self._boot_samples[boot_id] @@ -80,12 +96,44 @@ def _bootstrap_worker(obj, method_name, original_DT, i, seed, args, kwargs): # Disable bootstrapping to prevent recursion obj.bootstrap_nboot = 0 + # Call the raw, undecorated fit body — not the @bootstrap_loop-wrapped + # method — so it returns this replicate's single model dict. Going through + # the wrapper would re-enter bootstrap_loop and return a list ([model_dict]), + # which the serial path never does, breaking the hazard/survival consumers + # that index outcome_model[i]["outcome"]. method = getattr(obj, method_name) - result = method(*args, **kwargs) + raw = getattr(method, "__wrapped__", None) + if raw is not None: + result = raw(obj, *args, **kwargs) + else: + result = method(*args, **kwargs) obj._rng = None return result +# Process-pool worker state for the parallel bootstrap fit. Set once per +# worker process by the initializer so each task ships only the replicate +# index — not the (slimmed) SEQuential object or the full analysis frame, +# which previously crossed the process boundary once per task. +_FIT_WORKER_OBJ = None +_FIT_WORKER_DATA = None +_FIT_WORKER_CALL = None + + +def _fit_pool_init(obj, data_ref, method_name, seed, args, kwargs): + global _FIT_WORKER_OBJ, _FIT_WORKER_DATA, _FIT_WORKER_CALL + _FIT_WORKER_OBJ = obj + _FIT_WORKER_DATA = obj._offloader.load_dataframe(data_ref) + _FIT_WORKER_CALL = (method_name, seed, args, kwargs) + + +def _fit_pool_task(i): + method_name, seed, args, kwargs = _FIT_WORKER_CALL + return _bootstrap_worker( + _FIT_WORKER_OBJ, method_name, _FIT_WORKER_DATA, i, seed, args, kwargs + ) + + def bootstrap_loop(method): @wraps(method) def wrapper(self, *args, **kwargs): @@ -104,6 +152,12 @@ def wrapper(self, *args, **kwargs): full = method(self, *args, **kwargs) results.append(full) + # Snapshot the main-fit weight models before any in-process replicate + # can overwrite them; restored just before returning. + main_fit_state = { + attr: getattr(self, attr) for attr in _MAIN_FIT_ATTRS if hasattr(self, attr) + } + if getattr(self, "bootstrap_nboot") > 0 and getattr( self, "_boot_samples", None ): @@ -119,19 +173,13 @@ def wrapper(self, *args, **kwargs): self._rng = None self.DT = None - with ProcessPoolExecutor(max_workers=ncores) as executor: + with ProcessPoolExecutor( + max_workers=ncores, + initializer=_fit_pool_init, + initargs=(self, original_DT_ref, method_name, seed, args, kwargs), + ) as executor: futures = { - executor.submit( - _bootstrap_worker, - self, - method_name, - original_DT_ref, - i, - seed, - args, - kwargs, - ): i - for i in range(nboot) + executor.submit(_fit_pool_task, i): i for i in range(nboot) } skipped = 0 boot_sample_idx = [] @@ -154,12 +202,12 @@ def wrapper(self, *args, **kwargs): self._rng = original_rng self.DT = self._offloader.load_dataframe(original_DT_ref) else: - # Keep original data in memory if offloading is disabled to avoid unnecessary I/O + # original_DT_ref already holds the parquet ref (offload on) or + # the frame itself (offload off) from the save above — don't + # write the parquet a second time. With offload on, drop the + # in-memory frame; replicates reload from the ref. if self._offloader.enabled: - original_DT_ref = self._offloader.save_dataframe(original_DT, "_DT") del original_DT - else: - original_DT_ref = original_DT skipped = 0 boot_sample_idx = [] @@ -205,6 +253,11 @@ def wrapper(self, *args, **kwargs): end = time.perf_counter() self._model_time = _format_time(start, end) + # Restore the main-fit weight models so numerator/denominator summaries + # reflect the main fit rather than the last in-process replicate. + for attr, value in main_fit_state.items(): + setattr(self, attr, value) + self.outcome_model = results return results diff --git a/pySEQTarget/helpers/_glum_fit.py b/pySEQTarget/helpers/_glum_fit.py index b9f2a73..4263467 100644 --- a/pySEQTarget/helpers/_glum_fit.py +++ b/pySEQTarget/helpers/_glum_fit.py @@ -40,22 +40,82 @@ class _GlumFit: just like statsmodels keeps model.exog, so memory use is comparable. """ - def __init__(self, glum_model, design_info, feature_names, X_design, sample_weight): + def __init__( + self, + glum_model, + design_info, + feature_names, + X_design, + sample_weight, + formula=None, + ref_frame=None, + ): self._glum = glum_model self._design_info = design_info self._X_design = X_design # includes the intercept column + self._nobs = X_design.shape[0] self._sample_weight = sample_weight + # Lazily-filled cache of the (small) coefficient covariance matrix. It + # lets __getstate__ drop the full design matrix (_X_design can be 100s + # of MB) while keeping bse/summary working after unpickle — important + # for the process pool and offload, which ship many fitted models. + self._cov_cached = None + # Inputs to rebuild ``design_info`` on unpickle: the patsy DesignInfo + # itself cannot be pickled (patsy #26), so we keep the formula and a + # tiny reference frame (which preserves each categorical column's full, + # ordered dtype categories) and re-parse on __setstate__. + self._formula = formula + self._ref_frame = ref_frame - self.model = types.SimpleNamespace( - exog_names=feature_names, - data=types.SimpleNamespace(design_info=design_info), - ) + self._build_model_namespace(design_info, feature_names) self.exog_names = feature_names # statsmodels convention: intercept first all_coefs = np.concatenate([[glum_model.intercept_], glum_model.coef_]) self.params = pd.Series(all_coefs, index=feature_names) + def _build_model_namespace(self, design_info, feature_names): + self.model = types.SimpleNamespace( + exog_names=feature_names, + data=types.SimpleNamespace(design_info=design_info), + ) + + def __getstate__(self): + # Drop the unpicklable patsy DesignInfo and the SimpleNamespaces that + # reference it; __setstate__ rebuilds them from the formula + ref_frame. + state = self.__dict__.copy() + state.pop("_design_info", None) + state.pop("model", None) + # Replace the full design matrix with the small cached covariance so the + # pickled model stays lightweight (the design matrix can be 100s of MB). + # bse/summary still work via _cov_cached; predict never needs _X_design. + if state.get("_cov_cached") is None: + state["_cov_cached"] = self.cov_params() + state["_X_design"] = None + return state + + def __setstate__(self, state): + self.__dict__.update(state) + if self._formula is None or self._ref_frame is None: + raise RuntimeError( + "Cannot unpickle _GlumFit fitted before formula/ref_frame were " + "recorded; refit with the current pySEQTarget version." + ) + _, X_mat = patsy.dmatrices( + self._formula, self._ref_frame, return_type="dataframe" + ) + if list(X_mat.columns) != list(self.exog_names): + # The reference frame's categorical ordering must reproduce the + # frozen column structure exactly, or glum's coefficients would be + # paired with the wrong design columns on predict. Fail loudly + # rather than return silently wrong predictions. + raise RuntimeError( + "_GlumFit design columns changed on unpickle: " + f"{list(X_mat.columns)} != {list(self.exog_names)}" + ) + self._design_info = X_mat.design_info + self._build_model_namespace(self._design_info, self.exog_names) + def predict(self, data, transform=True): if transform: # data is a pandas DataFrame — build design matrix via stored patsy info @@ -69,12 +129,20 @@ def predict(self, data, transform=True): return self._glum.predict(X_arr) def cov_params(self): + if self._cov_cached is not None: + return self._cov_cached X = self._X_design + if X is None: + raise RuntimeError( + "cov_params unavailable: design matrix was dropped on pickle and " + "no covariance was cached." + ) mu = self._glum.predict(X[:, 1:]) w = mu * (1.0 - mu) if self._sample_weight is not None: w = w * np.asarray(self._sample_weight) - return np.linalg.pinv(X.T @ (w[:, None] * X)) + self._cov_cached = np.linalg.pinv(X.T @ (w[:, None] * X)) + return self._cov_cached @property def bse(self): @@ -110,7 +178,7 @@ def summary(self): "GLM (glum backend)", "Binomial", "logit", - str(self._X_design.shape[0]), + str(self._nobs), ] }, index=["Model:", "Family:", "Link:", "No. Observations:"], @@ -185,4 +253,24 @@ def _fit_glum(formula, data, var_weights=None, start_params=None, design_cache=N fit_kwargs["sample_weight"] = sample_weight glm.fit(X_arr, y_arr, **fit_kwargs) - return _GlumFit(glm, design_info, feature_names, X_design, sample_weight) + + # Keep a minimal reference frame so the (unpicklable) design_info can be + # rebuilt on unpickle. Two rows suffice ONLY when each categorical factor's + # full, ordered level set lives in the column dtype — patsy derives the + # contrasts from pd.Categorical dtype categories, but for plain string + # columns it falls back to the observed values, and two rows rarely cover + # every level. Freeze the design's levels into the frame's dtypes so the + # re-parse reproduces the frozen column structure regardless of source + # dtype. (The codebase uses only stateless transforms — precomputed + # squares, explicit-knot splines — so no other fit-time state needs + # preserving.) + ref_frame = _align_categories(design_info, data.head(2).copy()) + return _GlumFit( + glm, + design_info, + feature_names, + X_design, + sample_weight, + formula=formula, + ref_frame=ref_frame, + ) diff --git a/pySEQTarget/helpers/_jax_fit.py b/pySEQTarget/helpers/_jax_fit.py index 201c503..e7bbf63 100644 --- a/pySEQTarget/helpers/_jax_fit.py +++ b/pySEQTarget/helpers/_jax_fit.py @@ -25,6 +25,16 @@ def __init__( self._design_info = design_info self._feature_names = feature_names self._X_design = X_mat.values + self._nobs = X_mat.shape[0] + # Lazily-filled coefficient covariance cache: lets __getstate__ drop the + # full design matrix while keeping bse/summary working after unpickle. + self._cov_cached = None + # Inputs to rebuild ``design_info`` on unpickle: patsy DesignInfo cannot + # be pickled (patsy #26), so keep the formula plus a tiny reference + # frame (which preserves each categorical column's full, ordered dtype + # categories) and re-parse on __setstate__. Mirrors _GlumFit. + self._formula = formula + self._ref_frame = df_pd.head(2).copy() X_arr = X_mat.drop(columns=["Intercept"], errors="ignore").values y_raw = y_mat.values.ravel() @@ -51,12 +61,46 @@ def __init__( ) # statsmodels 'like' exposure + self._build_model_namespace(design_info, feature_names) + self.exog_names = feature_names + self.params = self._build_params() + + def _build_model_namespace(self, design_info, feature_names): self.model = types.SimpleNamespace( exog_names=feature_names, data=types.SimpleNamespace(design_info=design_info), ) - self.exog_names = feature_names - self.params = self._build_params() + + def __getstate__(self): + # Drop the unpicklable patsy DesignInfo and the SimpleNamespaces that + # reference it; __setstate__ rebuilds them from the formula + ref_frame. + state = self.__dict__.copy() + state.pop("_design_info", None) + state.pop("model", None) + # Replace the full design matrix with the small cached covariance so + # the pickled model stays lightweight. Covariance is only implemented + # for binary fits; multiclass keeps None (bse raises either way). + if state.get("_cov_cached") is None and self._n_classes == 2: + state["_cov_cached"] = self.cov_params() + state["_X_design"] = None + return state + + def __setstate__(self, state): + self.__dict__.update(state) + _, X_mat = patsy.dmatrices( + self._formula, self._ref_frame, return_type="dataframe" + ) + if list(X_mat.columns) != list(self._feature_names): + # The reference frame's categorical ordering must reproduce the + # frozen column structure exactly, or the coefficients would pair + # with the wrong design columns on predict. Fail loudly rather + # than return silently wrong predictions. + raise RuntimeError( + "_JaxFit design columns changed on unpickle: " + f"{list(X_mat.columns)} != {list(self._feature_names)}" + ) + self._design_info = X_mat.design_info + self._build_model_namespace(self._design_info, self._feature_names) def _coef_components(self): W, b = self._jax.params @@ -120,12 +164,20 @@ def cov_params(self): raise NotImplementedError( "Standard errors are only implemented for binary jax fits." ) + if self._cov_cached is not None: + return self._cov_cached X = self._X_design + if X is None: + raise RuntimeError( + "cov_params unavailable: design matrix was dropped on pickle " + "and no covariance was cached." + ) mu = np.asarray(self._jax.predict(X[:, 1:]))[:, 1] w = mu * (1.0 - mu) if self._sample_weight is not None: w = w * self._sample_weight - return np.linalg.pinv(X.T @ (w[:, None] * X)) + self._cov_cached = np.linalg.pinv(X.T @ (w[:, None] * X)) + return self._cov_cached @property def bse(self): @@ -161,7 +213,7 @@ def summary(self): "GLM (jax backend)", "Binomial", "logit", - str(self._X_design.shape[0]), + str(self._nobs), ] }, index=["Model:", "Family:", "Link:", "No. Observations:"], diff --git a/pySEQTarget/helpers/_output_files.py b/pySEQTarget/helpers/_output_files.py index aab4177..dd4cbb7 100644 --- a/pySEQTarget/helpers/_output_files.py +++ b/pySEQTarget/helpers/_output_files.py @@ -12,6 +12,19 @@ def _build_md(self, img_path: str = None) -> str: lines = [] + def _model_section(title, kind): + # SEQoutput.summary handles absent model lists (e.g. no numerator + # models under weighted ITT) and loads offloaded path refs. + summaries = self.summary(kind) + if not summaries: + return + lines.append(f"### {title}") + lines.append("") + lines.append("```") + lines.append(str(summaries[0])) + lines.append("```") + lines.append("") + lines.append(f"# SEQuential Analysis: {datetime.date.today()}: {self.method}") lines.append("") @@ -19,42 +32,22 @@ def _build_md(self, img_path: str = None) -> str: lines.append("## Weighting") lines.append("") - lines.append("### Numerator Model") - lines.append("") - lines.append("```") - lines.append(str(self.numerator_models[0].summary())) - lines.append("```") - lines.append("") + _model_section("Numerator Model", "numerator") + _model_section("Denominator Model", "denominator") - lines.append("### Denominator Model") - lines.append("") - lines.append("```") - lines.append(str(self.denominator_models[0].summary())) - lines.append("```") - lines.append("") + if self.options.compevent_colname is not None: + _model_section("Competing Event Model", "compevent") - if self.options.compevent_colname is not None and self.compevent_models: - lines.append("### Competing Event Model") + if self.weight_statistics is not None: + lines.append("### Weighting Statistics") lines.append("") - lines.append("```") - lines.append(str(self.compevent_models[0].summary())) - lines.append("```") + lines.append(self.weight_statistics.to_pandas().to_markdown(index=False)) lines.append("") - lines.append("### Weighting Statistics") - lines.append("") - lines.append(self.weight_statistics.to_pandas().to_markdown(index=False)) - lines.append("") - lines.append("## Outcome") lines.append("") - lines.append("### Outcome Model") - lines.append("") - lines.append("```") - lines.append(str(self.outcome_models[0].summary())) - lines.append("```") - lines.append("") + _model_section("Outcome Model", "outcome") if self.options.hazard_estimate and self.hazard is not None: lines.append("### Hazard") @@ -85,10 +78,26 @@ def _build_md(self, img_path: str = None) -> str: lines.append("") if self.diagnostic_tables: + # Clarify what each unique/nonunique table actually counts, so the + # rendered headings are not ambiguous (see SEQoutput.retrieve_data). + diag_descriptions = { + "unique_outcomes": "distinct subjects who had the outcome", + "nonunique_outcomes": "total outcome events", + "unique_followup": "distinct subjects contributing follow-up", + "nonunique_followup": "person-time intervals", + "unique_compevent": "distinct subjects with a competing event", + "nonunique_compevent": "total competing-event intervals", + "unique_switches": "distinct subjects who switched", + "nonunique_switches": "total switch intervals", + } lines.append("## Diagnostic Tables") lines.append("") for name, table in self.diagnostic_tables.items(): - lines.append(f"### {name.replace('_', ' ').title()}") + heading = name.replace("_", " ").title() + description = diag_descriptions.get(name) + if description: + heading = f"{heading} ({description})" + lines.append(f"### {heading}") lines.append("") lines.append(table.to_pandas().to_markdown(index=False)) lines.append("") diff --git a/pySEQTarget/helpers/_predict_model.py b/pySEQTarget/helpers/_predict_model.py index 7ccb3b1..13c2655 100644 --- a/pySEQTarget/helpers/_predict_model.py +++ b/pySEQTarget/helpers/_predict_model.py @@ -14,7 +14,9 @@ def _safe_predict(model, data, clip_probs=True): data : pandas DataFrame Data to predict on clip_probs : bool - If True, clip probabilities to [0, 1] and replace NaN with 0.5 + If True, clip probabilities to [0, 1]. Raises ValueError if any + predicted probability is NaN (this signals a train/predict dtype + mismatch or coefficient overflow, not a value to silently impute). """ try: probs = model.predict(data) @@ -39,14 +41,22 @@ def _safe_predict(model, data, clip_probs=True): return probs -def _predict_model(self, model, newdata): - newdata = newdata.to_pandas() +def _prep_predict_frame(self, newdata): + """Convert a polars frame to pandas with fixed_cols cast to category. - # Original behavior - convert fixed_cols to category + Split out from _predict_model so callers predicting with several models on + the same rows (e.g. numerator + denominator in _weight_predict) can pay + the conversion once and share the frame. + """ + newdata = newdata.to_pandas() for col in self.fixed_cols: if col in newdata.columns: newdata[col] = newdata[col].astype("category") + return newdata + +def _predict_model_pd(model, newdata): + """Predict on an already-prepared pandas frame, with category fix retry.""" try: return np.array(model.predict(newdata)) except Exception as e: @@ -55,3 +65,7 @@ def _predict_model(self, model, newdata): return np.array(model.predict(newdata)) else: raise + + +def _predict_model(self, model, newdata): + return _predict_model_pd(model, _prep_predict_frame(self, newdata)) diff --git a/pySEQTarget/weighting/_weight_bind.py b/pySEQTarget/weighting/_weight_bind.py index d159af5..0269a1c 100644 --- a/pySEQTarget/weighting/_weight_bind.py +++ b/pySEQTarget/weighting/_weight_bind.py @@ -2,6 +2,7 @@ def _weight_bind(self, WDT): + drop_after_join = [] if self.weight_preexpansion: join = "inner" on = [self.id_col, "period"] @@ -9,23 +10,29 @@ def _weight_bind(self, WDT): # On a bootstrap pass _prepare_boot_data transformed id_col so that # each replicate has a unique value -- integer math (orig_id * id_mult # + replicate) for int IDs, "{orig_id}_{replicate}" for string IDs. - # Recover the original ID here so the join to WDT (which still carries - # un-resampled originals) lines up. No-op on the main fit pass. + # WDT still carries the un-resampled originals, so join on a recovered + # original-ID key. Do NOT overwrite id_col itself: the weight cum_prod + # below groups on (id_col, trial), and collapsing replicate copies of a + # multiply-sampled subject into one group would interleave their rows + # and corrupt the cumulative weights (each copy must accumulate its own + # product independently). No-op on the main fit pass. is_boot = getattr(self, "_current_boot_idx", None) is not None if is_boot: if self.DT.schema[self.id_col].is_integer(): - self.DT = self.DT.with_columns( - (pl.col(self.id_col) // self._boot_id_mult).alias(self.id_col) - ) + orig_id = pl.col(self.id_col) // self._boot_id_mult else: - self.DT = self.DT.with_columns( - pl.col(self.id_col).str.replace(r"_\d+$", "").alias(self.id_col) - ) + orig_id = pl.col(self.id_col).str.replace(r"_\d+$", "") + self.DT = self.DT.with_columns(orig_id.alias("_orig_id")) + WDT = WDT.rename({self.id_col: "_orig_id"}) + on = ["_orig_id", "period"] + drop_after_join = ["_orig_id"] else: join = "left" on = [self.id_col, "trial", "followup"] WDT = self.DT.join(WDT, on=on, how=join) + if drop_after_join: + WDT = WDT.drop(drop_after_join) if self.visit_colname is not None: visit = pl.col(self.visit_colname) == 0 diff --git a/pySEQTarget/weighting/_weight_offload.py b/pySEQTarget/weighting/_weight_offload.py index 04603e8..2879cc3 100644 --- a/pySEQTarget/weighting/_weight_offload.py +++ b/pySEQTarget/weighting/_weight_offload.py @@ -1,19 +1,29 @@ def _offload_weights(self, boot_idx): - """Helper to offload weight models to disk""" - weight_models = [ + """Offload fitted weight models to disk, replacing them with path refs. + + numerator_model/denominator_model are lists with one fit per treatment + level; the cense/visit models are single fits. Entries already offloaded + (str refs) or never fit (None) are left as-is. Consumers go through + Offloader.load_model, which passes non-str values through. + """ + for attr, name in ( ("numerator_model", "numerator"), ("denominator_model", "denominator"), - ("LTFU_model", "LTFU"), - ("visit_model", "visit"), - ] - - for model_attr, model_name in weight_models: - if hasattr(self, model_attr): - model_list = getattr(self, model_attr) - if model_list and isinstance(model_list, list) and len(model_list) > 0: - latest_model = model_list[-1] - if latest_model is not None: - offloaded = self._offloader.save_model( - latest_model, model_name, boot_idx + ): + model_list = getattr(self, attr, None) + if isinstance(model_list, list): + for i, model in enumerate(model_list): + if model is not None and not isinstance(model, str): + model_list[i] = self._offloader.save_model( + model, f"{name}{i}", boot_idx ) - model_list[-1] = offloaded + + for attr, name in ( + ("cense_numerator_model", "cense_numerator"), + ("cense_denominator_model", "cense_denominator"), + ("visit_numerator_model", "visit_numerator"), + ("visit_denominator_model", "visit_denominator"), + ): + model = getattr(self, attr, None) + if model is not None and not isinstance(model, str): + setattr(self, attr, self._offloader.save_model(model, name, boot_idx)) diff --git a/pySEQTarget/weighting/_weight_pred.py b/pySEQTarget/weighting/_weight_pred.py index 6a2bca9..56341b5 100644 --- a/pySEQTarget/weighting/_weight_pred.py +++ b/pySEQTarget/weighting/_weight_pred.py @@ -2,6 +2,7 @@ import polars as pl from ..helpers import _predict_model +from ..helpers._predict_model import _predict_model_pd, _prep_predict_frame def _extract_class_probability(p, level_idx, is_binary): @@ -59,17 +60,22 @@ def _weight_predict(self, WDT): denom_model = self._offloader.load_model(self.denominator_model[i]) num_model = self._offloader.load_model(self.numerator_model[i]) - if denom_model is not None and lag_mask.sum() > 0: - subset = WDT.filter(pl.Series(lag_mask)) - p = _predict_model(self, denom_model, subset) + if (denom_model is None and num_model is None) or lag_mask.sum() == 0: + continue + + # Numerator and denominator predict on the same rows — pay the + # filter + pandas conversion once and share the frame. + subset_pd = _prep_predict_frame(self, WDT.filter(pl.Series(lag_mask))) + + if denom_model is not None: + p = _predict_model_pd(denom_model, subset_pd) p_class = _extract_class_probability(p, i, is_binary) pred_denom[lag_mask] = np.where( switched_treatment[lag_mask], 1.0 - p_class, p_class ) - if num_model is not None and lag_mask.sum() > 0: - subset = WDT.filter(pl.Series(lag_mask)) - p = _predict_model(self, num_model, subset) + if num_model is not None: + p = _predict_model_pd(num_model, subset_pd) p_class = _extract_class_probability(p, i, is_binary) pred_num[lag_mask] = np.where( switched_treatment[lag_mask], 1.0 - p_class, p_class @@ -167,12 +173,17 @@ def _weight_predict(self, WDT): .otherwise(pl.col("numerator")) .alias("numerator") ) + # Full-frame pandas conversion shared by the cense and visit predictions + # (built lazily — only when at least one of those model pairs exists). + WDT_pd = None + if self.cense_colname is not None: cense_num_model = self._offloader.load_model(self.cense_numerator_model) cense_denom_model = self._offloader.load_model(self.cense_denominator_model) if cense_num_model is not None and cense_denom_model is not None: - p_num = _predict_model(self, cense_num_model, WDT).flatten() - p_denom = _predict_model(self, cense_denom_model, WDT).flatten() + WDT_pd = _prep_predict_frame(self, WDT) + p_num = _predict_model_pd(cense_num_model, WDT_pd).flatten() + p_denom = _predict_model_pd(cense_denom_model, WDT_pd).flatten() WDT = ( WDT.with_columns( [ @@ -196,8 +207,12 @@ def _weight_predict(self, WDT): visit_num_model = self._offloader.load_model(self.visit_numerator_model) visit_denom_model = self._offloader.load_model(self.visit_denominator_model) if visit_num_model is not None and visit_denom_model is not None: - p_num = _predict_model(self, visit_num_model, WDT).flatten() - p_denom = _predict_model(self, visit_denom_model, WDT).flatten() + # The visit formulas don't reference the _cense column added above, + # so the frame converted before the cense block is still valid. + if WDT_pd is None: + WDT_pd = _prep_predict_frame(self, WDT) + p_num = _predict_model_pd(visit_num_model, WDT_pd).flatten() + p_denom = _predict_model_pd(visit_denom_model, WDT_pd).flatten() WDT = ( WDT.with_columns( [ diff --git a/pyproject.toml b/pyproject.toml index 046f356..08e7c2d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pySEQTarget" -version = "0.13.7" +version = "0.13.8" description = "Sequentially Nested Target Trial Emulation" readme = "README.md" license = {text = "MIT"} @@ -24,7 +24,7 @@ classifiers = [ authors = [ {name = "Ryan O'Dea", email = "ryan.odea@psi.ch"}, {name = "Alejandro Szmulewicz", email = "aszmulewicz@hsph.harvard.edu"}, - {name = "Tom Palmer", email = "tom.palmer@bristol.ac.uk"}, + {name = "Tom Palmer", email = "remlapmot@hotmail.com"}, {name = "Miguel Hernán", email = "mhernan@hsph.harvard.edu"}, ] @@ -76,12 +76,14 @@ Repository = "https://github.com/CausalInference/pySEQTarget" dev = [ "black", "isort", + "jax", "pytest", "myst-parser", "piccolo_theme", "sphinx", "sphinx-copybutton", "sphinx-autodoc-typehints", + "tabulate", ] [tool.setuptools.packages.find] diff --git a/tests/test_accessor.py b/tests/test_accessor.py index ab9b796..15c6767 100644 --- a/tests/test_accessor.py +++ b/tests/test_accessor.py @@ -25,3 +25,52 @@ def test_ITT_collector(): collector.retrieve_data("unique_outcomes") with pytest.raises(ValueError): collector.retrieve_data("km_data") + # ITT produces no switch diagnostics: a clean ValueError, not the + # Python-2 dict.has_key AttributeError this used to raise. + with pytest.raises(ValueError, match="not created"): + collector.retrieve_data("unique_switches") + + +def test_collect_before_fit(): + # collect() without fit() must return an SEQoutput with None models rather + # than raising UnboundLocalError. Diagnostics from expand() still come + # through. + s = SEQuential( + load_data("SEQdata"), + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(), + ) + s.expand() + collector = s.collect() + assert collector.outcome_models is None + assert collector.compevent_models is None + assert collector.retrieve_data("unique_outcomes").height > 0 + + +def test_censoring_collector_switch_diagnostics(): + # Under method="censoring" the switch diagnostics exist and must be + # retrievable (regression for dict.has_key). + s = SEQuential( + load_data("SEQdata"), + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="censoring", + parameters=SEQopts(), + ) + s.expand() + s.fit() + collector = s.collect() + assert collector.retrieve_data("unique_switches").height > 0 + assert collector.retrieve_data("nonunique_switches").height > 0 diff --git a/tests/test_bootstrap_weights.py b/tests/test_bootstrap_weights.py new file mode 100644 index 0000000..4ebfb33 --- /dev/null +++ b/tests/test_bootstrap_weights.py @@ -0,0 +1,76 @@ +"""Regression test: bootstrap weights with weight_preexpansion=True. + +_weight_bind joins the pre-expansion weight frame (un-resampled original IDs) +onto the bootstrap-resampled DT. It must do so WITHOUT collapsing the resampled +IDs back to originals: the weight cum_prod groups on (id, trial), and merging +the replicate copies of a multiply-sampled subject into one group interleaves +their rows — turning weights a, ab into a, a², a²b, a²b²… (each copy compounds +the other's). Every replicate copy duplicates the same source rows, so the +correct cumulative weights are identical across copies. +""" + +import sys + +import polars as pl + +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + + +def test_boot_weights_identical_across_replicate_copies(monkeypatch): + # The package __init__ re-exports shadow the submodule names, so patch the + # name inside the SEQuential module via sys.modules. + seq_mod = sys.modules["pySEQTarget.SEQuential"] + wb_mod = sys.modules["pySEQTarget.weighting._weight_bind"] + + captured = [] + orig = wb_mod._weight_bind + + def spy(self, WDT): + result = orig(self, WDT) + if getattr(self, "_current_boot_idx", None) is not None: + captured.append(self.DT) + return result + + monkeypatch.setattr(seq_mod, "_weight_bind", spy) + + s = SEQuential( + load_data("SEQdata"), + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="censoring", + parameters=SEQopts( + weighted=True, + weight_preexpansion=True, + bootstrap_nboot=1, + bootstrap_sample=1.0, + seed=42, + ), + ) + s.expand() + s.bootstrap() + s.fit() + + assert len(captured) == 1 + DT = captured[0] + + # The resampled encoded IDs must survive the bind (replicate copies stay + # distinct groups for the cum_prod) ... + id_mult = s._boot_id_mult + orig_ids = set(s.data["ID"].unique().to_list()) + assert not set(DT["ID"].unique().to_list()) <= orig_ids + + # ... and with replicate sampling (sample=1.0 guarantees duplicated + # subjects), every copy of the same original (id, trial, followup) row must + # carry the SAME cumulative weight. + decoded = DT.with_columns((pl.col("ID") // id_mult).alias("_orig")) + dup = decoded.group_by(["_orig", "trial", "followup"]).agg( + [pl.len().alias("n"), pl.col("weight").n_unique().alias("n_weights")] + ) + assert dup.filter(pl.col("n") > 1).height > 0 # duplicates actually present + assert dup.filter(pl.col("n_weights") > 1).height == 0 diff --git a/tests/test_excused_colnames.py b/tests/test_excused_colnames.py new file mode 100644 index 0000000..47793e5 --- /dev/null +++ b/tests/test_excused_colnames.py @@ -0,0 +1,38 @@ +"""Validation of excused_colnames in _data_checker. + +_param_checker pads excused_colnames with None up to len(treatment_level); +the data checker used to feed that None into pl.col() and crash with a +confusing TypeError before the analysis even started. +""" + +import pytest + +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + + +def _build(**opts): + return SEQuential( + load_data("SEQdata"), + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="censoring", + parameters=SEQopts(**opts), + ) + + +def test_excused_colnames_shorter_than_treatment_level(): + # One excused column for two treatment levels: the padded None entry must + # be skipped, not validated. + s = _build(excused=True, excused_colnames=["excusedZero"]) + assert s.excused_colnames == ["excusedZero", None] + + +def test_excused_colnames_missing_column_raises_clearly(): + with pytest.raises(ValueError, match="not found in data columns"): + _build(excused=True, excused_colnames=["nonexistent_col"]) diff --git a/tests/test_glum.py b/tests/test_glum.py index cb3b3c5..c42a078 100644 --- a/tests/test_glum.py +++ b/tests/test_glum.py @@ -358,6 +358,57 @@ def test_glum_design_cache_handles_categorical_level_reordering(): assert list(mb.params.index) == list(m.params.index) +def test_glum_model_pickle_roundtrip_preserves_predictions(): + # _GlumFit holds a patsy DesignInfo, which cannot be pickled (patsy #26). + # It must rebuild the DesignInfo on unpickle from the stored formula + + # reference frame so fitted models can cross a process boundary (parallel + # bootstrap, offload). The roundtrip must preserve params and predictions + # exactly and yield a usable design_info. + import pickle + + import numpy as np + import pandas as pd + + from pySEQTarget.helpers._glum_fit import _fit_glum + + rng = np.random.default_rng(0) + n = 500 + df = pd.DataFrame( + { + "y": (rng.random(n) < 0.4).astype(int), + "x": rng.standard_normal(n), + "g": pd.Categorical( + rng.choice(["a", "b", "c"], n), categories=["a", "b", "c"] + ), + } + ) + + m = _fit_glum("y ~ x + g", df) + pred = m.predict(df) + + m2 = pickle.loads(pickle.dumps(m)) + + assert list(m2.params.index) == list(m.params.index) + assert list(m2.params.values) == approx(list(m.params.values), rel=1e-12, abs=1e-12) + assert list(m2.exog_names) == list(m.exog_names) + # design_info is rebuilt and reproduces the frozen column structure + assert list(m2.model.data.design_info.column_names) == list( + m.model.data.design_info.column_names + ) + # predictions are bit-identical through both predict paths. The design + # matrix is an external input (the unpickled model drops its own _X_design + # to stay lightweight), so feed the original to both for the transform=False + # path. + np.testing.assert_array_equal(m2.predict(df), pred) + np.testing.assert_array_equal( + m2.predict(m._X_design, transform=False), + m.predict(m._X_design, transform=False), + ) + # bse still works after unpickle even though _X_design was dropped (cached cov) + assert m2._X_design is None + np.testing.assert_allclose(m2.bse.values, m.bse.values, rtol=0, atol=0) + + def test_glum_warm_start_dropped_when_design_columns_mismatch(): # The defensive guard in _fit_glum: a (values, names) tuple whose names # don't line up with the patsy design matrix must be ignored, falling back @@ -384,3 +435,73 @@ def test_glum_warm_start_dropped_when_design_columns_mismatch(): assert list(bogus_fit.params.values) == approx( list(ref.params.values), rel=1e-8, abs=1e-12 ) + + +def test_glum_pickle_with_plain_string_covariate(): + # ref_frame is data.head(2): for a plain object/string column patsy derives + # the categorical levels from the OBSERVED values, so two rows cannot cover + # a 4-level factor and the unpickle column check used to fail with + # RuntimeError. The design's levels must be frozen into the reference + # frame's dtypes instead. + import pickle + + import numpy as np + import pandas as pd + + from pySEQTarget.helpers._glum_fit import _fit_glum + + rng = np.random.default_rng(0) + n = 2000 + levels = ["a", "b", "c", "d"] + # First two rows share one level so head(2) observes a strict subset + grp = ["a", "a"] + list(rng.choice(levels, n - 2)) + df = pd.DataFrame( + { + "grp": grp, # plain object dtype, NOT pd.Categorical + "x": rng.standard_normal(n), + "y": (rng.random(n) < 0.4).astype(int), + } + ) + + m = _fit_glum("y ~ grp + x", df) + m2 = pickle.loads(pickle.dumps(m)) + + assert list(m2.params) == approx(list(m.params), rel=1e-12, abs=1e-12) + assert list(m2.predict(df)) == approx(list(m.predict(df)), rel=1e-10, abs=1e-12) + + +def test_glum_offload_with_string_time_varying_covariate(): + # End-to-end: offload=True round-trips the weight models through joblib. + # A plain string time-varying covariate in the denominator formula must + # survive the pickle/unpickle cycle. + import polars as pl + + data = load_data("SEQdata").with_columns( + pl.when(pl.col("P") < 9) + .then(pl.lit("low")) + .when(pl.col("P") < 10) + .then(pl.lit("mid")) + .otherwise(pl.lit("high")) + .alias("P_grp") + ) + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P_grp"], + fixed_cols=["sex"], + method="censoring", + parameters=SEQopts( + glm_package="glum", + weighted=True, + weight_preexpansion=True, + offload=True, + seed=42, + ), + ) + s.expand() + s.fit() + assert s.DT["weight"].is_finite().all() diff --git a/tests/test_hazard.py b/tests/test_hazard.py index f3f1854..746f119 100644 --- a/tests/test_hazard.py +++ b/tests/test_hazard.py @@ -119,3 +119,36 @@ def flaky_outcome_fit(seq_self, *args, **kwargs): assert hr["Hazard ratio"][0] is not None and np.isfinite(hr["Hazard ratio"][0]) assert hr["LCI"][0] is not None assert hr["UCI"][0] is not None + + +def _dose_response_model(**opts): + return SEQuential( + load_data("SEQdata"), + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="dose-response", + parameters=SEQopts(weighted=True, weight_preexpansion=True, **opts), + ) + + +def test_dose_response_hazard_estimate_rejected_at_construction(): + # The counterfactual hazard simulation only sets the baseline treatment, + # but the dose-response outcome model depends on cumulative dose — both + # arms would simulate identical outcomes and the HR would silently be ~1. + with pytest.raises(ValueError, match="dose-response"): + _dose_response_model(hazard_estimate=True) + + +def test_dose_response_hazard_call_rejected(): + # hazard() can be called regardless of the hazard_estimate flag, so the + # method itself must refuse too. + s = _dose_response_model() + s.expand() + s.fit() + with pytest.raises(NotImplementedError, match="dose-response"): + s.hazard() diff --git a/tests/test_hazard_truncation.py b/tests/test_hazard_truncation.py new file mode 100644 index 0000000..a106d3e --- /dev/null +++ b/tests/test_hazard_truncation.py @@ -0,0 +1,111 @@ +"""Regression tests for the survival-time reduction in the hazard g-formula. + +`_truncate_to_first_event` collapses the simulated counterfactual grid (outcomes +drawn independently at every follow-up row) to one survival row per (id, trial): +the first-event row, or the max-follow-up row when there is no event. + +The earlier implementation used the inclusive `cum_sum(outcome) <= 1` then +`.last()`, which kept post-event rows and returned the final follow-up row, +silently recording single events as censored. That dropped ~99% of simulated +events and inflated the marginal-HR variance ~8x relative to SEQTaRget (R). +""" + +import polars as pl + +from pySEQTarget.analysis._hazard import _truncate_to_first_event + + +def _grid(rows): + # rows: list of (id, trial, [outcome per follow-up 0..T]) + recs = [] + for uid, trial, outs in rows: + for f, o in enumerate(outs): + recs.append((uid, trial, f, o)) + return pl.DataFrame( + recs, schema=["id", "trial", "followup", "outcome"], orient="row" + ) + + +def test_first_event_row_is_kept_for_each_pattern(): + grid = _grid( + [ + (1, 0, [0, 0, 1, 0, 0]), # single interior event -> (followup=2, event=1) + (2, 0, [0, 0, 0, 0, 0]), # no event -> (followup=4, event=0) + (3, 0, [0, 1, 0, 1, 0]), # two events; first -> (followup=1, event=1) + (4, 0, [1, 0, 0, 0, 0]), # event at time 0 -> (followup=0, event=1) + (5, 0, [0, 0, 0, 0, 1]), # event at last row -> (followup=4, event=1) + ] + ) + + out = ( + _truncate_to_first_event(grid, "id", "outcome") + .sort("id") + .select(["id", "followup", "outcome"]) + ) + + assert out.to_dicts() == [ + {"id": 1, "followup": 2, "outcome": 1}, + {"id": 2, "followup": 4, "outcome": 0}, + {"id": 3, "followup": 1, "outcome": 1}, + {"id": 4, "followup": 0, "outcome": 1}, + {"id": 5, "followup": 4, "outcome": 1}, + ] + + +def test_no_events_are_dropped(): + # Every unit that has >=1 simulated outcome must end up with event=1; only the + # all-zero unit (id=2) is censored. This is the property the old idiom broke. + grid = _grid( + [ + (1, 0, [0, 0, 1, 0, 0]), + (2, 0, [0, 0, 0, 0, 0]), + (3, 0, [0, 1, 0, 1, 0]), + (4, 0, [1, 0, 0, 0, 0]), + (5, 0, [0, 0, 0, 0, 1]), + ] + ) + out = _truncate_to_first_event(grid, "id", "outcome") + true_units_with_event = ( + grid.group_by("id").agg(pl.col("outcome").max().alias("ever"))["ever"].sum() + ) + assert out["outcome"].sum() == true_units_with_event == 4 + + +def test_grouping_is_per_id_and_trial(): + # Same id, two trials with different first-event times must be reduced + # independently. + grid = _grid( + [ + (1, 0, [0, 0, 1, 0]), # trial 0: event at 2 + (1, 1, [1, 0, 0, 0]), # trial 1: event at 0 + ] + ) + out = ( + _truncate_to_first_event(grid, "id", "outcome") + .sort(["id", "trial"]) + .select(["id", "trial", "followup", "outcome"]) + ) + assert out.to_dicts() == [ + {"id": 1, "trial": 0, "followup": 2, "outcome": 1}, + {"id": 1, "trial": 1, "followup": 0, "outcome": 1}, + ] + + +def test_beats_the_old_buggy_idiom(): + # Lock the regression: the previous `cum_sum <= 1` then `.last()` loses the + # single interior events that the fixed helper retains. + grid = _grid([(uid, 0, [0, 0, 1, 0, 0]) for uid in range(1, 11)]) + + fixed = _truncate_to_first_event(grid, "id", "outcome")["outcome"].sum() + + old = ( + grid.with_columns(pl.col("outcome").cum_sum().over(["id", "trial"]).alias("cs")) + .filter(pl.col("cs") <= 1) + .group_by(["id", "trial"]) + .last()["outcome"] + .sum() + ) + + assert fixed == 10 # every unit's single event retained + assert old == 0 # old idiom dropped all of them + assert fixed > old diff --git a/tests/test_jax.py b/tests/test_jax.py index 7461437..ae6bfa2 100644 --- a/tests/test_jax.py +++ b/tests/test_jax.py @@ -7,9 +7,9 @@ # every platform — skip the whole module rather than erroring at collection. pytest.importorskip("jax") -from pySEQTarget import SEQopts, SEQuential -from pySEQTarget.data import load_data -from pySEQTarget.helpers._jax_fit import _JaxFit +from pySEQTarget import SEQopts, SEQuential # noqa: E402 +from pySEQTarget.data import load_data # noqa: E402 +from pySEQTarget.helpers._jax_fit import _JaxFit # noqa: E402 def _fit(method, glm_package, dataset="SEQdata", **opts): @@ -120,3 +120,73 @@ def test_jax_warm_start_reaches_same_optimum(): start_params=(cold.params.values, list(cold.model.exog_names)), ) assert list(warm.params) == approx(list(cold.params), rel=1e-3, abs=1e-3) + + +def test_jax_fit_pickle_roundtrip(): + # _JaxFit holds a patsy DesignInfo, which cannot be pickled; offload and + # the parallel bootstrap both pickle fitted models. The wrapper must + # rebuild the design info on unpickle (same strategy as _GlumFit) and keep + # predict/bse/summary working. + import pickle + + df = _binary_frame() + m = _JaxFit("y ~ x1 + x2", df) + + m2 = pickle.loads(pickle.dumps(m)) + + assert list(m2.params) == approx(list(m.params), rel=1e-12, abs=1e-12) + assert m2.predict(df) == approx(m.predict(df), rel=1e-10, abs=1e-12) + assert list(m2.bse) == approx(list(m.bse), rel=1e-10, abs=1e-12) + assert str(m2.summary()) + + +def test_jax_offload_bootstrap_survival(): + # End-to-end: offload=True pickles every fitted model to disk via joblib. + data = load_data("SEQdata") + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts( + glm_package="jax", + bootstrap_nboot=2, + seed=7, + km_curves=True, + offload=True, + ), + ) + s.expand() + s.bootstrap() + s.fit() + s.survival() + assert s.km_data.height > 0 + + +def test_jax_parallel_bootstrap(): + # End-to-end: parallel=True pickles the SEQuential object into worker + # processes and the fitted models back. + data = load_data("SEQdata") + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts( + glm_package="jax", bootstrap_nboot=2, seed=7, parallel=True, ncores=2 + ), + ) + s.expand() + s.bootstrap() + s.fit() + assert len(s.outcome_model) == 3 # main + 2 replicates diff --git a/tests/test_offload.py b/tests/test_offload.py index 1e82858..ddffe2f 100644 --- a/tests/test_offload.py +++ b/tests/test_offload.py @@ -40,3 +40,80 @@ def test_compevent_offload(): warnings.filterwarnings("ignore") model.fit() model.survival() + + +def test_weight_models_fully_offloaded(tmp_path): + # _offload_weights used to check nonexistent attributes (LTFU_model, + # visit_model) and only offload the LAST treatment level's model. All + # fitted weight models must end up as path refs, and summaries must load + # them back transparently. + data = load_data("SEQdata_LTFU") + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="censoring", + parameters=SEQopts( + weighted=True, + weight_preexpansion=False, + cense_colname="LTFU", + offload=True, + offload_dir=str(tmp_path), + seed=42, + ), + ) + s.expand() + s.fit() + + for m in s.numerator_model + s.denominator_model: + assert m is None or isinstance(m, str) + assert isinstance(s.cense_numerator_model, str) + assert isinstance(s.cense_denominator_model, str) + + out = s.collect() + for kind in ("numerator", "denominator", "outcome"): + summaries = out.summary(kind) + assert len(summaries) >= 1 + assert all(str(smry) for smry in summaries) + + +def test_serial_bootstrap_offload_writes_DT_once(monkeypatch, tmp_path): + # The serial bootstrap path used to save the _DT parquet twice per fit. + from pySEQTarget.helpers._offloader import Offloader + + writes = [] + real_save = Offloader.save_dataframe + + def spy(self, df, name): + writes.append(name) + return real_save(self, df, name) + + monkeypatch.setattr(Offloader, "save_dataframe", spy) + + s = SEQuential( + load_data("SEQdata"), + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts( + bootstrap_nboot=2, + seed=42, + offload=True, + offload_dir=str(tmp_path), + ), + ) + s.expand() + s.bootstrap() + s.fit() + + assert writes.count("_DT") == 1 diff --git a/tests/test_optional_covariate_args.py b/tests/test_optional_covariate_args.py new file mode 100644 index 0000000..7660c89 --- /dev/null +++ b/tests/test_optional_covariate_args.py @@ -0,0 +1,30 @@ +"""Regression test: time_varying_cols and fixed_cols are documented Optional. + +Constructing without them used to crash in _param_checker (`set(None)`), and +several downstream sites iterate self.fixed_cols directly. Omitting both must +work through the whole pipeline. +""" + +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + + +def test_pipeline_runs_without_covariate_args(): + s = SEQuential( + load_data("SEQdata"), + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + method="ITT", + parameters=SEQopts(km_curves=True, seed=42), + ) + s.expand() + s.fit() + s.survival() + + # The auto-built outcome formula contains no covariate terms beyond the + # treatment/followup/trial defaults. + assert "sex" not in s.covariates + assert s.km_data.height > 0 diff --git a/tests/test_output_files.py b/tests/test_output_files.py new file mode 100644 index 0000000..8ff47ac --- /dev/null +++ b/tests/test_output_files.py @@ -0,0 +1,71 @@ +"""Markdown report generation (SEQoutput.to_md / _build_md). + +_build_md used to index numerator_models[0]/outcome_models[0] directly, which +crashed for weighted ITT analyses (no treatment-weight models exist — the +attribute is None) and for offloaded models (path refs, not fitted objects). +It now routes through SEQoutput.summary, which handles both. +""" + +import pytest + +# pandas.DataFrame.to_markdown needs tabulate (the "output" extra). +pytest.importorskip("tabulate") + +from pySEQTarget import SEQopts, SEQuential # noqa: E402 +from pySEQTarget.data import load_data # noqa: E402 + + +def test_to_md_weighted_ITT_without_numerator_models(tmp_path): + s = SEQuential( + load_data("SEQdata_LTFU"), + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(weighted=True, cense_colname="LTFU", seed=42), + ) + s.expand() + s.fit() + out = s.collect() + + md_file = tmp_path / "report.md" + out.to_md(str(md_file)) + content = md_file.read_text() + assert "Outcome Model" in content + # No treatment-weight models under ITT: the section is skipped, not a crash. + assert "Numerator Model" not in content + + +def test_to_md_with_offloaded_models(tmp_path): + s = SEQuential( + load_data("SEQdata"), + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="censoring", + parameters=SEQopts( + weighted=True, + weight_preexpansion=True, + offload=True, + offload_dir=str(tmp_path / "models"), + seed=42, + ), + ) + s.expand() + s.fit() + out = s.collect() + + md_file = tmp_path / "report.md" + out.to_md(str(md_file)) + content = md_file.read_text() + assert "Numerator Model" in content + assert "Denominator Model" in content + assert "Outcome Model" in content diff --git a/tests/test_parallel.py b/tests/test_parallel.py index 907121d..a6384f9 100644 --- a/tests/test_parallel.py +++ b/tests/test_parallel.py @@ -43,3 +43,47 @@ def test_parallel_ITT(): ], abs=1e-6, ) + + +def _hazard_run(parallel, glm_package): + data = load_data("SEQdata") + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts( + glm_package=glm_package, + hazard_estimate=True, + bootstrap_nboot=4, + ncores=2, + parallel=parallel, + seed=42, + ), + ) + s.expand() + s.bootstrap() + s.fit() + s.hazard() + hr = s.hazard_ratio + return (hr["Hazard ratio"][0], hr["LCI"][0], hr["UCI"][0]) + + +@pytest.mark.skipif( + os.getenv("CI") == "true", reason="Parallelism test hangs in CI environment" +) +@pytest.mark.parametrize("glm_package", ["statsmodels", "glum"]) +def test_parallel_hazard_matches_serial(glm_package): + # The process-pool bootstrap must produce the same hazard ratio + CI as the + # serial loop. Locks two fixes: (1) glum's _GlumFit is now picklable so the + # fitted models survive crossing the process boundary, and (2) the worker + # calls the raw fit body, so outcome_model[i] is a model dict (not a list) + # for the hazard consumer to index. Previously crashed for both backends. + serial = _hazard_run(parallel=False, glm_package=glm_package) + parallel = _hazard_run(parallel=True, glm_package=glm_package) + assert parallel == pytest.approx(serial, rel=1e-9, abs=1e-12) diff --git a/tests/test_selection_random.py b/tests/test_selection_random.py index 6dbe2b8..3607dcc 100644 --- a/tests/test_selection_random.py +++ b/tests/test_selection_random.py @@ -54,3 +54,36 @@ def test_selection_random_is_reproducible_with_fixed_seed(): a = _build(selection_random=True, selection_sample=0.5, seed=7) b = _build(selection_random=True, selection_sample=0.5, seed=7) assert a.DT.equals(b.DT) + + +def test_selection_random_nonzero_control_level(): + # Regression: the filter used `is_in(sample) | col != level`, which parses + # as `(is_in | col) != level` and silently dropped every sampled control + # trial whenever treatment_level[0] != 0 (e.g. [1, 2]) — the whole control + # arm vanished. Sampled controls must be retained. + prob = 0.5 + + def build(**opts): + opts.setdefault("seed", 1) + s = SEQuential( + load_data("SEQdata_multitreatment"), + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(treatment_level=[1, 2], **opts), + ) + s.expand() + return s + + base = _arm_trial_starts(build().DT) + sel = _arm_trial_starts(build(selection_random=True, selection_sample=prob).DT) + + # Non-control arm (level 2) fully retained; control arm (level 1) + # subsampled to the requested fraction — not dropped entirely. + assert sel[2] == base[2] + assert sel[1] == int(prob * base[1]) diff --git a/tests/test_verbose_censored_counts.py b/tests/test_verbose_censored_counts.py new file mode 100644 index 0000000..b007957 --- /dev/null +++ b/tests/test_verbose_censored_counts.py @@ -0,0 +1,49 @@ +import re + +import polars as pl + +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + + +def _expand(method, capsys): + data = load_data("SEQdata") + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method=method, + parameters=SEQopts(verbose=True), + ) + s.expand() + return s, capsys.readouterr().out + + +def test_verbose_reports_uncensored_censored_split(capsys): + # Under censoring the verbose output reports how many expanded rows enter the + # outcome model (un-censored) vs are artificially censored, so the count + # lines up with implementations that print only the modelled rows. + s, out = _expand("censoring", capsys) + + total = int( + re.search(r"Final analysis dataset: ([\d,]+)", out).group(1).replace(",", "") + ) + unc = int(re.search(r"uncensored\): ([\d,]+)", out).group(1).replace(",", "")) + cen = int(re.search(r"treatment switch\): ([\d,]+)", out).group(1).replace(",", "")) + + assert unc + cen == total + assert cen > 0 # SEQdata has treatment switches, so some rows are censored + # The reported un-censored count must equal the rows _outcome_fit fits on. + assert unc == s.DT.filter(pl.col("switch") != 1).height + + +def test_verbose_no_censored_split_for_itt(capsys): + # ITT applies no artificial censoring, so the split is not reported. + _, out = _expand("ITT", capsys) + assert "uncensored" not in out + assert "artificially censored" not in out diff --git a/tests/test_weight_fit_cached_across_bootstrap.py b/tests/test_weight_fit_cached_across_bootstrap.py new file mode 100644 index 0000000..f3a1cbd --- /dev/null +++ b/tests/test_weight_fit_cached_across_bootstrap.py @@ -0,0 +1,77 @@ +"""With weight_preexpansion=True the weight models are fit on the un-resampled +pre-expansion data, so bootstrap replicates would refit identical models +and re-predict identical weights every iteration. The main fit's predicted +weight frame is cached and replicates only redo the join onto their resampled +DT — results must be unchanged (to numerical precision), the weight fitters +must run exactly once. +""" + +import sys + +import numpy as np + +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + + +def _run(monkeypatch=None, disable_cache=False): + seq_mod = sys.modules["pySEQTarget.SEQuential"] + fit_calls = [] + + real_fit_denominator = seq_mod._fit_denominator + + def spy_fit_denominator(self, WDT): + fit_calls.append(getattr(self, "_current_boot_idx", None)) + return real_fit_denominator(self, WDT) + + if monkeypatch is not None: + monkeypatch.setattr(seq_mod, "_fit_denominator", spy_fit_denominator) + if disable_cache: + real_bind = seq_mod._weight_bind + + def bind_no_cache(self, WDT): + result = real_bind(self, WDT) + # Drop the cache after every bind so each replicate refits. + self._main_weight_WDT = None + return result + + monkeypatch.setattr(seq_mod, "_weight_bind", bind_no_cache) + + s = SEQuential( + load_data("SEQdata"), + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="censoring", + parameters=SEQopts( + weighted=True, + weight_preexpansion=True, + bootstrap_nboot=3, + seed=42, + ), + ) + s.expand() + s.bootstrap() + s.fit() + coefs = np.concatenate([np.asarray(m["outcome"].params) for m in s.outcome_model]) + return fit_calls, coefs + + +def test_weight_models_fit_once_across_bootstrap(monkeypatch): + fit_calls, _ = _run(monkeypatch) + assert fit_calls == [None] # main fit only, no replicate refits + + +def test_cached_weights_match_refit_weights(monkeypatch): + _, cached = _run(monkeypatch) + fit_calls, refit = _run(monkeypatch, disable_cache=True) + assert len(fit_calls) == 4 # cache disabled: main + 3 replicate refits + # Identical to numerical precision: the cached and refit paths assemble the + # GLM input via different code routes, so multi-threaded BLAS can differ in + # the last few ULPs (passes bit-identical on CI, not always locally). A tight + # tolerance still catches any real divergence, which would be far larger. + assert np.allclose(cached, refit, rtol=0, atol=1e-10) diff --git a/tests/test_weight_main_fit_preserved.py b/tests/test_weight_main_fit_preserved.py new file mode 100644 index 0000000..3e05c1c --- /dev/null +++ b/tests/test_weight_main_fit_preserved.py @@ -0,0 +1,76 @@ +import os + +import pytest +from pytest import approx + +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + + +def _fit_weighted(bootstrap_nboot, parallel=False, glm_package="statsmodels", seed=42): + data = load_data("SEQdata") + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="censoring", + parameters=SEQopts( + glm_package=glm_package, + weighted=True, + # post-expansion weights are refit on every (resampled) replicate, so + # this is the case where an in-process bootstrap clobbers the stored + # main-fit weight models. + weight_preexpansion=False, + bootstrap_nboot=bootstrap_nboot, + parallel=parallel, + ncores=2, + seed=seed, + ), + ) + s.expand() + if bootstrap_nboot > 0: + s.bootstrap() + s.fit() + return s + + +def _weight_params(s): + return { + "numerator": [list(m.params) for m in s.numerator_model if m is not None], + "denominator": [list(m.params) for m in s.denominator_model if m is not None], + } + + +def _assert_same(a, b): + for kind in ("numerator", "denominator"): + assert len(a[kind]) == len(b[kind]) + for pa, pb in zip(a[kind], b[kind]): + assert pb == approx(pa, rel=1e-9, abs=1e-9) + + +@pytest.mark.parametrize("glm_package", ["statsmodels", "glum"]) +def test_main_weight_models_preserved_after_bootstrap(glm_package): + # _fit_numerator/_fit_denominator overwrite self.X_model on every fit, and a + # serial bootstrap runs replicates in-process — so without preservation the + # stored weight model (used by the numerator/denominator summary) would be + # the last resampled replicate, not the main fit. The main-fit models must + # survive the bootstrap loop. + main = _weight_params(_fit_weighted(0, glm_package=glm_package)) + booted = _weight_params(_fit_weighted(3, glm_package=glm_package)) + _assert_same(main, booted) + + +@pytest.mark.skipif( + os.getenv("CI") == "true", reason="Parallelism test hangs in CI environment" +) +def test_main_weight_models_match_across_serial_and_parallel(): + main = _weight_params(_fit_weighted(0, glm_package="glum")) + serial = _weight_params(_fit_weighted(3, parallel=False, glm_package="glum")) + parallel = _weight_params(_fit_weighted(3, parallel=True, glm_package="glum")) + _assert_same(main, serial) + _assert_same(main, parallel)