From 57d293101a46a2c6856d28cb85e5b7feb4152591 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 14 May 2026 16:14:49 +0200 Subject: [PATCH 1/3] Add local H5 worker service boundary --- changelog.d/976.added | 1 + docs/engineering/stages/build_outputs.md | 18 + modal_app/worker_script.py | 241 ++++------ .../build_outputs/__init__.py | 2 +- .../build_outputs/worker_service.py | 420 ++++++++++++++++++ .../test_worker_script_tiny_fixture.py | 15 + .../unit/build_outputs/test_worker_service.py | 318 +++++++++++++ tests/unit/test_modal_worker_script.py | 41 +- 8 files changed, 883 insertions(+), 173 deletions(-) create mode 100644 changelog.d/976.added create mode 100644 policyengine_us_data/build_outputs/worker_service.py create mode 100644 tests/unit/build_outputs/test_worker_service.py diff --git a/changelog.d/976.added b/changelog.d/976.added new file mode 100644 index 000000000..3bcc2ceee --- /dev/null +++ b/changelog.d/976.added @@ -0,0 +1 @@ +Add a reusable local H5 worker service boundary and keep the Modal worker script as a thin adapter. diff --git a/docs/engineering/stages/build_outputs.md b/docs/engineering/stages/build_outputs.md index 1c2c9e0fe..9be5dc2e3 100644 --- a/docs/engineering/stages/build_outputs.md +++ b/docs/engineering/stages/build_outputs.md @@ -32,6 +32,24 @@ source-variable cloning, postprocessing, or writing concern. Do not place country-specific payload mutation in `build_h5()` when it can be represented as a postprocessor. +## Worker Chunk Execution + +`LocalH5WorkerService` is the reusable Stage 4 boundary for executing one +prepared local-H5 worker chunk. It consumes a `WorkerSession`, typed +`AreaBuildRequest` objects, and a `WorkerExecutionConfig`, then returns a +structured `WorkerResult`. + +`modal_app.worker_script` should remain a thin CLI/JSON adapter around this +service. It may parse legacy `--work-items` and typed `--requests-json`, prepare +the worker session, and print the legacy coordinator JSON shape, but it should +not regain build-loop, write-loop, or validation-loop logic. + +For now, `WorkerResult.to_legacy_dict()` preserves the existing coordinator +contract with `completed`, `failed`, `errors`, `validation_rows`, and +`validation_summary`. New code should prefer the structured `results` and +`issues` fields. Removing the legacy shape and moving the coordinator off worker +subprocess JSON is a later migration step. + ## Payload Postprocessors Payload postprocessors are ordered, country- or product-specific transformations diff --git a/modal_app/worker_script.py b/modal_app/worker_script.py index d89d22b1c..eae96a0f0 100644 --- a/modal_app/worker_script.py +++ b/modal_app/worker_script.py @@ -10,7 +10,6 @@ import sys import traceback from pathlib import Path -from typing import Any def parse_args(argv: list[str] | None = None): @@ -148,38 +147,6 @@ def _build_publishing_inputs(*, args, run_id: str): return worker_inputs.to_publishing_input_bundle(run_id=run_id) -def _build_kwargs_from_request(request) -> dict[str, Any]: - """Translate a typed request into `build_h5(...)` keyword arguments.""" - - if request.area_type == "national": - return {} - - if len(request.filters) != 1: - raise ValueError( - f"{request.area_type} requests must carry exactly one build filter" - ) - - build_filter = request.filters[0] - if ( - request.area_type in {"state", "district"} - and build_filter.geography_field == "cd_geoid" - and build_filter.op == "in" - ): - return {"cd_subset": [str(item) for item in build_filter.value]} - - if ( - request.area_type == "city" - and build_filter.geography_field == "county_fips" - and build_filter.op == "in" - ): - return {"county_fips_filter": {str(item) for item in build_filter.value}} - - raise ValueError( - f"Unsupported build filter for {request.area_type}: " - f"{build_filter.geography_field}:{build_filter.op}" - ) - - def _request_key(request) -> str: """Return the stable completion key used by worker/coordinator flows.""" @@ -196,20 +163,6 @@ def _work_item_key(work_item) -> str: return f"{item_type}:{item_id}" -def _resolve_output_path(*, output_dir: Path, output_relative_path: str) -> Path: - """Resolve one request output path and reject attempts to escape the run dir.""" - - candidate_path = (output_dir / output_relative_path).resolve(strict=False) - output_dir_path = output_dir.resolve(strict=False) - try: - candidate_path.relative_to(output_dir_path) - except ValueError as exc: - raise ValueError( - "output_relative_path must stay within the worker output_dir" - ) from exc - return candidate_path - - def _resolve_request_input( *, request_input_mode, @@ -232,6 +185,51 @@ def _resolve_request_input( return _request_key(request), request +def _resolve_worker_requests( + *, + request_input_mode, + request_inputs, + area_catalog, + geography, +) -> tuple[tuple, tuple]: + """Resolve queued CLI inputs into typed requests plus conversion issues.""" + + from policyengine_us_data.build_outputs.worker_service import WorkerIssue + + if request_input_mode == "requests": + return tuple(request_inputs), () + + requests = [] + issues = [] + for request_input in request_inputs: + request_key = _work_item_key(request_input) + try: + request_key, request = _resolve_request_input( + request_input_mode=request_input_mode, + request_input=request_input, + area_catalog=area_catalog, + geography=geography, + ) + except Exception as exc: + issues.append( + WorkerIssue( + item=request_key, + phase="request", + message=str(exc), + traceback=traceback.format_exc(), + ) + ) + continue + if request is None: + print( + f"Skipping {request_key}: no matching geography in legacy work item", + file=sys.stderr, + ) + continue + requests.append(request) + return tuple(requests), tuple(issues) + + def _log_worker_session_ready(*, scope: str, session, geography) -> None: """Write worker-session setup details to stderr for Modal diagnostics.""" @@ -252,7 +250,6 @@ def _log_worker_session_ready(*, scope: str, session, geography) -> None: def main(argv: list[str] | None = None): args = parse_args(argv) - dataset_path = Path(args.dataset_path) output_dir = Path(args.output_dir) run_id = args.run_id or output_dir.name or "local-worker" @@ -265,15 +262,17 @@ def main(argv: list[str] | None = None): original_stdout = sys.stdout sys.stdout = sys.stderr - from policyengine_us_data.calibration.publish_local_area import ( - build_h5, - ) from policyengine_us_data.build_outputs.area_catalog import USAreaCatalog from policyengine_us_data.build_outputs.requests import AreaBuildRequest from policyengine_us_data.build_outputs.validation import ( AreaValidationService, ValidationPolicy, ) + from policyengine_us_data.build_outputs.worker_service import ( + LocalH5WorkerService, + WorkerExecutionConfig, + WorkerResult, + ) from policyengine_us_data.build_outputs.worker_session import WorkerSessionFactory area_catalog = USAreaCatalog.default() @@ -297,8 +296,6 @@ def main(argv: list[str] | None = None): artifacts_dir=Path(args.artifacts_dir) if args.artifacts_dir else None, expected_scope_fingerprint=args.scope_fingerprint, ) - weights = session.weights.values - n_records = session.weights.n_records geography = session.geography validation_context = session.validation_context _log_worker_session_ready(scope=scope, session=session, geography=geography) @@ -312,111 +309,55 @@ def main(argv: list[str] | None = None): file=sys.stderr, ) - results = { - "completed": [], - "failed": [], - "errors": [], - "validation_rows": [], - "validation_summary": {}, - } - - for request_input in request_inputs: - try: - request_key = ( - _work_item_key(request_input) - if request_input_mode == "work_items" - else None - ) - request_key, request = _resolve_request_input( - request_input_mode=request_input_mode, - request_input=request_input, - area_catalog=area_catalog, - geography=geography, - ) - if request is None: - print( - f"Skipping {request_key}: no matching geography in legacy work item", - file=sys.stderr, - ) - continue + requests, request_issues = _resolve_worker_requests( + request_input_mode=request_input_mode, + request_inputs=request_inputs, + area_catalog=area_catalog, + geography=geography, + ) + worker_result = LocalH5WorkerService( + validation_service=validation_service, + ).execute( + session=session, + requests=requests, + config=WorkerExecutionConfig( + output_dir=output_dir, + takeup_filter=tuple(takeup_filter), + validate=not args.no_validate, + ), + ) + if request_issues: + worker_result = WorkerResult( + area_results=worker_result.area_results, + issues=(*request_issues, *worker_result.issues), + ) - output_path = _resolve_output_path( - output_dir=output_dir, - output_relative_path=request.output_relative_path, + for area_result in worker_result.area_results: + if area_result.status == "completed": + print(f"Completed {area_result.key}", file=sys.stderr) + else: + message = ( + area_result.issues[0].message if area_result.issues else "unknown error" ) - output_path.parent.mkdir(parents=True, exist_ok=True) - build_kwargs = _build_kwargs_from_request(request) - if request.area_type == "national": - n_clones_from_weights = weights.shape[0] // n_records - if n_clones_from_weights != geography.n_clones: - raise ValueError( - f"National weights have {n_clones_from_weights} clones " - f"but geography has {geography.n_clones}. " - "Use the matching saved geography artifact." - ) - path = build_h5( - weights=weights, - geography=geography, - dataset_path=dataset_path, - output_path=output_path, - ) - else: - path = build_h5( - weights=weights, - geography=geography, - dataset_path=dataset_path, - output_path=output_path, - takeup_filter=takeup_filter, - **build_kwargs, - ) - - if path: - results["completed"].append(request_key) - print( - f"Completed {request_key}", - file=sys.stderr, - ) - - if not args.no_validate and validation_context is not None: - try: - validation_result = validation_service.validate_request( - context=validation_context, - h5_path=str(path), - request=request, - ) - v_rows = list(validation_result.rows) - results["validation_rows"].extend(v_rows) - summary = dict(validation_result.summary) - results["validation_summary"][request_key] = summary - print( - f" Validated {request_key}: " - f"{summary['n_targets']} targets, " - f"{summary['n_sanity_fail']} sanity fails, " - f"mean RAE={summary['mean_rel_abs_error']:.4f}", - file=sys.stderr, - ) - except Exception as ve: - print( - f" Validation failed for {request_key}: {ve}", - file=sys.stderr, - ) - - except Exception as e: - results["failed"].append(request_key) - results["errors"].append( - { - "item": request_key, - "error": str(e), - "traceback": traceback.format_exc(), - } + print(f"FAILED {area_result.key}: {message}", file=sys.stderr) + if area_result.validation_status == "passed" and area_result.validation_summary: + summary = area_result.validation_summary + print( + f" Validated {area_result.key}: " + f"{summary['n_targets']} targets, " + f"{summary['n_sanity_fail']} sanity fails, " + f"mean RAE={summary['mean_rel_abs_error']:.4f}", + file=sys.stderr, ) + elif area_result.validation_status == "error" and area_result.issues: print( - f"FAILED {request_key}: {e}", + f" Validation failed for {area_result.key}: " + f"{area_result.issues[-1].message}", file=sys.stderr, ) sys.stdout = original_stdout - print(json.dumps(results)) + print(json.dumps(worker_result.to_legacy_dict())) if __name__ == "__main__": diff --git a/policyengine_us_data/build_outputs/__init__.py b/policyengine_us_data/build_outputs/__init__.py index 93a32b1ab..9f4e1eeb5 100644 --- a/policyengine_us_data/build_outputs/__init__.py +++ b/policyengine_us_data/build_outputs/__init__.py @@ -8,5 +8,5 @@ artifacts, worker-scoped session and validation context setup, microsimulation access helpers, clone selection, entity reindexing, source-variable cloning, validated H5 payload contracts, ordered output postprocessing, one-area payload -building, and H5 writing. +building, H5 writing, and worker chunk execution. """ diff --git a/policyengine_us_data/build_outputs/worker_service.py b/policyengine_us_data/build_outputs/worker_service.py new file mode 100644 index 000000000..d9c17e010 --- /dev/null +++ b/policyengine_us_data/build_outputs/worker_service.py @@ -0,0 +1,420 @@ +"""Worker chunk execution boundary for local H5 publication.""" + +from __future__ import annotations + +import traceback as traceback_module +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Literal, Mapping, Sequence + +from policyengine_us_data.pipeline_metadata import pipeline_node + +from .builder import LocalAreaDatasetBuilder +from .requests import AreaBuildRequest +from .us_augmentations import default_us_postprocessors +from .validation import AreaValidationService +from .worker_session import WorkerSession +from .writer import H5Writer + +WorkerAreaStatus = Literal["completed", "failed", "skipped"] +WorkerIssuePhase = Literal["request", "build", "write", "validation"] +WorkerValidationStatus = Literal["not_run", "passed", "error"] + +__all__ = [ + "LocalH5WorkerService", + "WorkerAreaResult", + "WorkerAreaStatus", + "WorkerExecutionConfig", + "WorkerIssue", + "WorkerIssuePhase", + "WorkerResult", + "WorkerValidationStatus", +] + + +@pipeline_node( + id="local_h5_worker_execution_config", + label="WorkerExecutionConfig", + node_type="library", + description="Runtime policy for one local H5 worker-service execution.", + source_file="policyengine_us_data/build_outputs/worker_service.py", + status="current", + stability="moving", + pathways=["local_h5"], + validation_commands=[ + "uv run pytest tests/unit/build_outputs/test_worker_service.py" + ], +) +@dataclass(frozen=True) +class WorkerExecutionConfig: + """Execution policy for one worker chunk.""" + + output_dir: Path + takeup_filter: tuple[str, ...] = () + validate: bool = True + fail_on_validation_error: bool = False + + def __post_init__(self) -> None: + object.__setattr__(self, "output_dir", Path(self.output_dir)) + object.__setattr__( + self, + "takeup_filter", + tuple(str(item) for item in self.takeup_filter), + ) + + +@pipeline_node( + id="local_h5_worker_issue", + label="WorkerIssue", + node_type="library", + description="Structured issue reported by one local H5 worker request.", + source_file="policyengine_us_data/build_outputs/worker_service.py", + status="current", + stability="moving", + pathways=["local_h5"], + validation_commands=[ + "uv run pytest tests/unit/build_outputs/test_worker_service.py" + ], +) +@dataclass(frozen=True) +class WorkerIssue: + """Structured worker issue for request, build, write, or validation failures.""" + + item: str + phase: WorkerIssuePhase + message: str + traceback: str | None = None + + def to_dict(self) -> dict[str, Any]: + """Serialize the issue to worker JSON output.""" + + payload: dict[str, Any] = { + "item": self.item, + "phase": self.phase, + "error": self.message, + } + if self.traceback: + payload["traceback"] = self.traceback + return payload + + +@pipeline_node( + id="local_h5_worker_area_result", + label="WorkerAreaResult", + node_type="library", + description="Structured result for one local H5 worker request.", + source_file="policyengine_us_data/build_outputs/worker_service.py", + status="current", + stability="moving", + pathways=["local_h5"], + validation_commands=[ + "uv run pytest tests/unit/build_outputs/test_worker_service.py" + ], +) +@dataclass(frozen=True) +class WorkerAreaResult: + """Structured result for one area handled by a worker.""" + + key: str + request: AreaBuildRequest + status: WorkerAreaStatus + output_relative_path: str + output_path: Path | None = None + validation_status: WorkerValidationStatus = "not_run" + validation_rows: tuple[Mapping[str, Any], ...] = () + validation_summary: Mapping[str, Any] = field(default_factory=dict) + issues: tuple[WorkerIssue, ...] = () + + def __post_init__(self) -> None: + object.__setattr__( + self, + "output_path", + Path(self.output_path) if self.output_path is not None else None, + ) + object.__setattr__(self, "validation_rows", tuple(self.validation_rows)) + object.__setattr__(self, "validation_summary", dict(self.validation_summary)) + object.__setattr__(self, "issues", tuple(self.issues)) + + def to_dict(self) -> dict[str, Any]: + """Serialize the area result to worker JSON output.""" + + return { + "key": self.key, + "request": self.request.to_dict(), + "status": self.status, + "output_relative_path": self.output_relative_path, + "output_path": str(self.output_path) if self.output_path else None, + "validation_status": self.validation_status, + "validation_rows": [dict(row) for row in self.validation_rows], + "validation_summary": dict(self.validation_summary), + "issues": [issue.to_dict() for issue in self.issues], + } + + +@pipeline_node( + id="local_h5_worker_result", + label="WorkerResult", + node_type="library", + description="Structured result for one local H5 worker chunk.", + source_file="policyengine_us_data/build_outputs/worker_service.py", + status="current", + stability="moving", + pathways=["local_h5"], + validation_commands=[ + "uv run pytest tests/unit/build_outputs/test_worker_service.py" + ], +) +@dataclass(frozen=True) +class WorkerResult: + """Structured result for a worker chunk.""" + + area_results: tuple[WorkerAreaResult, ...] = () + issues: tuple[WorkerIssue, ...] = () + + def __post_init__(self) -> None: + object.__setattr__(self, "area_results", tuple(self.area_results)) + object.__setattr__(self, "issues", tuple(self.issues)) + + def to_legacy_dict(self) -> dict[str, Any]: + """Serialize to the existing worker/coordinator JSON contract.""" + + completed = [ + result.key for result in self.area_results if result.status == "completed" + ] + failed = [ + result.key for result in self.area_results if result.status == "failed" + ] + failed.extend(issue.item for issue in self.issues) + validation_rows: list[dict[str, Any]] = [] + validation_summary: dict[str, Mapping[str, Any]] = {} + legacy_errors = [issue.to_dict() for issue in self.issues] + structured_issues = [issue.to_dict() for issue in self.issues] + + for result in self.area_results: + validation_rows.extend(dict(row) for row in result.validation_rows) + if result.validation_summary: + validation_summary[result.key] = dict(result.validation_summary) + issue_dicts = [issue.to_dict() for issue in result.issues] + structured_issues.extend(issue_dicts) + if result.status == "failed": + legacy_errors.extend(issue_dicts) + + return { + "completed": completed, + "failed": failed, + "errors": legacy_errors, + "validation_rows": validation_rows, + "validation_summary": validation_summary, + "results": [result.to_dict() for result in self.area_results], + "issues": structured_issues, + } + + +@pipeline_node( + id="local_h5_worker_service", + label="LocalH5WorkerService", + node_type="library", + description="Execute one worker chunk of local H5 build requests.", + source_file="policyengine_us_data/build_outputs/worker_service.py", + status="current", + stability="moving", + pathways=["local_h5"], + validation_commands=[ + "uv run pytest tests/unit/build_outputs/test_worker_service.py", + "uv run pytest tests/integration/build_outputs/h5_worker_runtime/test_worker_script_tiny_fixture.py", + ], +) +@dataclass(frozen=True) +class LocalH5WorkerService: + """Execute typed local H5 requests for one prepared worker session.""" + + builder: Any = field( + default_factory=lambda: LocalAreaDatasetBuilder( + postprocessors=default_us_postprocessors() + ) + ) + writer: Any = field(default_factory=H5Writer) + validation_service: AreaValidationService = field( + default_factory=AreaValidationService + ) + + def execute( + self, + *, + session: WorkerSession, + requests: Sequence[AreaBuildRequest], + config: WorkerExecutionConfig, + ) -> WorkerResult: + """Build and optionally validate every request in one worker chunk.""" + + area_results = tuple( + self._execute_request(session=session, request=request, config=config) + for request in requests + ) + return WorkerResult(area_results=area_results) + + def _execute_request( + self, + *, + session: WorkerSession, + request: AreaBuildRequest, + config: WorkerExecutionConfig, + ) -> WorkerAreaResult: + key = _request_key(request) + try: + output_path = _resolve_output_path( + output_dir=config.output_dir, + output_relative_path=request.output_relative_path, + ) + except Exception as exc: + return _failed_result( + key=key, + request=request, + phase="request", + error=exc, + ) + + try: + if request.area_type == "national": + _validate_national_weight_scope(session) + build_result = self.builder.build( + source=session.source, + simulation=_source_simulation(session), + weights=session.weights, + geography=session.geography, + request=request, + takeup_filter=( + None + if request.area_type == "national" + else tuple(config.takeup_filter) + ), + ) + except Exception as exc: + return _failed_result( + key=key, + request=request, + phase="build", + error=exc, + output_path=output_path, + ) + + try: + write_result = self.writer.write( + payload=build_result.payload, + output_path=output_path, + ) + written_path = Path(getattr(write_result, "path", output_path)) + except Exception as exc: + return _failed_result( + key=key, + request=request, + phase="write", + error=exc, + output_path=output_path, + ) + + validation_status: WorkerValidationStatus = "not_run" + validation_rows: tuple[Mapping[str, Any], ...] = () + validation_summary: Mapping[str, Any] = {} + issues: tuple[WorkerIssue, ...] = () + if config.validate and session.validation_context is not None: + try: + validation_result = self.validation_service.validate_request( + context=session.validation_context, + h5_path=written_path, + request=request, + ) + validation_rows = tuple(validation_result.rows) + validation_summary = dict(validation_result.summary) + validation_status = "passed" + except Exception as exc: + issue = _issue(key=key, phase="validation", error=exc) + issues = (issue,) + validation_status = "error" + if config.fail_on_validation_error: + return WorkerAreaResult( + key=key, + request=request, + status="failed", + output_relative_path=request.output_relative_path, + output_path=written_path, + validation_status=validation_status, + issues=issues, + ) + + return WorkerAreaResult( + key=key, + request=request, + status="completed", + output_relative_path=request.output_relative_path, + output_path=written_path, + validation_status=validation_status, + validation_rows=validation_rows, + validation_summary=validation_summary, + issues=issues, + ) + + +def _request_key(request: AreaBuildRequest) -> str: + return f"{request.area_type}:{request.area_id}" + + +def _resolve_output_path(*, output_dir: Path, output_relative_path: str) -> Path: + candidate_path = (Path(output_dir) / output_relative_path).resolve(strict=False) + output_dir_path = Path(output_dir).resolve(strict=False) + try: + candidate_path.relative_to(output_dir_path) + except ValueError as exc: + raise ValueError( + "output_relative_path must stay within the worker output_dir" + ) from exc + return candidate_path + + +def _source_simulation(session: WorkerSession) -> Any: + provider = getattr(session.source, "variable_provider", None) + simulation = getattr(provider, "simulation", None) + if simulation is None: + raise ValueError("Worker session source does not expose a simulation") + return simulation + + +def _validate_national_weight_scope(session: WorkerSession) -> None: + if session.weights.n_clones != session.geography.n_clones: + raise ValueError( + f"National weights have {session.weights.n_clones} clones " + f"but geography has {session.geography.n_clones}. " + "Use the matching saved geography artifact." + ) + + +def _failed_result( + *, + key: str, + request: AreaBuildRequest, + phase: WorkerIssuePhase, + error: Exception, + output_path: Path | None = None, +) -> WorkerAreaResult: + return WorkerAreaResult( + key=key, + request=request, + status="failed", + output_relative_path=request.output_relative_path, + output_path=output_path, + issues=(_issue(key=key, phase=phase, error=error),), + ) + + +def _issue( + *, + key: str, + phase: WorkerIssuePhase, + error: Exception, +) -> WorkerIssue: + return WorkerIssue( + item=key, + phase=phase, + message=str(error), + traceback=traceback_module.format_exc(), + ) diff --git a/tests/integration/build_outputs/h5_worker_runtime/test_worker_script_tiny_fixture.py b/tests/integration/build_outputs/h5_worker_runtime/test_worker_script_tiny_fixture.py index 5cfb7635e..24abb0179 100644 --- a/tests/integration/build_outputs/h5_worker_runtime/test_worker_script_tiny_fixture.py +++ b/tests/integration/build_outputs/h5_worker_runtime/test_worker_script_tiny_fixture.py @@ -140,6 +140,9 @@ def test_worker_builds_district_h5_from_saved_geography(tmp_path): assert result["failed"] == [] assert result["errors"] == [] assert result["completed"] == [f"district:{request.area_id}"] + assert result["issues"] == [] + assert result["results"][0]["key"] == f"district:{request.area_id}" + assert result["results"][0]["status"] == "completed" assert (output_dir / request.output_relative_path).exists() @@ -158,6 +161,9 @@ def test_worker_builds_state_h5_from_package_geography(tmp_path): assert result["failed"] == [] assert result["errors"] == [] assert result["completed"] == [f"state:{request.area_id}"] + assert result["issues"] == [] + assert result["results"][0]["key"] == f"state:{request.area_id}" + assert result["results"][0]["status"] == "completed" assert (output_dir / request.output_relative_path).exists() @@ -177,6 +183,9 @@ def test_worker_builds_national_h5_from_package_geography(tmp_path): assert result["failed"] == [] assert result["errors"] == [] assert result["completed"] == ["national:US"] + assert result["issues"] == [] + assert result["results"][0]["key"] == "national:US" + assert result["results"][0]["status"] == "completed" assert (output_dir / request.output_relative_path).exists() @@ -213,6 +222,12 @@ def test_worker_validation_runs_for_tiny_district_state_and_national_h5s(tmp_pat assert parsed["failed"] == [] assert parsed["errors"] == [] assert parsed["completed"] == ["district:NC-01", "state:NC", "national:US"] + assert parsed["issues"] == [] + assert [item["key"] for item in parsed["results"]] == [ + "district:NC-01", + "state:NC", + "national:US", + ] assert len(parsed["validation_rows"]) == 3 assert set(parsed["validation_summary"]) == { "district:NC-01", diff --git a/tests/unit/build_outputs/test_worker_service.py b/tests/unit/build_outputs/test_worker_service.py new file mode 100644 index 000000000..182e25f94 --- /dev/null +++ b/tests/unit/build_outputs/test_worker_service.py @@ -0,0 +1,318 @@ +from __future__ import annotations + +from pathlib import Path +from types import SimpleNamespace + +import numpy as np + +from policyengine_us_data.build_outputs.requests import AreaBuildRequest, AreaFilter +from policyengine_us_data.build_outputs.validation import AreaValidationResult +from policyengine_us_data.build_outputs.worker_service import ( + LocalH5WorkerService, + WorkerAreaResult, + WorkerExecutionConfig, + WorkerIssue, + WorkerResult, +) +from policyengine_us_data.build_outputs.worker_session import WorkerSession + + +def _request( + area_type: str = "district", + area_id: str = "NC-01", + output_relative_path: str = "districts/NC-01.h5", +) -> AreaBuildRequest: + return AreaBuildRequest( + area_type=area_type, + area_id=area_id, + display_name=area_id, + output_relative_path=output_relative_path, + filters=( + AreaFilter( + geography_field="cd_geoid", + op="in", + value=("3701",), + ), + ) + if area_type != "national" + else (), + validation_geo_level="district" if area_type != "national" else "national", + validation_geographic_ids=(area_id,), + ) + + +def _session( + *, + validation_context=None, + weight_clones: int = 2, + geography_clones: int = 2, +) -> WorkerSession: + source = SimpleNamespace( + variable_provider=SimpleNamespace(simulation=SimpleNamespace()), + ) + return WorkerSession( + inputs=SimpleNamespace(), + scope="regional", + source=source, + weights=SimpleNamespace( + values=np.ones(weight_clones), + n_records=1, + n_clones=weight_clones, + ), + geography=SimpleNamespace(n_records=1, n_clones=geography_clones), + validation_context=validation_context, + ) + + +class FakeBuilder: + def __init__(self, *, fail: bool = False): + self.fail = fail + self.calls = [] + + def build(self, **kwargs): + self.calls.append(kwargs) + if self.fail: + raise RuntimeError("build failed") + return SimpleNamespace(payload=SimpleNamespace(name="payload")) + + +class FakeWriter: + def __init__(self, *, fail: bool = False): + self.fail = fail + self.calls = [] + + def write(self, **kwargs): + self.calls.append(kwargs) + if self.fail: + raise RuntimeError("write failed") + return SimpleNamespace(path=Path(kwargs["output_path"])) + + +class FakeValidationService: + def __init__(self, *, fail: bool = False): + self.fail = fail + self.calls = [] + + def validate_request(self, **kwargs): + self.calls.append(kwargs) + if self.fail: + raise RuntimeError("validation failed") + return AreaValidationResult( + rows=( + { + "variable": "household_count", + "sanity_check": "PASS", + "rel_abs_error": 0.0, + }, + ), + summary={ + "n_targets": 1, + "n_sanity_fail": 0, + "mean_rel_abs_error": 0.0, + }, + ) + + +def test_worker_result_preserves_legacy_and_structured_shapes(tmp_path): + request = _request() + result = WorkerResult( + area_results=( + WorkerAreaResult( + key="district:NC-01", + request=request, + status="completed", + output_relative_path=request.output_relative_path, + output_path=tmp_path / "districts" / "NC-01.h5", + validation_status="passed", + validation_rows=({"variable": "household_count"},), + validation_summary={"n_targets": 1}, + ), + ), + issues=( + WorkerIssue( + item="district:bad", + phase="request", + message="bad request", + ), + ), + ) + + payload = result.to_legacy_dict() + + assert payload["completed"] == ["district:NC-01"] + assert payload["failed"] == ["district:bad"] + assert payload["errors"] == [ + {"item": "district:bad", "phase": "request", "error": "bad request"} + ] + assert payload["validation_rows"] == [{"variable": "household_count"}] + assert payload["validation_summary"] == {"district:NC-01": {"n_targets": 1}} + assert payload["results"][0]["key"] == "district:NC-01" + assert payload["issues"][0]["item"] == "district:bad" + + +def test_worker_service_builds_writes_and_validates_request(tmp_path): + builder = FakeBuilder() + writer = FakeWriter() + validation_service = FakeValidationService() + request = _request() + + result = LocalH5WorkerService( + builder=builder, + writer=writer, + validation_service=validation_service, + ).execute( + session=_session(validation_context=SimpleNamespace()), + requests=(request,), + config=WorkerExecutionConfig( + output_dir=tmp_path, + takeup_filter=("takes_up_snap",), + validate=True, + ), + ) + + payload = result.to_legacy_dict() + + assert payload["completed"] == ["district:NC-01"] + assert payload["failed"] == [] + assert payload["errors"] == [] + assert payload["validation_summary"]["district:NC-01"]["n_targets"] == 1 + assert builder.calls[0]["request"] == request + assert builder.calls[0]["takeup_filter"] == ("takes_up_snap",) + assert writer.calls[0]["output_path"] == tmp_path / "districts" / "NC-01.h5" + assert validation_service.calls[0]["request"] == request + + +def test_worker_service_does_not_apply_takeup_filter_to_national_request(tmp_path): + builder = FakeBuilder() + request = _request("national", "US", "national/US.h5") + + result = LocalH5WorkerService( + builder=builder, + writer=FakeWriter(), + validation_service=FakeValidationService(), + ).execute( + session=_session(), + requests=(request,), + config=WorkerExecutionConfig( + output_dir=tmp_path, + takeup_filter=("takes_up_snap",), + validate=False, + ), + ) + + assert result.to_legacy_dict()["completed"] == ["national:US"] + assert builder.calls[0]["takeup_filter"] is None + + +def test_worker_service_reports_build_failures(tmp_path): + request = _request() + + result = LocalH5WorkerService( + builder=FakeBuilder(fail=True), + writer=FakeWriter(), + validation_service=FakeValidationService(), + ).execute( + session=_session(), + requests=(request,), + config=WorkerExecutionConfig(output_dir=tmp_path, validate=False), + ) + + payload = result.to_legacy_dict() + + assert payload["completed"] == [] + assert payload["failed"] == ["district:NC-01"] + assert payload["errors"][0]["phase"] == "build" + assert payload["errors"][0]["error"] == "build failed" + assert payload["results"][0]["status"] == "failed" + + +def test_worker_service_records_validation_errors_without_failing_by_default( + tmp_path, +): + request = _request() + + result = LocalH5WorkerService( + builder=FakeBuilder(), + writer=FakeWriter(), + validation_service=FakeValidationService(fail=True), + ).execute( + session=_session(validation_context=SimpleNamespace()), + requests=(request,), + config=WorkerExecutionConfig(output_dir=tmp_path, validate=True), + ) + + payload = result.to_legacy_dict() + + assert payload["completed"] == ["district:NC-01"] + assert payload["failed"] == [] + assert payload["errors"] == [] + assert payload["issues"][0]["phase"] == "validation" + assert payload["results"][0]["validation_status"] == "error" + + +def test_worker_service_can_fail_on_validation_error(tmp_path): + request = _request() + + result = LocalH5WorkerService( + builder=FakeBuilder(), + writer=FakeWriter(), + validation_service=FakeValidationService(fail=True), + ).execute( + session=_session(validation_context=SimpleNamespace()), + requests=(request,), + config=WorkerExecutionConfig( + output_dir=tmp_path, + validate=True, + fail_on_validation_error=True, + ), + ) + + payload = result.to_legacy_dict() + + assert payload["completed"] == [] + assert payload["failed"] == ["district:NC-01"] + assert payload["errors"][0]["phase"] == "validation" + + +def test_worker_service_rejects_output_path_escape(tmp_path): + request = _request() + object.__setattr__(request, "output_relative_path", "../escape.h5") + + result = LocalH5WorkerService( + builder=FakeBuilder(), + writer=FakeWriter(), + validation_service=FakeValidationService(), + ).execute( + session=_session(), + requests=(request,), + config=WorkerExecutionConfig(output_dir=tmp_path, validate=False), + ) + + payload = result.to_legacy_dict() + + assert payload["failed"] == ["district:NC-01"] + assert payload["errors"][0]["phase"] == "request" + assert "worker output_dir" in payload["errors"][0]["error"] + + +def test_worker_service_rejects_national_weight_geography_mismatch(tmp_path): + request = _request("national", "US", "national/US.h5") + + result = LocalH5WorkerService( + builder=FakeBuilder(), + writer=FakeWriter(), + validation_service=FakeValidationService(), + ).execute( + session=_session(weight_clones=2, geography_clones=3), + requests=(request,), + config=WorkerExecutionConfig(output_dir=tmp_path, validate=False), + ) + + payload = result.to_legacy_dict() + + assert payload["failed"] == ["national:US"] + assert payload["errors"][0]["phase"] == "build" + assert ( + "National weights have 2 clones but geography has 3" + in payload["errors"][0]["error"] + ) diff --git a/tests/unit/test_modal_worker_script.py b/tests/unit/test_modal_worker_script.py index 2ce7d26e7..3a73dc90a 100644 --- a/tests/unit/test_modal_worker_script.py +++ b/tests/unit/test_modal_worker_script.py @@ -177,31 +177,28 @@ def test_resolve_request_input_skips_legacy_work_item_without_request(): assert catalog.received_item == (work_item, geography) -def test_resolve_output_path_keeps_outputs_under_worker_directory(tmp_path): - output_dir = tmp_path / "worker-out" - output_dir.mkdir() +def test_resolve_worker_requests_records_legacy_conversion_issues(): + catalog = FakeAreaCatalog() + geography = object() + bad_work_item = {"type": "district", "id": "bad"} + good_work_item = {"type": "national", "id": "US"} + catalog.raise_for = bad_work_item - resolved = worker_script._resolve_output_path( - output_dir=output_dir, - output_relative_path="states/CA.h5", + requests, issues = worker_script._resolve_worker_requests( + request_input_mode="work_items", + request_inputs=(bad_work_item, good_work_item), + area_catalog=catalog, + geography=geography, ) - assert resolved == output_dir / "states" / "CA.h5" - - -def test_resolve_output_path_rejects_escaped_request_path(tmp_path): - output_dir = tmp_path / "worker-out" - output_dir.mkdir() - - try: - worker_script._resolve_output_path( - output_dir=output_dir, - output_relative_path="../escaped.h5", - ) - except ValueError as exc: - assert "must stay within the worker output_dir" in str(exc) - else: - raise AssertionError("Expected _resolve_output_path to reject traversal") + assert len(requests) == 1 + assert requests[0].area_type == "national" + assert requests[0].area_id == "US" + assert len(issues) == 1 + assert issues[0].item == "district:bad" + assert issues[0].phase == "request" + assert issues[0].message == "bad work item" + assert "ValueError: bad work item" in issues[0].traceback def test_log_worker_session_ready_includes_bootstrap_fallback_reason(capsys): From c0ebe4e1514220686f487a5125a86c3ad8522dd0 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 18 May 2026 18:44:41 +0200 Subject: [PATCH 2/3] Fix worker source isolation --- docs/engineering/stages/build_outputs.md | 13 +++++- .../build_outputs/worker_service.py | 28 +++++++------ .../build_outputs/worker_session.py | 42 ++++++++++++++----- .../test_worker_script_tiny_fixture.py | 9 ++++ .../unit/build_outputs/test_worker_service.py | 42 ++++++++++++++++++- 5 files changed, 108 insertions(+), 26 deletions(-) diff --git a/docs/engineering/stages/build_outputs.md b/docs/engineering/stages/build_outputs.md index 9be5dc2e3..12e132bec 100644 --- a/docs/engineering/stages/build_outputs.md +++ b/docs/engineering/stages/build_outputs.md @@ -39,6 +39,13 @@ prepared local-H5 worker chunk. It consumes a `WorkerSession`, typed `AreaBuildRequest` objects, and a `WorkerExecutionConfig`, then returns a structured `WorkerResult`. +`WorkerSession` owns worker-scoped setup that is safe to reuse across the +queued requests, such as weights, geography, validation context, and bootstrap +metadata. Source dataset snapshots are loaded per request through +`WorkerSession.load_source()` because the PolicyEngine microsimulation behind a +snapshot is mutable; reusing it across multiple H5 outputs can leak calculated +holder state into later outputs. + `modal_app.worker_script` should remain a thin CLI/JSON adapter around this service. It may parse legacy `--work-items` and typed `--requests-json`, prepare the worker session, and print the legacy coordinator JSON shape, but it should @@ -47,8 +54,10 @@ not regain build-loop, write-loop, or validation-loop logic. For now, `WorkerResult.to_legacy_dict()` preserves the existing coordinator contract with `completed`, `failed`, `errors`, `validation_rows`, and `validation_summary`. New code should prefer the structured `results` and -`issues` fields. Removing the legacy shape and moving the coordinator off worker -subprocess JSON is a later migration step. +`issues` fields. Validation exceptions remain visible in legacy `errors` so the +current coordinator does not drop them before it migrates to structured results. +Removing the legacy shape and moving the coordinator off worker subprocess JSON +is a later migration step. ## Payload Postprocessors diff --git a/policyengine_us_data/build_outputs/worker_service.py b/policyengine_us_data/build_outputs/worker_service.py index d9c17e010..04112acf0 100644 --- a/policyengine_us_data/build_outputs/worker_service.py +++ b/policyengine_us_data/build_outputs/worker_service.py @@ -9,9 +9,7 @@ from policyengine_us_data.pipeline_metadata import pipeline_node -from .builder import LocalAreaDatasetBuilder from .requests import AreaBuildRequest -from .us_augmentations import default_us_postprocessors from .validation import AreaValidationService from .worker_session import WorkerSession from .writer import H5Writer @@ -196,7 +194,7 @@ def to_legacy_dict(self) -> dict[str, Any]: validation_summary[result.key] = dict(result.validation_summary) issue_dicts = [issue.to_dict() for issue in result.issues] structured_issues.extend(issue_dicts) - if result.status == "failed": + if result.status == "failed" or result.validation_status == "error": legacy_errors.extend(issue_dicts) return { @@ -228,11 +226,7 @@ def to_legacy_dict(self) -> dict[str, Any]: class LocalH5WorkerService: """Execute typed local H5 requests for one prepared worker session.""" - builder: Any = field( - default_factory=lambda: LocalAreaDatasetBuilder( - postprocessors=default_us_postprocessors() - ) - ) + builder: Any = field(default_factory=lambda: _default_builder()) writer: Any = field(default_factory=H5Writer) validation_service: AreaValidationService = field( default_factory=AreaValidationService @@ -277,9 +271,10 @@ def _execute_request( try: if request.area_type == "national": _validate_national_weight_scope(session) + source = session.load_source() build_result = self.builder.build( - source=session.source, - simulation=_source_simulation(session), + source=source, + simulation=_source_simulation(source), weights=session.weights, geography=session.geography, request=request, @@ -359,6 +354,13 @@ def _request_key(request: AreaBuildRequest) -> str: return f"{request.area_type}:{request.area_id}" +def _default_builder() -> Any: + from .builder import LocalAreaDatasetBuilder + from .us_augmentations import default_us_postprocessors + + return LocalAreaDatasetBuilder(postprocessors=default_us_postprocessors()) + + def _resolve_output_path(*, output_dir: Path, output_relative_path: str) -> Path: candidate_path = (Path(output_dir) / output_relative_path).resolve(strict=False) output_dir_path = Path(output_dir).resolve(strict=False) @@ -371,11 +373,11 @@ def _resolve_output_path(*, output_dir: Path, output_relative_path: str) -> Path return candidate_path -def _source_simulation(session: WorkerSession) -> Any: - provider = getattr(session.source, "variable_provider", None) +def _source_simulation(source: Any) -> Any: + provider = getattr(source, "variable_provider", None) simulation = getattr(provider, "simulation", None) if simulation is None: - raise ValueError("Worker session source does not expose a simulation") + raise ValueError("Worker source does not expose a simulation") return simulation diff --git a/policyengine_us_data/build_outputs/worker_session.py b/policyengine_us_data/build_outputs/worker_session.py index 855fe62e9..52bb6d9d4 100644 --- a/policyengine_us_data/build_outputs/worker_session.py +++ b/policyengine_us_data/build_outputs/worker_session.py @@ -4,7 +4,7 @@ from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Literal +from typing import Any, Callable, Literal import numpy as np @@ -64,6 +64,14 @@ class WorkerSession: bootstrap_bundle: WorkerBootstrapBundle | None = None bootstrap_status: BootstrapStatus = "unavailable" caches: dict[str, Any] = field(default_factory=dict) + source_loader: Callable[[], SourceDatasetSnapshot] | None = None + + def load_source(self) -> SourceDatasetSnapshot: + """Load a request-scoped source snapshot for one H5 build.""" + + if self.source_loader is None: + return self.source + return self.source_loader() @pipeline_node( @@ -142,7 +150,7 @@ def create( ) if bootstrap_error is not None: bundle = None - source, bootstrap_status, source_error = self._load_source( + source, bootstrap_status, source_error, source_loader = self._load_source( inputs=inputs, bundle=bundle, ) @@ -183,6 +191,7 @@ def create( bootstrap_bundle=bundle if bootstrap_status == "used" else None, bootstrap_status=bootstrap_status, caches=caches, + source_loader=source_loader, ) def _load_bootstrap( @@ -269,7 +278,15 @@ def _load_source( *, inputs: PublishingInputBundle, bundle: WorkerBootstrapBundle | None, - ) -> tuple[SourceDatasetSnapshot, BootstrapStatus, Exception | None]: + ) -> tuple[ + SourceDatasetSnapshot, + BootstrapStatus, + Exception | None, + Callable[[], SourceDatasetSnapshot], + ]: + def raw_source_loader() -> SourceDatasetSnapshot: + return self._dataset_reader.load(inputs.source_dataset_path) + if bundle is not None: try: entity_graph = load_entity_graph(bundle.entity_graph_path) @@ -277,20 +294,25 @@ def _load_source( self._dataset_reader, "load_with_entity_graph", ) - return ( - load_with_entity_graph( + + def bootstrap_source_loader() -> SourceDatasetSnapshot: + return load_with_entity_graph( inputs.source_dataset_path, entity_graph, - ), + ) + + return ( + bootstrap_source_loader(), "used", None, + bootstrap_source_loader, ) except Exception as exc: - source = self._dataset_reader.load(inputs.source_dataset_path) - return source, "fallback", exc + source = raw_source_loader() + return source, "fallback", exc, raw_source_loader - source = self._dataset_reader.load(inputs.source_dataset_path) - return source, "unavailable", None + source = raw_source_loader() + return source, "unavailable", None, raw_source_loader def _load_weights( self, diff --git a/tests/integration/build_outputs/h5_worker_runtime/test_worker_script_tiny_fixture.py b/tests/integration/build_outputs/h5_worker_runtime/test_worker_script_tiny_fixture.py index 24abb0179..0a3c591c7 100644 --- a/tests/integration/build_outputs/h5_worker_runtime/test_worker_script_tiny_fixture.py +++ b/tests/integration/build_outputs/h5_worker_runtime/test_worker_script_tiny_fixture.py @@ -97,6 +97,14 @@ def _run_worker( return json.loads(result.stdout) +def _assert_outputs_reload_with_policyengine(output_dir: Path, requests) -> None: + from policyengine_us import Microsimulation + + for request in requests: + h5_path = output_dir / request.output_relative_path + Microsimulation(dataset=str(h5_path)) + + def test_tiny_fixture_source_snapshot_matches_worker_artifacts(tmp_path): artifacts = seed_local_h5_artifacts(tmp_path / "source-snapshot") @@ -218,6 +226,7 @@ def test_worker_validation_runs_for_tiny_district_state_and_national_h5s(tmp_pat ) parsed = json.loads(result.stdout) + _assert_outputs_reload_with_policyengine(output_dir, requests) assert result.stderr.count("Worker session ready:") == 1 assert parsed["failed"] == [] assert parsed["errors"] == [] diff --git a/tests/unit/build_outputs/test_worker_service.py b/tests/unit/build_outputs/test_worker_service.py index 182e25f94..1316b7bda 100644 --- a/tests/unit/build_outputs/test_worker_service.py +++ b/tests/unit/build_outputs/test_worker_service.py @@ -46,6 +46,7 @@ def _session( validation_context=None, weight_clones: int = 2, geography_clones: int = 2, + source_loader=None, ) -> WorkerSession: source = SimpleNamespace( variable_provider=SimpleNamespace(simulation=SimpleNamespace()), @@ -61,6 +62,7 @@ def _session( ), geography=SimpleNamespace(n_records=1, n_clones=geography_clones), validation_context=validation_context, + source_loader=source_loader, ) @@ -204,6 +206,43 @@ def test_worker_service_does_not_apply_takeup_filter_to_national_request(tmp_pat assert builder.calls[0]["takeup_filter"] is None +def test_worker_service_loads_fresh_source_for_each_request(tmp_path): + builder = FakeBuilder() + requests = ( + _request("district", "NC-01", "districts/NC-01.h5"), + _request("state", "NC", "states/NC.h5"), + ) + simulations = ( + SimpleNamespace(name="first-simulation"), + SimpleNamespace(name="second-simulation"), + ) + sources = [ + SimpleNamespace(variable_provider=SimpleNamespace(simulation=simulation)) + for simulation in simulations + ] + source_loads = [] + + def load_source(): + source = sources[len(source_loads)] + source_loads.append(source) + return source + + result = LocalH5WorkerService( + builder=builder, + writer=FakeWriter(), + validation_service=FakeValidationService(), + ).execute( + session=_session(source_loader=load_source), + requests=requests, + config=WorkerExecutionConfig(output_dir=tmp_path, validate=False), + ) + + assert result.to_legacy_dict()["completed"] == ["district:NC-01", "state:NC"] + assert source_loads == sources + assert [call["source"] for call in builder.calls] == sources + assert [call["simulation"] for call in builder.calls] == list(simulations) + + def test_worker_service_reports_build_failures(tmp_path): request = _request() @@ -245,7 +284,8 @@ def test_worker_service_records_validation_errors_without_failing_by_default( assert payload["completed"] == ["district:NC-01"] assert payload["failed"] == [] - assert payload["errors"] == [] + assert payload["errors"][0]["phase"] == "validation" + assert payload["errors"][0]["error"] == "validation failed" assert payload["issues"][0]["phase"] == "validation" assert payload["results"][0]["validation_status"] == "error" From 4df0f1928e0b7b7bd3caae25384e9b7a2e11cf46 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 18 May 2026 18:54:06 +0200 Subject: [PATCH 3/3] Cover worker source and validation error seams --- .../unit/build_outputs/test_worker_session.py | 44 ++++++++++ tests/unit/test_modal_local_area.py | 85 +++++++++++++++++++ 2 files changed, 129 insertions(+) diff --git a/tests/unit/build_outputs/test_worker_session.py b/tests/unit/build_outputs/test_worker_session.py index 7ff142ddb..0e23f1829 100644 --- a/tests/unit/build_outputs/test_worker_session.py +++ b/tests/unit/build_outputs/test_worker_session.py @@ -41,6 +41,21 @@ def load_with_entity_graph(self, dataset_path: Path, entity_graph): return self.snapshot +class SequenceDatasetReader(SessionDatasetReader): + """Dataset reader fake that returns a new snapshot for each load.""" + + def __init__(self, snapshots): + super().__init__(snapshots[0]) + self.snapshots = list(snapshots) + self.load_index = 0 + + def load(self, dataset_path: Path): + self.loaded_paths.append(Path(dataset_path)) + snapshot = self.snapshots[self.load_index] + self.load_index += 1 + return snapshot + + class SessionGeographyLoader(FakeGeographyLoader): """Geography loader fake that records load calls.""" @@ -124,6 +139,35 @@ def test_worker_session_factory_uses_raw_loaders_without_bootstrap(tmp_path): assert validation_service.calls[0]["inputs"] == artifacts.inputs +def test_worker_session_factory_raw_source_loader_returns_fresh_snapshots(tmp_path): + first = make_bootstrap_test_artifacts(tmp_path / "first") + second = make_bootstrap_test_artifacts(tmp_path / "second") + third = make_bootstrap_test_artifacts(tmp_path / "third") + dataset_reader = SequenceDatasetReader( + (first.snapshot, second.snapshot, third.snapshot) + ) + + session = WorkerSessionFactory( + dataset_reader=dataset_reader, + geography_loader=SessionGeographyLoader(first), + validation_service=FakeValidationService(), + ).create( + inputs=first.inputs, + scope="regional", + validation_policy=ValidationPolicy(enabled=False), + period=2024, + ) + + assert session.source is first.snapshot + assert session.load_source() is second.snapshot + assert session.load_source() is third.snapshot + assert dataset_reader.loaded_paths == [ + first.inputs.source_dataset_path, + first.inputs.source_dataset_path, + first.inputs.source_dataset_path, + ] + + def test_worker_session_factory_prefers_bootstrap_entity_graph(tmp_path): artifacts = make_bootstrap_test_artifacts(tmp_path / "inputs") store = WorkerBootstrapStore(tmp_path / "artifacts") diff --git a/tests/unit/test_modal_local_area.py b/tests/unit/test_modal_local_area.py index aab68b432..54cd2858e 100644 --- a/tests/unit/test_modal_local_area.py +++ b/tests/unit/test_modal_local_area.py @@ -462,3 +462,88 @@ def fake_run(cmd, **kwargs): assert "Worker session ready: scope=regional, bootstrap=used" in captured.err assert "--scope-fingerprint" in captured_cmd["cmd"] assert "regional-fingerprint" in captured_cmd["cmd"] + + +def test_run_phase_collects_worker_validation_errors(monkeypatch, tmp_path): + local_area = load_local_area_module() + run_dir = tmp_path / "run-123" + run_dir.mkdir() + captured_spawns = [] + + monkeypatch.setattr( + local_area, + "partition_work", + lambda work_items, num_workers, completed: [work_items], + ) + monkeypatch.setattr( + local_area, + "get_completed_from_volume", + lambda path: {"district:NC-01"}, + ) + monkeypatch.setattr( + local_area, + "staging_volume", + SimpleNamespace(reload=lambda: None), + ) + + class FakeHandle: + object_id = "fc-validation-error" + + def get(self): + return { + "completed": ["district:NC-01"], + "failed": [], + "errors": [ + { + "item": "district:NC-01", + "phase": "validation", + "error": "validation failed", + } + ], + "validation_rows": [], + "issues": [ + { + "item": "district:NC-01", + "phase": "validation", + "error": "validation failed", + } + ], + } + + def fake_spawn(**kwargs): + captured_spawns.append(kwargs) + return FakeHandle() + + monkeypatch.setattr( + local_area, + "build_areas_worker", + SimpleNamespace(spawn=fake_spawn), + ) + + completed, phase_errors, validation_rows = local_area.run_phase( + "All areas", + work_items=[{"type": "district", "id": "NC-01", "weight": 1}], + num_workers=1, + completed=set(), + branch="main", + run_id="run-123", + calibration_inputs={ + "weights": "/tmp/calibration_weights.npy", + "dataset": "/tmp/source.h5", + "database": "/tmp/policy_data.db", + }, + run_dir=run_dir, + validate=True, + scope_fingerprint="regional-fingerprint", + ) + + assert completed == {"district:NC-01"} + assert validation_rows == [] + assert phase_errors == [ + { + "item": "district:NC-01", + "phase": "validation", + "error": "validation failed", + } + ] + assert captured_spawns[0]["scope_fingerprint"] == "regional-fingerprint"