diff --git a/changelog.d/985.added b/changelog.d/985.added new file mode 100644 index 000000000..b81988763 --- /dev/null +++ b/changelog.d/985.added @@ -0,0 +1 @@ +Coordinate local H5 publishing through typed catalog requests and normalize worker responses. diff --git a/docs/engineering/stages/build_outputs.md b/docs/engineering/stages/build_outputs.md index 9be5dc2e3..f67d9895b 100644 --- a/docs/engineering/stages/build_outputs.md +++ b/docs/engineering/stages/build_outputs.md @@ -34,6 +34,18 @@ a postprocessor. ## Worker Chunk Execution +The Modal coordinator builds canonical typed area requests before spawning +workers. Regional publish reads the target congressional district universe from +the staged target database through `TargetUniverseReader`, then asks +`USAreaCatalog` to define the regional release shape: every configured state, +every target congressional district, and the explicitly supported city outputs +such as NYC. The coordinator wraps those requests in `WeightedAreaRequest`, +partitions them with +`partition_weighted_area_requests()`, and sends workers typed +`--requests-json` payloads. Completion is measured against the explicit request +keys, not just a raw file count, so stale or unrelated H5 files cannot satisfy a +missing expected area. + `LocalH5WorkerService` is the reusable Stage 4 boundary for executing one prepared local-H5 worker chunk. It consumes a `WorkerSession`, typed `AreaBuildRequest` objects, and a `WorkerExecutionConfig`, then returns a @@ -44,6 +56,10 @@ service. It may parse legacy `--work-items` and typed `--requests-json`, prepare the worker session, and print the legacy coordinator JSON shape, but it should not regain build-loop, write-loop, or validation-loop logic. +The legacy `--work-items` input path remains compatibility-only while older +tests and explicit override callers are retired. New coordinator work should +prefer typed `AreaBuildRequest` objects and typed worker payloads. + For now, `WorkerResult.to_legacy_dict()` preserves the existing coordinator contract with `completed`, `failed`, `errors`, `validation_rows`, and `validation_summary`. New code should prefer the structured `results` and diff --git a/modal_app/local_area.py b/modal_app/local_area.py index 89ce90b09..8594fbb49 100644 --- a/modal_app/local_area.py +++ b/modal_app/local_area.py @@ -17,9 +17,10 @@ import sys import traceback from pathlib import Path -from typing import Dict, List, Mapping +from typing import Dict, List, Mapping, Sequence import modal +import numpy as np _baked = "/root/policyengine-us-data" _local = str(Path(__file__).resolve().parent.parent) @@ -32,13 +33,25 @@ from policyengine_us_data.build_outputs.bootstrap import ( # noqa: E402 WorkerBootstrapBuilder, ) +from policyengine_us_data.build_outputs.area_catalog import USAreaCatalog # noqa: E402 from policyengine_us_data.build_outputs.fingerprinting import ( # noqa: E402 FingerprintingService, PublishingInputBundle, ) +from policyengine_us_data.build_outputs.geography_loader import ( # noqa: E402 + CalibrationGeographyLoader, +) from policyengine_us_data.build_outputs.partitioning import ( # noqa: E402 + WeightedAreaRequest, + partition_weighted_area_requests, partition_weighted_work_items, ) +from policyengine_us_data.build_outputs.target_universe import ( # noqa: E402 + TargetUniverseReader, +) +from policyengine_us_data.build_outputs.worker_responses import ( # noqa: E402 + normalize_worker_response, +) from policyengine_us_data.build_outputs.worker_inputs import ( # noqa: E402 WorkerCalibrationInputs, ) @@ -517,6 +530,125 @@ def _build_worker_calibration_inputs( ) +def _existing_path(path: Path | None) -> Path | None: + if path is None: + return None + return path if Path(path).exists() else None + + +def _infer_weight_record_count(*, weights_path: Path, n_clones: int) -> int: + """Infer source-record count from a flat weight vector without loading it.""" + + if isinstance(n_clones, bool) or not isinstance(n_clones, int | np.integer): + raise TypeError("n_clones must be an integer") + normalized_clones = int(n_clones) + if normalized_clones <= 0: + raise ValueError("n_clones must be positive") + + weights = np.load(weights_path, mmap_mode="r") + if weights.ndim != 1: + raise ValueError("Weight vector must be one-dimensional") + if weights.size == 0: + raise ValueError("Weight vector must be non-empty") + if not np.issubdtype(weights.dtype, np.number): + raise TypeError("Weight vector must have a numeric dtype") + if np.issubdtype(weights.dtype, np.complexfloating): + raise TypeError("Weight vector must have a real numeric dtype") + if weights.size % normalized_clones != 0: + raise ValueError( + f"Weight vector length {weights.size} is not divisible by " + f"n_clones={normalized_clones}" + ) + return weights.size // normalized_clones + + +def _load_area_catalog_geography( + *, + weights_path: Path, + n_clones: int, + geography_path: Path | None, + calibration_package_path: Path | None = None, + legacy_blocks_path: Path | None = None, +): + """Load geography once for coordinator-side typed request construction.""" + + n_records = _infer_weight_record_count( + weights_path=weights_path, + n_clones=n_clones, + ) + return CalibrationGeographyLoader().load( + weights_path=weights_path, + n_records=n_records, + n_clones=n_clones, + geography_path=_existing_path(geography_path), + blocks_path=_existing_path(legacy_blocks_path), + calibration_package_path=_existing_path(calibration_package_path), + ) + + +def _build_regional_weighted_requests( + *, + geography, + target_cd_geoids: Sequence[str], + catalog: USAreaCatalog | None = None, +) -> tuple[WeightedAreaRequest, ...]: + """Build canonical weighted regional H5 requests from release targets.""" + + catalog = catalog or USAreaCatalog.default() + requests = catalog.build_expected_regional_requests( + target_cd_geoids=target_cd_geoids, + geography=geography, + ) + + from collections import Counter + + districts_by_state = Counter( + request.area_id.split("-", 1)[0] + for request in requests + if request.area_type == "district" + ) + city_weights = {"NYC": 11} + + weighted: list[WeightedAreaRequest] = [] + for request in requests: + if request.area_type == "state": + weight = districts_by_state.get(request.area_id, 1) + elif request.area_type == "city": + weight = city_weights.get(request.area_id, 3) + else: + weight = 1 + weighted.append( + WeightedAreaRequest( + request=request, + weight=weight, + ) + ) + return tuple(weighted) + + +def _build_weighted_requests_from_work_items( + *, + work_items: Sequence[Mapping[str, object]], + geography, + catalog: USAreaCatalog | None = None, +) -> tuple[WeightedAreaRequest, ...]: + """Convert legacy override work items into canonical weighted requests.""" + + catalog = catalog or USAreaCatalog.default() + weighted: list[WeightedAreaRequest] = [] + for item in work_items: + request = catalog.build_request_from_work_item(item, geography=geography) + if request is None: + continue + weighted.append( + WeightedAreaRequest( + request=request, + weight=item.get("weight", 1), + ) + ) + return tuple(weighted) + + @pipeline_node( PipelineNode( id="coordinate_work_partition", @@ -568,6 +700,25 @@ def get_completed_from_volume(run_dir: Path) -> set: return completed +def _measure_expected_completion( + *, + expected_keys: set[str], + initially_completed: set[str], + completed: set[str], +) -> tuple[set[str], dict[str, int]]: + """Measure completion against the explicit expected request set.""" + + missing_keys = expected_keys - completed + reused_outputs = initially_completed & completed & expected_keys + recomputed_outputs = (completed - initially_completed) & expected_keys + return missing_keys, { + "expected_outputs": len(expected_keys), + "valid_reused_outputs": len(reused_outputs), + "recomputed_outputs": len(recomputed_outputs), + "invalid_outputs": len(missing_keys), + } + + @pipeline_node( PipelineNode( id="run_local_h5_phase", @@ -584,7 +735,7 @@ def get_completed_from_volume(run_dir: Path) -> set: ) def run_phase( phase_name: str, - work_items: List[Dict], + weighted_requests: Sequence[WeightedAreaRequest] | None, num_workers: int, completed: set, branch: str, @@ -593,6 +744,7 @@ def run_phase( run_dir: Path, validate: bool = True, scope_fingerprint: str | None = None, + work_items: List[Dict] | None = None, ) -> tuple: """Run a single build phase, spawning workers and collecting results. @@ -602,7 +754,14 @@ def run_phase( and crashes, and validation_rows is a list of per-target validation result dicts. """ - work_chunks = partition_work(work_items, num_workers, completed) + if weighted_requests is not None: + work_chunks = partition_weighted_area_requests( + weighted_requests, + num_workers, + completed, + ) + else: + work_chunks = partition_work(work_items or [], num_workers, completed) total_remaining = sum(len(c) for c in work_chunks) worker_input_payload = WorkerCalibrationInputs.from_wire_dict( calibration_inputs @@ -617,13 +776,21 @@ def run_phase( handles = [] for i, chunk in enumerate(work_chunks): - total_weight = sum(item["weight"] for item in chunk) + if weighted_requests is not None: + total_weight = sum(item.weight for item in chunk) + request_payloads = [item.to_worker_payload() for item in chunk] + legacy_work_items = None + else: + total_weight = sum(item["weight"] for item in chunk) + request_payloads = None + legacy_work_items = chunk print(f" Worker {i}: {len(chunk)} items, weight {total_weight}") handle = build_areas_worker.spawn( branch=branch, run_id=run_id, scope="regional", - work_items=chunk, + work_items=legacy_work_items, + request_payloads=request_payloads, calibration_inputs=worker_input_payload, validate=validate, scope_fingerprint=scope_fingerprint, @@ -639,30 +806,38 @@ def run_phase( for i, handle in enumerate(handles): try: result = handle.get() - if result is None: - all_errors.append({"worker": i, "error": "Worker returned None"}) - print(f" Worker {i}: returned None (no results)") - continue - all_results.append(result) + worker_result = normalize_worker_response( + worker_index=i, + result=result, + ) + all_results.append(worker_result) print( - f" Worker {i}: {len(result['completed'])} completed, " - f"{len(result['failed'])} failed" + f" Worker {i}: {len(worker_result.completed)} completed, " + f"{len(worker_result.failed)} failed" ) - if result["errors"]: - all_errors.extend(result["errors"]) - # Collect validation rows - v_rows = result.get("validation_rows", []) - if v_rows: - all_validation_rows.extend(v_rows) - print(f" Worker {i}: {len(v_rows)} validation rows") + if worker_result.fatal_errors: + all_errors.extend(worker_result.fatal_errors) + if worker_result.issues: + all_errors.extend(worker_result.issues) + if worker_result.validation_rows: + all_validation_rows.extend(worker_result.validation_rows) + print( + f" Worker {i}: {len(worker_result.validation_rows)} validation rows" + ) except Exception as e: all_errors.append( - {"worker": i, "error": str(e), "traceback": traceback.format_exc()} + { + "worker": i, + "phase": "transport", + "severity": "transport", + "error": str(e), + "traceback": traceback.format_exc(), + } ) print(f" Worker {i}: CRASHED - {e}") - total_completed = sum(len(r["completed"]) for r in all_results) - total_failed = sum(len(r["failed"]) for r in all_results) + total_completed = sum(len(result.completed) for result in all_results) + total_failed = sum(len(result.failed) for result in all_results) staging_volume.reload() volume_completed = get_completed_from_volume(run_dir) @@ -720,10 +895,11 @@ def build_areas_worker( branch: str, run_id: str, scope: str, - work_items: List[Dict], - calibration_inputs: WorkerCalibrationInputs | Mapping[str, object], + work_items: List[Dict] | None = None, + calibration_inputs: WorkerCalibrationInputs | Mapping[str, object] | None = None, validate: bool = True, scope_fingerprint: str | None = None, + request_payloads: List[Dict] | None = None, ) -> Dict: """ Worker function that builds a subset of H5 files. @@ -737,13 +913,25 @@ def build_areas_worker( output_dir = Path(VOLUME_MOUNT) / run_id output_dir.mkdir(parents=True, exist_ok=True) - work_items_json = json.dumps(work_items) + if calibration_inputs is None: + raise ValueError("calibration_inputs must be provided") worker_inputs = WorkerCalibrationInputs.from_wire_dict(calibration_inputs) + if request_payloads is not None: + request_args = ["--requests-json", json.dumps(request_payloads)] + failed_items = [ + f"{item.get('area_type', '')}:" + f"{item.get('area_id', '')}" + for item in request_payloads + ] + elif work_items is not None: + request_args = ["--work-items", json.dumps(work_items)] + failed_items = [f"{item['type']}:{item['id']}" for item in work_items] + else: + raise ValueError("Either request_payloads or work_items must be provided") worker_cmd = [ *_python_cmd("-m", "modal_app.worker_script"), - "--work-items", - work_items_json, + *request_args, *worker_inputs.to_worker_cli_args(), "--output-dir", str(output_dir), @@ -786,7 +974,7 @@ def build_areas_worker( if result.returncode != 0: return { "completed": [], - "failed": [f"{item['type']}:{item['id']}" for item in work_items], + "failed": failed_items, "errors": [{"error": (result.stderr or "No stderr")[:2000]}], } @@ -832,12 +1020,6 @@ def validate_staging(branch: str, run_id: str, version: str = "") -> Dict: if not version: version = run_id.split("_", 1)[0] - # PR 9 migration note: - # The coordinator still enumerates states, districts, and cities inline - # and emits legacy work_items. This is intentionally temporary for the - # dual-path migration. The target cleanup is to delegate regional request - # enumeration to USAreaCatalog and send typed --requests-json payloads to - # workers so area construction no longer lives in the coordinator. result = subprocess.run( _python_cmd( "-c", @@ -1125,6 +1307,13 @@ def coordinate_publish( calibration_package_path=calibration_package_path, ) validate_artifacts(config_json_path, artifacts) + regional_geography = _load_area_catalog_geography( + weights_path=weights_path, + n_clones=n_clones, + geography_path=geography_path, + calibration_package_path=calibration_package_path, + legacy_blocks_path=artifacts / "stacked_blocks.npy", + ) if validate: try: @@ -1177,60 +1366,21 @@ def coordinate_publish( pipeline_volume.commit() staging_volume.commit() if work_items_override is None: - result = subprocess.run( - _python_cmd( - "-c", - ( - "import json\n" - "from policyengine_us_data.calibration.calibration_utils " - "import get_all_cds_from_database, STATE_CODES\n" - "from policyengine_us_data.calibration.publish_local_area " - "import get_district_friendly_name\n" - f'db_uri = "sqlite:///{db_path}"\n' - "cds = get_all_cds_from_database(db_uri)\n" - "states = list(STATE_CODES.values())\n" - "districts = [get_district_friendly_name(cd) for cd in cds]\n" - 'print(json.dumps({"states": states, "districts": districts, ' - '"cities": ["NYC"], "cds": cds}))\n' - ), - ), - capture_output=True, - text=True, - env=os.environ.copy(), + target_universe = TargetUniverseReader.from_sqlite(db_path).regional() + weighted_requests = _build_regional_weighted_requests( + geography=regional_geography, + target_cd_geoids=target_universe.cd_geoids, ) - - if result.returncode != 0: - raise RuntimeError(f"Failed to get work items: {result.stderr}") - - work_info = json.loads(result.stdout) - states = work_info["states"] - districts = work_info["districts"] - cities = work_info["cities"] - - from collections import Counter - - cds_per_state = Counter(d.split("-")[0] for d in districts) - - CITY_WEIGHTS = {"NYC": 11} - - work_items = [] - for s in states: - work_items.append( - {"type": "state", "id": s, "weight": cds_per_state.get(s, 1)} - ) - for d in districts: - work_items.append({"type": "district", "id": d, "weight": 1}) - for c in cities: - work_items.append( - {"type": "city", "id": c, "weight": CITY_WEIGHTS.get(c, 3)} - ) else: - work_items = work_items_override - states = [item["id"] for item in work_items if item.get("type") == "state"] - districts = [ - item["id"] for item in work_items if item.get("type") == "district" - ] - cities = [item["id"] for item in work_items if item.get("type") == "city"] + weighted_requests = _build_weighted_requests_from_work_items( + work_items=work_items_override, + geography=regional_geography, + ) + if not weighted_requests: + raise RuntimeError("No regional H5 requests found for coordinator geography") + + expected_total = len(weighted_requests) + expected_keys = {item.key for item in weighted_requests} staging_volume.reload() completed = get_completed_from_volume(run_dir) @@ -1252,50 +1402,59 @@ def coordinate_publish( completed, phase_errors, v_rows = run_phase( "All areas", - work_items=work_items, + weighted_requests=weighted_requests, completed=completed, **phase_args, ) accumulated_errors.extend(phase_errors) accumulated_validation_rows.extend(v_rows) - expected_total = len(states) + len(districts) + len(cities) - # If workers crashed but all files landed on the volume, # treat as transient infrastructure errors (e.g. gRPC stream resets). + missing_keys, reuse_measurement = _measure_expected_completion( + expected_keys=expected_keys, + initially_completed=initially_completed, + completed=completed, + ) if accumulated_errors: - crash_errors = [e for e in accumulated_errors if "worker" in e] - if crash_errors and len(completed) >= expected_total: + fatal_worker_errors = [ + error + for error in accumulated_errors + if error.get("severity") in {"protocol", "worker_failure"} + ] + transport_errors = [ + error + for error in accumulated_errors + if error.get("severity") == "transport" + ] + if fatal_worker_errors: + raise RuntimeError( + f"Build failed: {len(fatal_worker_errors)} fatal worker " + f"error(s) detected. Errors: {fatal_worker_errors[:3]}" + ) + if transport_errors and not missing_keys: print( - f"WARNING: {len(crash_errors)} worker error(s) occurred " + f"WARNING: {len(transport_errors)} worker transport error(s) occurred " f"but all {expected_total} files present on volume. " - f"Treating as transient. Errors: {crash_errors[:3]}" + f"Treating as transient. Errors: {transport_errors[:3]}" ) - elif crash_errors: + elif transport_errors: raise RuntimeError( - f"Build failed: {len(crash_errors)} worker " - f"crash(es) detected and only " - f"{len(completed)}/{expected_total} files on volume. " - f"Errors: {crash_errors[:3]}" + f"Build failed: {len(transport_errors)} worker " + f"transport error(s) detected and only " + f"{expected_total - len(missing_keys)}/{expected_total} " + f"expected files on volume. " + f"Errors: {transport_errors[:3]}" ) - if len(completed) < expected_total: - missing = expected_total - len(completed) + if missing_keys: raise RuntimeError( - f"Build incomplete: {missing} files missing from " - f"volume ({len(completed)}/{expected_total}). " + f"Build incomplete: {len(missing_keys)} expected files missing from " + f"volume ({expected_total - len(missing_keys)}/{expected_total}). " + f"Missing: {sorted(missing_keys)[:5]}. " f"Volume preserved for retry." ) - reused_outputs = initially_completed & completed - recomputed_outputs = completed - initially_completed - reuse_measurement = { - "expected_outputs": expected_total, - "valid_reused_outputs": len(reused_outputs), - "recomputed_outputs": len(recomputed_outputs), - "invalid_outputs": max(expected_total - len(completed), 0), - } - if skip_upload: print("\nSkipping upload (--skip-upload flag set)") return { @@ -1308,7 +1467,6 @@ def coordinate_publish( print("\nValidating staging...") manifest = validate_staging.remote(branch=branch, run_id=run_id, version=version) - expected_total = len(states) + len(districts) + len(cities) actual_total = ( manifest["totals"]["states"] + manifest["totals"]["districts"] @@ -1470,26 +1628,27 @@ def coordinate_national_publish( pipeline_volume.commit() national_h5 = run_dir / "national" / "US.h5" - work_items = [{"type": "national", "id": "US"}] + national_request = USAreaCatalog.default().build_national_request() print("Spawning worker for national H5 build...") - worker_result = build_areas_worker.remote( + raw_worker_result = build_areas_worker.remote( branch=branch, run_id=run_id, scope="national", - work_items=work_items, + request_payloads=[national_request.to_dict()], calibration_inputs=calibration_inputs.to_wire_dict(), validate=validate, scope_fingerprint=fingerprint, ) + worker_result = normalize_worker_response(worker_index=0, result=raw_worker_result) print( f"Worker result: " - f"{len(worker_result['completed'])} completed, " - f"{len(worker_result['failed'])} failed" + f"{len(worker_result.completed)} completed, " + f"{len(worker_result.failed)} failed" ) - if worker_result["failed"]: - raise RuntimeError(f"National build failed: {worker_result['errors']}") + if worker_result.fatal_errors: + raise RuntimeError(f"National build failed: {worker_result.fatal_errors}") staging_volume.reload() national_h5 = run_dir / "national" / "US.h5" diff --git a/policyengine_us_data/build_outputs/__init__.py b/policyengine_us_data/build_outputs/__init__.py index 9f4e1eeb5..838534378 100644 --- a/policyengine_us_data/build_outputs/__init__.py +++ b/policyengine_us_data/build_outputs/__init__.py @@ -5,8 +5,10 @@ H5 output request construction, exact calibration geography loading, fingerprinting, clone-weight shape contracts, worker partitioning, source dataset snapshot contracts, worker input normalization, worker-bootstrap -artifacts, worker-scoped session and validation context setup, microsimulation -access helpers, clone selection, entity reindexing, source-variable cloning, -validated H5 payload contracts, ordered output postprocessing, one-area payload -building, H5 writing, and worker chunk execution. +artifacts, target-universe reading, worker-scoped session and validation context +setup, microsimulation access helpers, clone selection, entity reindexing, +source-variable cloning, validated H5 payload contracts, ordered output +postprocessing, one-area payload building, H5 writing, worker chunk execution, +and coordinator-side typed request partitioning and worker-response +normalization. """ diff --git a/policyengine_us_data/build_outputs/area_catalog.py b/policyengine_us_data/build_outputs/area_catalog.py index 4744c5b35..7f824ab99 100644 --- a/policyengine_us_data/build_outputs/area_catalog.py +++ b/policyengine_us_data/build_outputs/area_catalog.py @@ -153,6 +153,51 @@ def build_city_requests(self, geography: Any) -> AreaRequests: return () return (request,) + def build_expected_regional_requests( + self, + *, + target_cd_geoids: Sequence[Any], + geography: Any | None = None, + include_cities: Sequence[str] = ("NYC",), + ) -> AreaRequests: + """Enumerate the canonical regional H5 release shape. + + This method owns the release-output universe for regional local H5s: + every configured state, every target congressional district supplied by + the target adapter, and explicitly supported city outputs. Callers + should query targets elsewhere and pass the resulting CD GEOIDs in; the + catalog stays responsible for turning that target universe into output + requests. + + Args: + target_cd_geoids: Congressional district GEOIDs present in the + calibration target universe. + geography: Optional clone geography used only to derive validation + IDs for city outputs such as NYC. + include_cities: Supported city IDs to include in the release shape. + + Returns: + State, district, and city requests in production output order. + """ + + cd_geoids = self._unique_cd_geoids(target_cd_geoids) + requests: list[AreaBuildRequest] = [] + requests.extend( + self._build_release_state_request( + state_code=state_code, + state_fips=state_fips, + ) + for state_fips, state_code in sorted(self._state_codes.items()) + ) + requests.extend( + self._build_district_request(cd_geoid) for cd_geoid in cd_geoids + ) + requests.extend( + self._build_release_city_request(city_id, geography=geography) + for city_id in include_cities + ) + return tuple(requests) + def build_city_request( self, city_id: str, @@ -325,6 +370,28 @@ def _build_state_request( validation_geographic_ids=(str(state_fips),), ) + def _build_release_state_request( + self, + *, + state_code: str, + state_fips: int, + ) -> AreaBuildRequest: + return AreaBuildRequest( + area_type="state", + area_id=state_code, + display_name=state_code, + output_relative_path=f"states/{state_code}.h5", + filters=( + AreaFilter( + geography_field="state_fips", + op="eq", + value=state_fips, + ), + ), + validation_geo_level="state", + validation_geographic_ids=(str(state_fips),), + ) + def _build_district_request(self, cd_geoid: str) -> AreaBuildRequest: friendly_name = self.get_district_friendly_name(cd_geoid) return AreaBuildRequest( @@ -343,6 +410,33 @@ def _build_district_request(self, cd_geoid: str) -> AreaBuildRequest: validation_geographic_ids=(str(cd_geoid),), ) + def _build_release_city_request( + self, + city_id: str, + *, + geography: Any | None, + ) -> AreaBuildRequest: + if city_id != "NYC": + raise ValueError(f"Unknown city: {city_id}") + + return AreaBuildRequest( + area_type="city", + area_id="NYC", + display_name="NYC", + output_relative_path="cities/NYC.h5", + filters=( + AreaFilter( + geography_field="county_fips", + op="in", + value=self._nyc_county_fips, + ), + ), + validation_geo_level="district", + validation_geographic_ids=( + self._nyc_cd_geoids(geography) if geography is not None else () + ), + ) + def get_district_friendly_name(self, cd_geoid: str) -> str: """Convert a congressional district GEOID into its output name. diff --git a/policyengine_us_data/build_outputs/partitioning.py b/policyengine_us_data/build_outputs/partitioning.py index 9f1815a2d..275d60903 100644 --- a/policyengine_us_data/build_outputs/partitioning.py +++ b/policyengine_us_data/build_outputs/partitioning.py @@ -1,8 +1,9 @@ -"""Pure helpers for assigning weighted work items to worker chunks.""" +"""Pure helpers for assigning weighted local H5 requests to worker chunks.""" from __future__ import annotations import heapq +from dataclasses import dataclass from collections.abc import Mapping, Sequence from typing import Any @@ -11,15 +12,55 @@ WorkItem = Mapping[str, Any] WorkItems = Sequence[WorkItem] WorkChunks = list[list[WorkItem]] +WeightedAreaRequestChunks = list[list["WeightedAreaRequest"]] __all__ = [ + "WeightedAreaRequest", + "WeightedAreaRequestChunks", "WorkChunks", "WorkItem", "WorkItems", + "partition_weighted_area_requests", "partition_weighted_work_items", ] +@pipeline_node( + id="local_h5_weighted_area_request", + label="WeightedAreaRequest", + node_type="library", + description="Typed local H5 request with coordinator scheduling weight.", + source_file="policyengine_us_data/build_outputs/partitioning.py", + status="current", + stability="moving", + pathways=["local_h5"], + validation_commands=["uv run pytest tests/unit/build_outputs/test_partitioning.py"], +) +@dataclass(frozen=True) +class WeightedAreaRequest: + """Area build request plus scheduling weight for coordinator partitioning.""" + + request: Any + weight: int | float = 1 + + def __post_init__(self) -> None: + if isinstance(self.weight, bool) or not isinstance(self.weight, int | float): + raise TypeError("weight must be numeric") + if self.weight <= 0: + raise ValueError("weight must be positive") + + @property + def key(self) -> str: + """Return the stable completion key for this request.""" + + return f"{self.request.area_type}:{self.request.area_id}" + + def to_worker_payload(self) -> dict[str, Any]: + """Serialize the request for `modal_app.worker_script --requests-json`.""" + + return self.request.to_dict() + + def work_item_key(item: WorkItem) -> str: """Return the stable completion key used by the current H5 workers.""" @@ -37,6 +78,33 @@ def work_item_key(item: WorkItem) -> str: pathways=["local_h5"], validation_commands=["uv run pytest tests/unit/build_outputs/test_partitioning.py"], ) +def partition_weighted_area_requests( + requests: Sequence[WeightedAreaRequest], + num_workers: int, + completed: set[str] | None = None, +) -> WeightedAreaRequestChunks: + """Partition remaining typed H5 requests across worker chunks.""" + + return _partition_weighted_items( + items=tuple(requests), + num_workers=num_workers, + completed=completed, + key=lambda item: item.key, + weight=lambda item: item.weight, + ) + + +@pipeline_node( + id="local_h5_legacy_work_item_partition", + label="Partition Legacy Local H5 Work Items", + node_type="library", + description="Compatibility wrapper for assigning legacy local H5 work items to workers.", + source_file="policyengine_us_data/build_outputs/partitioning.py", + status="legacy", + stability="moving", + pathways=["local_h5"], + validation_commands=["uv run pytest tests/unit/build_outputs/test_partitioning.py"], +) def partition_weighted_work_items( work_items: WorkItems, num_workers: int, @@ -61,23 +129,40 @@ def partition_weighted_work_items( every item is already completed. """ + return _partition_weighted_items( + items=tuple(work_items), + num_workers=num_workers, + completed=completed, + key=work_item_key, + weight=lambda item: item["weight"], + ) + + +def _partition_weighted_items( + *, + items: tuple[Any, ...], + num_workers: int, + completed: set[str] | None, + key, + weight, +): if num_workers <= 0: return [] completed = completed or set() - remaining = [item for item in work_items if work_item_key(item) not in completed] - remaining.sort(key=lambda item: -item["weight"]) + remaining = [item for item in items if key(item) not in completed] + remaining.sort(key=lambda item: -weight(item)) n_workers = min(num_workers, len(remaining)) if n_workers == 0: return [] heap: list[tuple[int | float, int]] = [(0, idx) for idx in range(n_workers)] - chunks: WorkChunks = [[] for _ in range(n_workers)] + chunks = [[] for _ in range(n_workers)] for item in remaining: load, idx = heapq.heappop(heap) chunks[idx].append(item) - heapq.heappush(heap, (load + item["weight"], idx)) + heapq.heappush(heap, (load + weight(item), idx)) return [chunk for chunk in chunks if chunk] diff --git a/policyengine_us_data/build_outputs/target_universe.py b/policyengine_us_data/build_outputs/target_universe.py new file mode 100644 index 000000000..90e5c0bcd --- /dev/null +++ b/policyengine_us_data/build_outputs/target_universe.py @@ -0,0 +1,80 @@ +"""Target-universe contracts for local H5 publication.""" + +from __future__ import annotations + +import sqlite3 +from dataclasses import dataclass +from pathlib import Path + +from policyengine_us_data.pipeline_metadata import pipeline_node + +__all__ = [ + "RegionalTargetUniverse", + "TargetUniverseReader", +] + + +@pipeline_node( + id="local_h5_regional_target_universe", + label="RegionalTargetUniverse", + node_type="library", + description="Target congressional district universe used to enumerate regional local H5 outputs.", + source_file="policyengine_us_data/build_outputs/target_universe.py", + status="current", + stability="moving", + pathways=["local_h5"], + validation_commands=[ + "uv run pytest tests/unit/build_outputs/test_target_universe.py" + ], +) +@dataclass(frozen=True) +class RegionalTargetUniverse: + """Congressional district target universe for regional H5 outputs.""" + + cd_geoids: tuple[str, ...] + + def __post_init__(self) -> None: + cd_geoids = tuple(str(item) for item in self.cd_geoids) + if not cd_geoids: + raise ValueError("Regional target universe must contain CD GEOIDs") + object.__setattr__(self, "cd_geoids", cd_geoids) + + +@pipeline_node( + id="local_h5_target_universe_reader", + label="TargetUniverseReader", + node_type="library", + description="Read local H5 target-universe contracts from the staged target database.", + source_file="policyengine_us_data/build_outputs/target_universe.py", + status="current", + stability="moving", + pathways=["local_h5"], + validation_commands=[ + "uv run pytest tests/unit/build_outputs/test_target_universe.py" + ], +) +@dataclass(frozen=True) +class TargetUniverseReader: + """Adapter from the Stage 1 target database artifact to H5 target contracts.""" + + db_path: Path + + @classmethod + def from_sqlite(cls, db_path: Path | str) -> "TargetUniverseReader": + """Create a reader for a SQLite `policy_data.db` artifact.""" + + return cls(db_path=Path(db_path)) + + def regional(self) -> RegionalTargetUniverse: + """Read the regional congressional district target universe.""" + + with sqlite3.connect(self.db_path) as conn: + rows = conn.execute( + """ + SELECT DISTINCT value AS cd_geoid + FROM stratum_constraints + WHERE constraint_variable = 'congressional_district_geoid' + ORDER BY value + """ + ).fetchall() + return RegionalTargetUniverse(cd_geoids=tuple(str(row[0]) for row in rows)) diff --git a/policyengine_us_data/build_outputs/worker_responses.py b/policyengine_us_data/build_outputs/worker_responses.py new file mode 100644 index 000000000..8dbb1d734 --- /dev/null +++ b/policyengine_us_data/build_outputs/worker_responses.py @@ -0,0 +1,234 @@ +"""Coordinator-side normalization for local H5 worker JSON responses.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Mapping + +from policyengine_us_data.pipeline_metadata import pipeline_node + +__all__ = [ + "CoordinatorWorkerResult", + "normalize_worker_response", +] + + +@pipeline_node( + id="local_h5_coordinator_worker_result", + label="CoordinatorWorkerResult", + node_type="library", + description="Coordinator-normalized view of one local H5 worker response.", + source_file="policyengine_us_data/build_outputs/worker_responses.py", + status="current", + stability="moving", + pathways=["local_h5"], + validation_commands=[ + "uv run pytest tests/unit/build_outputs/test_worker_responses.py" + ], +) +@dataclass(frozen=True) +class CoordinatorWorkerResult: + """Normalized worker response with explicit fatal and nonfatal issue classes.""" + + completed: tuple[str, ...] = () + failed: tuple[str, ...] = () + fatal_errors: tuple[dict[str, Any], ...] = () + issues: tuple[dict[str, Any], ...] = () + validation_rows: tuple[dict[str, Any], ...] = () + + +def _coordinator_error( + error: Mapping[str, Any], + *, + worker_index: int, + severity: str, +) -> dict[str, Any]: + payload = dict(error) + payload.setdefault("worker", worker_index) + payload["severity"] = severity + return payload + + +def _string_tuple_field( + result: Mapping[str, Any], + *, + worker_index: int, + field_name: str, +) -> tuple[tuple[str, ...], tuple[dict[str, Any], ...]]: + value = result.get(field_name) + if not isinstance(value, list | tuple): + return (), ( + _coordinator_error( + { + "phase": "protocol", + "error": f"Worker result field {field_name!r} must be a list", + }, + worker_index=worker_index, + severity="protocol", + ), + ) + return tuple(str(item) for item in value), () + + +def _dict_tuple_field( + result: Mapping[str, Any], + *, + worker_index: int, + field_name: str, +) -> tuple[tuple[dict[str, Any], ...], tuple[dict[str, Any], ...]]: + value = result.get(field_name, []) + if not isinstance(value, list | tuple): + return (), ( + _coordinator_error( + { + "phase": "protocol", + "error": f"Worker result field {field_name!r} must be a list", + }, + worker_index=worker_index, + severity="protocol", + ), + ) + + items: list[dict[str, Any]] = [] + protocol_errors: list[dict[str, Any]] = [] + for item in value: + if isinstance(item, dict): + items.append(dict(item)) + else: + protocol_errors.append( + _coordinator_error( + { + "phase": "protocol", + "error": ( + f"Worker result field {field_name!r} contained " + "a non-object item" + ), + }, + worker_index=worker_index, + severity="protocol", + ) + ) + return tuple(items), tuple(protocol_errors) + + +def _issue_identity(issue: Mapping[str, Any]) -> tuple[Any, Any, Any]: + return issue.get("item"), issue.get("phase"), issue.get("error") + + +@pipeline_node( + id="normalize_local_h5_worker_response", + label="Normalize Local H5 Worker Response", + node_type="library", + description="Normalize legacy worker JSON into explicit coordinator severity classes.", + source_file="policyengine_us_data/build_outputs/worker_responses.py", + status="current", + stability="moving", + pathways=["local_h5"], + validation_commands=[ + "uv run pytest tests/unit/build_outputs/test_worker_responses.py" + ], +) +def normalize_worker_response( + *, + worker_index: int, + result: object, +) -> CoordinatorWorkerResult: + """Normalize worker JSON into explicit fatal and nonfatal coordinator issues.""" + + if result is None: + return CoordinatorWorkerResult( + fatal_errors=( + _coordinator_error( + {"phase": "protocol", "error": "Worker returned None"}, + worker_index=worker_index, + severity="protocol", + ), + ) + ) + if not isinstance(result, dict): + return CoordinatorWorkerResult( + fatal_errors=( + _coordinator_error( + { + "phase": "protocol", + "error": f"Worker returned non-object result: {type(result)!r}", + }, + worker_index=worker_index, + severity="protocol", + ), + ) + ) + + completed, completed_errors = _string_tuple_field( + result, + worker_index=worker_index, + field_name="completed", + ) + failed, failed_errors = _string_tuple_field( + result, + worker_index=worker_index, + field_name="failed", + ) + worker_errors, worker_error_protocol_errors = _dict_tuple_field( + result, + worker_index=worker_index, + field_name="errors", + ) + worker_issues, worker_issue_protocol_errors = _dict_tuple_field( + result, + worker_index=worker_index, + field_name="issues", + ) + validation_rows, validation_row_protocol_errors = _dict_tuple_field( + result, + worker_index=worker_index, + field_name="validation_rows", + ) + + fatal_errors = [ + *completed_errors, + *failed_errors, + *worker_error_protocol_errors, + *worker_issue_protocol_errors, + *validation_row_protocol_errors, + ] + fatal_errors.extend( + _coordinator_error( + error, + worker_index=worker_index, + severity="worker_failure", + ) + for error in worker_errors + ) + + fatal_issue_keys = {_issue_identity(error) for error in worker_errors} + nonfatal_issues = tuple( + issue + for issue in worker_issues + if _issue_identity(issue) not in fatal_issue_keys + ) + + error_items = { + str(error.get("item")) for error in worker_errors if error.get("item") + } + fatal_errors.extend( + _coordinator_error( + { + "item": item, + "phase": "worker", + "error": "Worker reported failed item without a matching error", + }, + worker_index=worker_index, + severity="worker_failure", + ) + for item in failed + if item not in error_items + ) + + return CoordinatorWorkerResult( + completed=completed, + failed=failed, + fatal_errors=tuple(fatal_errors), + issues=nonfatal_issues, + validation_rows=validation_rows, + ) diff --git a/tests/integration/test_tiny_h5_pipeline.py b/tests/integration/test_tiny_h5_pipeline.py index 372d65553..63b19cf0c 100644 --- a/tests/integration/test_tiny_h5_pipeline.py +++ b/tests/integration/test_tiny_h5_pipeline.py @@ -129,6 +129,44 @@ def test_saved_geography_h5_pipeline_builds_regional_and_national_outputs(): cleanup.remote(run_id) +def test_deployed_regional_coordinator_builds_from_seeded_artifacts(): + _require_modal_tokens() + + run_id = _run_id("h5-coordinator") + seed = _function(HARNESS_APP_NAME, "seed_h5_case") + inspect = _function(HARNESS_APP_NAME, "inspect_h5_outputs") + cleanup = _function(HARNESS_APP_NAME, "cleanup_h5_case") + coordinate = _function(LOCAL_AREA_APP_NAME, "coordinate_publish") + + try: + seeded = seed.remote(run_id, "saved_geography_success") + + result = coordinate.remote( + branch="main", + num_workers=1, + skip_upload=True, + n_clones=seeded["n_clones"], + validate=False, + run_id=run_id, + work_items_override=_work_items("district", "state"), + ) + + assert result["message"].endswith("Upload skipped.") + assert result["reuse_measurement"]["expected_outputs"] == 2 + assert result["reuse_measurement"]["invalid_outputs"] == 0 + + inspection = inspect.remote( + run_id, + ["districts/NC-01.h5", "states/NC.h5"], + ) + _assert_output_contract( + inspection, + ("districts/NC-01.h5", "states/NC.h5"), + ) + finally: + cleanup.remote(run_id) + + def test_package_fallback_h5_pipeline_builds_district_output(): _require_modal_tokens() diff --git a/tests/support/build_outputs/partitioning.py b/tests/support/build_outputs/partitioning.py index 9250f367e..7c0da7e4b 100644 --- a/tests/support/build_outputs/partitioning.py +++ b/tests/support/build_outputs/partitioning.py @@ -41,6 +41,8 @@ def load_partitioning_exports(): return { "module": module, "flatten_chunks": flatten_chunks, + "partition_weighted_area_requests": module.partition_weighted_area_requests, "partition_weighted_work_items": module.partition_weighted_work_items, + "WeightedAreaRequest": module.WeightedAreaRequest, "work_item_key": module.work_item_key, } diff --git a/tests/support/modal_local_area.py b/tests/support/modal_local_area.py index 0e5f2a426..728146c63 100644 --- a/tests/support/modal_local_area.py +++ b/tests/support/modal_local_area.py @@ -76,6 +76,12 @@ def decorator(func): fake_pipeline_schema = ModuleType("policyengine_us_data.pipeline_schema") fake_utils = ModuleType("policyengine_us_data.utils") fake_run_context = ModuleType("policyengine_us_data.utils.run_context") + fake_area_catalog = ModuleType( + "policyengine_us_data.build_outputs.area_catalog" + ) + fake_geography_loader = ModuleType( + "policyengine_us_data.build_outputs.geography_loader" + ) fake_partitioning = ModuleType( "policyengine_us_data.build_outputs.partitioning" ) @@ -86,6 +92,12 @@ def decorator(func): fake_worker_inputs = ModuleType( "policyengine_us_data.build_outputs.worker_inputs" ) + fake_worker_responses = ModuleType( + "policyengine_us_data.build_outputs.worker_responses" + ) + fake_target_universe = ModuleType( + "policyengine_us_data.build_outputs.target_universe" + ) fake_policyengine.__path__ = [] fake_calibration.__path__ = [] fake_build_outputs.__path__ = [] @@ -104,8 +116,118 @@ def decorator(func): fake_pipeline_metadata.pipeline_node = _fake_pipeline_node fake_pipeline_schema.PipelineNode = _FakePipelineNode fake_run_context.resolve_run_id = lambda explicit="", **kwargs: explicit + + class _FakeAreaRequest: + def __init__(self, *, area_type, area_id): + self.area_type = area_type + self.area_id = area_id + + def to_dict(self): + return { + "area_type": self.area_type, + "area_id": self.area_id, + } + + class _FakeUSAreaCatalog: + @classmethod + def default(cls): + return cls() + + def build_state_requests(self, geography): + return () + + def build_district_requests(self, geography): + return () + + def build_city_requests(self, geography): + return () + + def build_expected_regional_requests(self, *, target_cd_geoids, **kwargs): + return tuple( + _FakeAreaRequest(area_type="district", area_id=str(cd_geoid)) + for cd_geoid in target_cd_geoids + ) + + def build_national_request(self): + return _FakeAreaRequest(area_type="national", area_id="US") + + def build_request_from_work_item(self, item, *, geography): + return _FakeAreaRequest(area_type=item["type"], area_id=item["id"]) + + class _FakeCalibrationGeographyLoader: + def load(self, **kwargs): + return SimpleNamespace() + + class _FakeWeightedAreaRequest: + def __init__(self, request, weight=1): + self.request = request + self.weight = weight + + @property + def key(self): + return f"{self.request.area_type}:{self.request.area_id}" + + def to_worker_payload(self): + return self.request.to_dict() + + def _fake_partition_typed(requests, num_workers, completed=None): + completed = completed or set() + remaining = [item for item in requests if item.key not in completed] + return [remaining] if remaining else [] + + fake_area_catalog.USAreaCatalog = _FakeUSAreaCatalog + fake_geography_loader.CalibrationGeographyLoader = ( + _FakeCalibrationGeographyLoader + ) + fake_partitioning.WeightedAreaRequest = _FakeWeightedAreaRequest + fake_partitioning.partition_weighted_area_requests = _fake_partition_typed fake_partitioning.partition_weighted_work_items = lambda *args, **kwargs: [] + def _fake_normalize_worker_response(*, worker_index, result): + if result is None: + return SimpleNamespace( + completed=(), + failed=(), + fatal_errors=( + { + "worker": worker_index, + "severity": "protocol", + "error": "Worker returned None", + }, + ), + issues=(), + validation_rows=(), + ) + errors = tuple( + { + **error, + "worker": worker_index, + "severity": "worker_failure", + } + for error in result.get("errors", ()) + ) + return SimpleNamespace( + completed=tuple(result.get("completed", ())), + failed=tuple(result.get("failed", ())), + fatal_errors=errors, + issues=tuple(result.get("issues", ())), + validation_rows=tuple(result.get("validation_rows", ())), + ) + + fake_worker_responses.normalize_worker_response = ( + _fake_normalize_worker_response + ) + + class _FakeTargetUniverseReader: + @classmethod + def from_sqlite(cls, db_path): + return cls() + + def regional(self): + return SimpleNamespace(cd_geoids=("3701",)) + + fake_target_universe.TargetUniverseReader = _FakeTargetUniverseReader + class _FakeWorkerBootstrapBuilder: def build(self, *args, **kwargs): return SimpleNamespace( @@ -227,14 +349,24 @@ def compute_scope_fingerprint(self, *args, **kwargs): "policyengine_us_data.utils": fake_utils, "policyengine_us_data.utils.run_context": fake_run_context, "policyengine_us_data.build_outputs": fake_build_outputs, + "policyengine_us_data.build_outputs.area_catalog": fake_area_catalog, "policyengine_us_data.build_outputs.bootstrap": fake_bootstrap, "policyengine_us_data.build_outputs.fingerprinting": ( fake_fingerprinting ), + "policyengine_us_data.build_outputs.geography_loader": ( + fake_geography_loader + ), "policyengine_us_data.build_outputs.partitioning": (fake_partitioning), "policyengine_us_data.build_outputs.worker_inputs": ( fake_worker_inputs ), + "policyengine_us_data.build_outputs.worker_responses": ( + fake_worker_responses + ), + "policyengine_us_data.build_outputs.target_universe": ( + fake_target_universe + ), } ) diff --git a/tests/unit/build_outputs/test_area_catalog.py b/tests/unit/build_outputs/test_area_catalog.py index 30523faa4..b0f714b5f 100644 --- a/tests/unit/build_outputs/test_area_catalog.py +++ b/tests/unit/build_outputs/test_area_catalog.py @@ -57,6 +57,61 @@ def test_build_city_requests_emits_nyc_request_with_district_validation_ids(): assert requests[0].validation_geographic_ids == ("3601", "3603") +def test_catalog_requests_all_states_districts_and_nyc_present_in_geography(): + catalog = make_catalog() + geography = make_geography( + cd_geoids=["101", "102", "298", "3601", "3603"], + county_fips=["01001", "01003", "02020", "36061", "36081"], + ) + + state_requests = catalog.build_state_requests(geography) + district_requests = catalog.build_district_requests(geography) + city_requests = catalog.build_city_requests(geography) + + assert [request.area_id for request in state_requests] == ["AL", "AK", "NY"] + assert [request.area_id for request in district_requests] == [ + "AL-01", + "AL-02", + "AK-01", + "NY-01", + "NY-03", + ] + assert [request.area_id for request in city_requests] == ["NYC"] + + +def test_build_expected_regional_requests_defines_release_shape(): + catalog = make_catalog() + geography = make_geography( + cd_geoids=["3601", "3603", "101"], + county_fips=["36061", "36081", "01001"], + ) + + requests = catalog.build_expected_regional_requests( + target_cd_geoids=["101", "102", "298", "3601", "3603"], + geography=geography, + ) + + assert [request.area_id for request in requests] == [ + "AL", + "AK", + "NY", + "AL-01", + "AL-02", + "AK-01", + "NY-01", + "NY-03", + "NYC", + ] + state_requests = requests[:3] + assert [request.filters[0].geography_field for request in state_requests] == [ + "state_fips", + "state_fips", + "state_fips", + ] + assert [request.filters[0].value for request in state_requests] == [1, 2, 36] + assert requests[-1].validation_geographic_ids == ("3601", "3603") + + def test_build_national_request_returns_canonical_us_request(): catalog = make_catalog() diff --git a/tests/unit/build_outputs/test_partitioning.py b/tests/unit/build_outputs/test_partitioning.py index 7627398a4..5a14c985e 100644 --- a/tests/unit/build_outputs/test_partitioning.py +++ b/tests/unit/build_outputs/test_partitioning.py @@ -5,15 +5,62 @@ partitioning = load_partitioning_exports() flatten_chunks = partitioning["flatten_chunks"] +partition_weighted_area_requests = partitioning["partition_weighted_area_requests"] partition_weighted_work_items = partitioning["partition_weighted_work_items"] +WeightedAreaRequest = partitioning["WeightedAreaRequest"] work_item_key = partitioning["work_item_key"] +class FakeAreaRequest: + def __init__(self, *, area_type: str, area_id: str): + self.area_type = area_type + self.area_id = area_id + + def to_dict(self): + return { + "area_type": self.area_type, + "area_id": self.area_id, + } + + def test_work_item_key_uses_existing_completion_shape(): item = {"type": "district", "id": "CA-12", "weight": 1} assert work_item_key(item) == "district:CA-12" +def test_weighted_area_request_uses_existing_completion_shape(): + item = WeightedAreaRequest( + request=FakeAreaRequest(area_type="district", area_id="CA-12"), + weight=1, + ) + + assert item.key == "district:CA-12" + assert item.to_worker_payload() == { + "area_type": "district", + "area_id": "CA-12", + } + + +def test_partition_typed_requests_filters_completed_items(): + requests = ( + WeightedAreaRequest(FakeAreaRequest(area_type="state", area_id="CA"), weight=3), + WeightedAreaRequest( + FakeAreaRequest(area_type="district", area_id="CA-12"), + weight=1, + ), + WeightedAreaRequest(FakeAreaRequest(area_type="city", area_id="NYC"), weight=2), + ) + + chunks = partition_weighted_area_requests( + requests, + num_workers=2, + completed={"district:CA-12"}, + ) + + flattened = flatten_chunks(chunks) + assert [item.key for item in flattened] == ["state:CA", "city:NYC"] + + def test_partition_filters_completed_items(): work_items = [ {"type": "state", "id": "CA", "weight": 3}, diff --git a/tests/unit/build_outputs/test_target_universe.py b/tests/unit/build_outputs/test_target_universe.py new file mode 100644 index 000000000..d51d9cdaf --- /dev/null +++ b/tests/unit/build_outputs/test_target_universe.py @@ -0,0 +1,38 @@ +import sqlite3 + +import pytest + +from policyengine_us_data.build_outputs.target_universe import ( + RegionalTargetUniverse, + TargetUniverseReader, +) + + +def _write_target_cd_db(db_path, cd_geoids: tuple[str, ...]) -> None: + with sqlite3.connect(db_path) as conn: + conn.execute( + "CREATE TABLE stratum_constraints " + "(constraint_variable TEXT NOT NULL, value TEXT NOT NULL)" + ) + conn.executemany( + "INSERT INTO stratum_constraints VALUES (?, ?)", + [("congressional_district_geoid", cd_geoid) for cd_geoid in cd_geoids], + ) + conn.execute( + "INSERT INTO stratum_constraints VALUES (?, ?)", + ("other_constraint", "9999"), + ) + + +def test_target_universe_reader_loads_sorted_regional_cd_geoids(tmp_path): + db_path = tmp_path / "policy_data.db" + _write_target_cd_db(db_path, ("102", "101")) + + universe = TargetUniverseReader.from_sqlite(db_path).regional() + + assert universe == RegionalTargetUniverse(cd_geoids=("101", "102")) + + +def test_regional_target_universe_rejects_empty_cd_geoids(): + with pytest.raises(ValueError, match="must contain CD GEOIDs"): + RegionalTargetUniverse(cd_geoids=()) diff --git a/tests/unit/build_outputs/test_worker_responses.py b/tests/unit/build_outputs/test_worker_responses.py new file mode 100644 index 000000000..caa7413a6 --- /dev/null +++ b/tests/unit/build_outputs/test_worker_responses.py @@ -0,0 +1,84 @@ +from policyengine_us_data.build_outputs.worker_responses import ( + normalize_worker_response, +) + + +def test_normalize_worker_response_marks_fatal_and_nonfatal_issues(): + result = normalize_worker_response( + worker_index=2, + result={ + "completed": ["district:NC-01"], + "failed": [], + "errors": [{"error": "Failed to parse worker output"}], + "issues": [ + { + "item": "district:NC-01", + "phase": "validation", + "error": "validation warning", + } + ], + "validation_rows": [], + }, + ) + + assert result.completed == ("district:NC-01",) + assert result.failed == () + assert result.fatal_errors == ( + { + "error": "Failed to parse worker output", + "worker": 2, + "severity": "worker_failure", + }, + ) + assert result.issues == ( + { + "item": "district:NC-01", + "phase": "validation", + "error": "validation warning", + }, + ) + + +def test_normalize_worker_response_marks_malformed_fields_as_protocol_errors(): + result = normalize_worker_response( + worker_index=1, + result={ + "completed": "district:NC-01", + "failed": [], + "errors": [], + "validation_rows": [], + }, + ) + + assert result.completed == () + assert result.fatal_errors == ( + { + "phase": "protocol", + "error": "Worker result field 'completed' must be a list", + "worker": 1, + "severity": "protocol", + }, + ) + + +def test_normalize_worker_response_marks_failed_items_without_errors(): + result = normalize_worker_response( + worker_index=0, + result={ + "completed": [], + "failed": ["district:NC-01"], + "errors": [], + "issues": [], + "validation_rows": [], + }, + ) + + assert result.fatal_errors == ( + { + "item": "district:NC-01", + "phase": "worker", + "error": "Worker reported failed item without a matching error", + "worker": 0, + "severity": "worker_failure", + }, + ) diff --git a/tests/unit/test_modal_local_area.py b/tests/unit/test_modal_local_area.py index da847fd24..913ab6d98 100644 --- a/tests/unit/test_modal_local_area.py +++ b/tests/unit/test_modal_local_area.py @@ -1,6 +1,7 @@ from pathlib import Path from types import SimpleNamespace +from tests.support.build_outputs.area_catalog import make_geography from tests.support.modal_local_area import load_local_area_module @@ -377,6 +378,378 @@ def test_build_worker_calibration_inputs_omits_missing_optional_files(tmp_path): assert "calibration_package" not in inputs.to_wire_dict() +def test_load_area_catalog_geography_uses_mmap_for_weight_shape( + monkeypatch, + tmp_path, +): + local_area = load_local_area_module() + weights_path = tmp_path / "calibration_weights.npy" + geography_path = tmp_path / "geography_assignment.npz" + geography_path.write_text("exists") + captured = {} + + class FakeWeights: + ndim = 1 + size = 6 + dtype = local_area.np.dtype("float64") + + def fake_load(path, *, mmap_mode=None): + captured["load"] = { + "path": path, + "mmap_mode": mmap_mode, + } + return FakeWeights() + + class FakeLoader: + def load(self, **kwargs): + captured["loader"] = kwargs + return "loaded-geography" + + monkeypatch.setattr(local_area.np, "load", fake_load) + monkeypatch.setattr( + local_area, + "CalibrationGeographyLoader", + lambda: FakeLoader(), + ) + + result = local_area._load_area_catalog_geography( + weights_path=weights_path, + n_clones=3, + geography_path=geography_path, + ) + + assert result == "loaded-geography" + assert captured["load"] == { + "path": weights_path, + "mmap_mode": "r", + } + assert captured["loader"]["n_records"] == 2 + assert captured["loader"]["n_clones"] == 3 + assert captured["loader"]["geography_path"] == geography_path + + +def test_build_regional_weighted_requests_uses_catalog_geography(): + local_area = load_local_area_module(stub_policyengine=False) + catalog = local_area.USAreaCatalog( + state_codes={1: "AL", 36: "NY"}, + nyc_county_fips={"36061"}, + at_large_districts={0, 98}, + ) + geography = make_geography( + cd_geoids=["101", "102", "3601"], + county_fips=["01001", "01003", "36061"], + ) + + weighted = local_area._build_regional_weighted_requests( + geography=geography, + target_cd_geoids=("101", "102", "3601"), + catalog=catalog, + ) + + assert [item.key for item in weighted] == [ + "state:AL", + "state:NY", + "district:AL-01", + "district:AL-02", + "district:NY-01", + "city:NYC", + ] + assert [item.weight for item in weighted] == [2, 1, 1, 1, 1, 11] + + +def test_build_weighted_requests_from_work_items_keeps_override_weights(): + local_area = load_local_area_module(stub_policyengine=False) + catalog = local_area.USAreaCatalog( + state_codes={1: "AL", 36: "NY"}, + nyc_county_fips={"36061"}, + at_large_districts={0, 98}, + ) + geography = make_geography( + cd_geoids=["101", "3601"], + county_fips=["01001", "36061"], + ) + + weighted = local_area._build_weighted_requests_from_work_items( + work_items=( + {"type": "district", "id": "AL-01", "weight": 7}, + {"type": "city", "id": "NYC", "weight": 5}, + ), + geography=geography, + catalog=catalog, + ) + + assert [item.key for item in weighted] == ["district:AL-01", "city:NYC"] + assert [item.weight for item in weighted] == [7, 5] + + +def test_measure_expected_completion_ignores_unexpected_stale_files(): + local_area = load_local_area_module() + + missing, measurement = local_area._measure_expected_completion( + expected_keys={"state:AL", "district:AL-01"}, + initially_completed={"state:AL", "district:OLD"}, + completed={"state:AL", "district:OLD"}, + ) + + assert missing == {"district:AL-01"} + assert measurement == { + "expected_outputs": 2, + "valid_reused_outputs": 1, + "recomputed_outputs": 0, + "invalid_outputs": 1, + } + + +def test_coordinate_publish_happy_path_with_fake_volumes_and_artifacts( + monkeypatch, + tmp_path, +): + local_area = load_local_area_module() + run_id = "run-123" + pipeline_root = tmp_path / "pipeline" + artifact_dir = pipeline_root / "artifacts" / run_id + artifact_dir.mkdir(parents=True) + staging_root = tmp_path / "staging" + staging_root.mkdir() + for filename in ( + "calibration_weights.npy", + "source_imputed_stratified_extended_cps.h5", + "policy_data.db", + "unified_run_config.json", + ): + (artifact_dir / filename).write_text("artifact") + + real_path = Path + + def remapped_path(value=".", *args): + text = str(value) + if text.startswith("/pipeline"): + return real_path(str(pipeline_root) + text[len("/pipeline") :]) + return real_path(value, *args) + + monkeypatch.setattr(local_area, "Path", remapped_path) + monkeypatch.setattr(local_area, "VOLUME_MOUNT", str(staging_root)) + monkeypatch.setattr(local_area, "setup_gcp_credentials", lambda: None) + monkeypatch.setattr(local_area, "setup_repo", lambda branch: None) + monkeypatch.setattr(local_area, "get_version", lambda: "0.0.0") + monkeypatch.setattr(local_area, "validate_artifacts", lambda *args, **kwargs: None) + monkeypatch.setattr( + local_area, "_load_area_catalog_geography", lambda **kwargs: object() + ) + monkeypatch.setattr( + local_area, "_build_publishing_input_bundle", lambda **kwargs: object() + ) + monkeypatch.setattr( + local_area, "_resolve_scope_fingerprint", lambda **kwargs: "fingerprint" + ) + monkeypatch.setattr( + local_area, "reconcile_run_dir_fingerprint", lambda *args, **kwargs: "fresh" + ) + monkeypatch.setattr( + local_area, "_build_worker_bootstrap", lambda **kwargs: object() + ) + monkeypatch.setattr( + local_area, + "pipeline_volume", + SimpleNamespace(reload=lambda: None, commit=lambda: None), + ) + monkeypatch.setattr( + local_area, + "staging_volume", + SimpleNamespace(reload=lambda: None, commit=lambda: None), + ) + requests = ( + local_area.WeightedAreaRequest( + request=SimpleNamespace( + area_type="state", + area_id="NC", + to_dict=lambda: {"area_type": "state", "area_id": "NC"}, + ), + weight=1, + ), + local_area.WeightedAreaRequest( + request=SimpleNamespace( + area_type="district", + area_id="NC-01", + to_dict=lambda: {"area_type": "district", "area_id": "NC-01"}, + ), + weight=1, + ), + ) + monkeypatch.setattr( + local_area, + "_build_regional_weighted_requests", + lambda **kwargs: requests, + ) + monkeypatch.setattr( + local_area, + "TargetUniverseReader", + SimpleNamespace( + from_sqlite=lambda db_path: SimpleNamespace( + regional=lambda: SimpleNamespace(cd_geoids=("3701",)) + ) + ), + ) + captured = {} + + def fake_run_phase(phase_name, *, weighted_requests, completed, **kwargs): + captured["phase_name"] = phase_name + captured["weighted_keys"] = [item.key for item in weighted_requests] + captured["completed_before"] = set(completed) + return {"state:NC", "district:NC-01"}, [], [{"variable": "household_count"}] + + monkeypatch.setattr(local_area, "run_phase", fake_run_phase) + + result = local_area.coordinate_publish( + branch="main", + num_workers=1, + skip_upload=True, + n_clones=1, + validate=False, + run_id=run_id, + ) + + assert result["message"] == "Build complete for version 0.0.0. Upload skipped." + assert result["fingerprint"] == "fingerprint" + assert result["validation_rows"] == [{"variable": "household_count"}] + assert result["reuse_measurement"] == { + "expected_outputs": 2, + "valid_reused_outputs": 0, + "recomputed_outputs": 2, + "invalid_outputs": 0, + } + assert captured == { + "phase_name": "All areas", + "weighted_keys": ["state:NC", "district:NC-01"], + "completed_before": set(), + } + + +def test_coordinate_publish_default_path_uses_target_db_and_catalog( + monkeypatch, + tmp_path, +): + local_area = load_local_area_module(stub_policyengine=False) + run_id = "run-123" + pipeline_root = tmp_path / "pipeline" + artifact_dir = pipeline_root / "artifacts" / run_id + artifact_dir.mkdir(parents=True) + staging_root = tmp_path / "staging" + staging_root.mkdir() + for filename in ( + "calibration_weights.npy", + "source_imputed_stratified_extended_cps.h5", + "unified_run_config.json", + ): + (artifact_dir / filename).write_text("artifact") + + (artifact_dir / "policy_data.db").write_text("artifact") + + real_path = Path + + def remapped_path(value=".", *args): + text = str(value) + if text.startswith("/pipeline"): + return real_path(str(pipeline_root) + text[len("/pipeline") :]) + return real_path(value, *args) + + tiny_catalog = local_area.USAreaCatalog( + state_codes={1: "AL", 36: "NY"}, + nyc_county_fips={"36061"}, + at_large_districts={0, 98}, + ) + geography = make_geography( + cd_geoids=["101", "102", "3601"], + county_fips=["01001", "01003", "36061"], + ) + + monkeypatch.setattr(local_area, "Path", remapped_path) + monkeypatch.setattr(local_area, "VOLUME_MOUNT", str(staging_root)) + monkeypatch.setattr( + local_area.USAreaCatalog, "default", classmethod(lambda cls: tiny_catalog) + ) + monkeypatch.setattr(local_area, "setup_gcp_credentials", lambda: None) + monkeypatch.setattr(local_area, "setup_repo", lambda branch: None) + monkeypatch.setattr(local_area, "get_version", lambda: "0.0.0") + monkeypatch.setattr(local_area, "validate_artifacts", lambda *args, **kwargs: None) + monkeypatch.setattr( + local_area, "_load_area_catalog_geography", lambda **kwargs: geography + ) + monkeypatch.setattr( + local_area, "_build_publishing_input_bundle", lambda **kwargs: object() + ) + monkeypatch.setattr( + local_area, "_resolve_scope_fingerprint", lambda **kwargs: "fingerprint" + ) + monkeypatch.setattr( + local_area, "reconcile_run_dir_fingerprint", lambda *args, **kwargs: "fresh" + ) + monkeypatch.setattr( + local_area, "_build_worker_bootstrap", lambda **kwargs: object() + ) + monkeypatch.setattr( + local_area, + "pipeline_volume", + SimpleNamespace(reload=lambda: None, commit=lambda: None), + ) + monkeypatch.setattr( + local_area, + "staging_volume", + SimpleNamespace(reload=lambda: None, commit=lambda: None), + ) + monkeypatch.setattr( + local_area, + "TargetUniverseReader", + SimpleNamespace( + from_sqlite=lambda db_path: SimpleNamespace( + regional=lambda: SimpleNamespace(cd_geoids=("101", "102", "3601")) + ) + ), + ) + captured = {} + + def fake_run_phase(phase_name, *, weighted_requests, completed, **kwargs): + captured["phase_name"] = phase_name + captured["weighted_keys"] = [item.key for item in weighted_requests] + captured["weights"] = [item.weight for item in weighted_requests] + captured["request_payloads"] = [ + item.to_worker_payload() for item in weighted_requests + ] + return set(captured["weighted_keys"]), [], [] + + monkeypatch.setattr(local_area, "run_phase", fake_run_phase) + + result = local_area.coordinate_publish( + branch="main", + num_workers=1, + skip_upload=True, + n_clones=1, + validate=False, + run_id=run_id, + ) + + assert result["reuse_measurement"] == { + "expected_outputs": 6, + "valid_reused_outputs": 0, + "recomputed_outputs": 6, + "invalid_outputs": 0, + } + assert captured["phase_name"] == "All areas" + assert captured["weighted_keys"] == [ + "state:AL", + "state:NY", + "district:AL-01", + "district:AL-02", + "district:NY-01", + "city:NYC", + ] + assert captured["weights"] == [2, 1, 1, 1, 1, 11] + assert captured["request_payloads"][0]["filters"] == [ + {"geography_field": "state_fips", "op": "eq", "value": 1} + ] + assert captured["request_payloads"][-1]["validation_geographic_ids"] == ["3601"] + + def test_build_areas_worker_surfaces_successful_worker_stderr( monkeypatch, capsys, @@ -408,22 +781,211 @@ def fake_run(cmd, **kwargs): monkeypatch.setattr(local_area.subprocess, "run", fake_run) + result = local_area.build_areas_worker( + "main", + "run-123", + "regional", + [{"type": "district", "id": "NC-01"}], + { + "weights": "/tmp/calibration_weights.npy", + "dataset": "/tmp/source.h5", + "database": "/tmp/policy_data.db", + }, + False, + "regional-fingerprint", + ) + + captured = capsys.readouterr() + assert result["completed"] == ["district:NC-01"] + assert "Worker session ready: scope=regional, bootstrap=used" in captured.err + assert "--work-items" in captured_cmd["cmd"] + assert "--requests-json" not in captured_cmd["cmd"] + assert "--scope-fingerprint" in captured_cmd["cmd"] + assert "regional-fingerprint" in captured_cmd["cmd"] + + +def test_build_areas_worker_prefers_typed_request_payloads( + monkeypatch, + tmp_path, +): + local_area = load_local_area_module() + monkeypatch.setattr(local_area, "setup_gcp_credentials", lambda: None) + monkeypatch.setattr(local_area, "setup_repo", lambda branch: None) + monkeypatch.setattr(local_area, "VOLUME_MOUNT", str(tmp_path / "staging")) + monkeypatch.setattr( + local_area, + "pipeline_volume", + SimpleNamespace(reload=lambda: None), + ) + monkeypatch.setattr( + local_area, + "staging_volume", + SimpleNamespace(reload=lambda: None, commit=lambda: None), + ) + captured_cmd = {} + + def fake_run(cmd, **kwargs): + captured_cmd["cmd"] = cmd + return SimpleNamespace( + returncode=0, + stdout='{"completed": ["district:NC-01"], "failed": [], "errors": []}', + stderr="", + ) + + monkeypatch.setattr(local_area.subprocess, "run", fake_run) + result = local_area.build_areas_worker( branch="main", run_id="run-123", scope="regional", - work_items=[{"type": "district", "id": "NC-01"}], + request_payloads=[ + { + "area_type": "district", + "area_id": "NC-01", + "display_name": "NC-01", + "output_relative_path": "districts/NC-01.h5", + } + ], calibration_inputs={ "weights": "/tmp/calibration_weights.npy", "dataset": "/tmp/source.h5", "database": "/tmp/policy_data.db", }, validate=False, - scope_fingerprint="regional-fingerprint", ) - captured = capsys.readouterr() assert result["completed"] == ["district:NC-01"] - assert "Worker session ready: scope=regional, bootstrap=used" in captured.err - assert "--scope-fingerprint" in captured_cmd["cmd"] - assert "regional-fingerprint" in captured_cmd["cmd"] + assert "--requests-json" in captured_cmd["cmd"] + assert "--work-items" not in captured_cmd["cmd"] + + +def test_run_phase_partitions_typed_requests_and_aggregates_issues( + monkeypatch, + tmp_path, +): + local_area = load_local_area_module() + request = SimpleNamespace( + area_type="district", + area_id="NC-01", + to_dict=lambda: {"area_type": "district", "area_id": "NC-01"}, + ) + weighted = (local_area.WeightedAreaRequest(request=request, weight=1),) + run_dir = tmp_path / "run-123" + (run_dir / "districts").mkdir(parents=True) + (run_dir / "districts" / "NC-01.h5").write_text("h5") + captured = {} + + class FakeHandle: + object_id = "fc-123" + + def get(self): + return { + "completed": ["district:NC-01"], + "failed": [], + "errors": [], + "issues": [ + { + "item": "district:NC-01", + "phase": "validation", + "error": "validation warning", + } + ], + "validation_rows": [{"variable": "household_count"}], + } + + def fake_spawn(**kwargs): + captured.update(kwargs) + return FakeHandle() + + monkeypatch.setattr( + local_area, + "build_areas_worker", + SimpleNamespace(spawn=fake_spawn), + ) + monkeypatch.setattr( + local_area, + "staging_volume", + SimpleNamespace(reload=lambda: None), + ) + + completed, errors, validation_rows = local_area.run_phase( + "Typed requests", + weighted_requests=weighted, + num_workers=1, + completed=set(), + branch="main", + run_id="run-123", + calibration_inputs={ + "weights": "/tmp/calibration_weights.npy", + "dataset": "/tmp/source.h5", + "database": "/tmp/policy_data.db", + }, + run_dir=run_dir, + validate=True, + scope_fingerprint="fingerprint", + ) + + assert captured["request_payloads"] == [request.to_dict()] + assert captured["work_items"] is None + assert completed == {"district:NC-01"} + assert errors == [ + { + "item": "district:NC-01", + "phase": "validation", + "error": "validation warning", + } + ] + assert validation_rows == [{"variable": "household_count"}] + + +def test_run_phase_records_worker_transport_failure_separately( + monkeypatch, + tmp_path, +): + local_area = load_local_area_module() + request = SimpleNamespace( + area_type="district", + area_id="NC-01", + to_dict=lambda: {"area_type": "district", "area_id": "NC-01"}, + ) + weighted = (local_area.WeightedAreaRequest(request=request, weight=1),) + run_dir = tmp_path / "run-123" + run_dir.mkdir() + + class FakeHandle: + object_id = "fc-123" + + def get(self): + raise RuntimeError("modal transport reset") + + monkeypatch.setattr( + local_area, + "build_areas_worker", + SimpleNamespace(spawn=lambda **kwargs: FakeHandle()), + ) + monkeypatch.setattr( + local_area, + "staging_volume", + SimpleNamespace(reload=lambda: None), + ) + + completed, errors, validation_rows = local_area.run_phase( + "Typed requests", + weighted_requests=weighted, + num_workers=1, + completed=set(), + branch="main", + run_id="run-123", + calibration_inputs={ + "weights": "/tmp/calibration_weights.npy", + "dataset": "/tmp/source.h5", + "database": "/tmp/policy_data.db", + }, + run_dir=run_dir, + validate=True, + ) + + assert completed == set() + assert validation_rows == [] + assert errors[0]["worker"] == 0 + assert errors[0]["error"] == "modal transport reset"