From 92948f8ab39ad2d18e05cafaa012286187fafa3e Mon Sep 17 00:00:00 2001 From: Claude Date: Sun, 17 May 2026 03:38:53 +0000 Subject: [PATCH 1/3] Phase 4: SMAP soil moisture, wired into the spine as an encoder feature Consolidates nasa_moisture into a single main(lat, lon, start, end) -> DataFrame[Date, soil_moisture] and deletes the three competing (and unused) implementations (appeears.py, soilmoisture.py, soilmoisture2.py). Wiring: - combine_data fetches SMAP per station alongside SWE; lazy-imports nasa_moisture so the training env doesn't fail to start when earthaccess/h5py aren't installed - merge_dataframes treats soil_moisture like SWE: slow-varying interior interpolation, missing rows default to 0 (never drop a row) - normalize_data adds soil_moisture to OPTIONAL_NUMERIC with the same z-score path - windowing adds soil_moisture to ENCODER_FEATURES only; decoder window stays flow- / SWE- / SM-free (no skillful 14-day SM forecast exists) - train.py sanity-check includes the new column Failure modes (HUC8 lookup, Earthdata auth, search, download, extraction) all degrade gracefully to an empty SMAP series so the spine still produces rows. OPENFLOW_DISABLE_SMAP=1 short-circuits the fetch for the ablation baseline. CI: - ml_training.yml installs earthaccess/h5py/shapely after tensorflow so pip resolves shared transitive deps against the tensorflow pin; passes EARTHDATA_USERNAME / EARTHDATA_PASSWORD to the training step - tests/test_nasa_moisture.py covers extraction, granule date parsing, auth/search/download failure paths, and the end-to-end main() flow with mocked earthaccess + a tiny real HDF5 fixture - test_combine_data, test_normalize_data, test_windowing updated for the new soil_moisture column --- .github/workflows/ml_training.yml | 10 + openFlowML/combine_data.py | 84 +++- openFlowML/data/appeears.py | 415 -------------------- openFlowML/data/nasa_moisture.py | 568 +++++++++++++--------------- openFlowML/data/soilmoisture.py | 26 -- openFlowML/data/soilmoisture2.py | 98 ----- openFlowML/data/utils/data_utils.py | 135 +------ openFlowML/normalize_data.py | 16 +- openFlowML/train.py | 3 +- openFlowML/windowing.py | 13 +- tests/test_combine_data.py | 31 ++ tests/test_nasa_moisture.py | 232 ++++++++++++ tests/test_normalize_data.py | 24 ++ tests/test_windowing.py | 10 + 14 files changed, 671 insertions(+), 994 deletions(-) delete mode 100644 openFlowML/data/appeears.py delete mode 100644 openFlowML/data/soilmoisture.py delete mode 100644 openFlowML/data/soilmoisture2.py create mode 100644 tests/test_nasa_moisture.py diff --git a/.github/workflows/ml_training.yml b/.github/workflows/ml_training.yml index 1fd5fd3..05fd329 100644 --- a/.github/workflows/ml_training.yml +++ b/.github/workflows/ml_training.yml @@ -19,11 +19,21 @@ jobs: python-version: '3.11' - name: Install dependencies + # earthaccess + h5py + shapely are needed for the SMAP soil-moisture + # fetch wired into combine_data. They are installed AFTER tensorflow so + # pip resolves the shared transitive deps (notably typing-extensions) + # against the tensorflow pin. If the version resolver cannot find a + # compatible earthaccess, combine_data still runs -- the SMAP fetch + # silently degrades to "no data" and soil_moisture defaults to 0. run: | python -m pip install --upgrade pip pip install -r openFlowML/requirements.txt + pip install earthaccess h5py shapely - name: Train model + env: + EARTHDATA_USERNAME: ${{ secrets.EARTHDATA_USERNAME }} + EARTHDATA_PASSWORD: ${{ secrets.EARTHDATA_PASSWORD }} run: python openFlowML/train.py # The model is not usable for inference without the scaler parameters diff --git a/openFlowML/combine_data.py b/openFlowML/combine_data.py index 5ed5fe3..78eb0c8 100644 --- a/openFlowML/combine_data.py +++ b/openFlowML/combine_data.py @@ -20,7 +20,12 @@ gaps are dropped rather than filled with a pooled (cross-station) mean - the per-station frames carry a clean 'site_id' column -TODO (later in Phase 2): wire in SWE as a history-window feature. +Phase 4 (SMAP soil moisture): + - Soil moisture is fetched per station from NASA SMAP L3 enhanced via + nasa_moisture (lazy import; failure degrades to "no SMAP" for that site). + - Treated like SWE in the spine: slow-varying, longer interior interpolation + limit, missing rows default to 0 rather than being dropped. + - Set OPENFLOW_DISABLE_SMAP=1 to run the ablation baseline without SMAP. """ if not logging.getLogger().hasHandlers(): @@ -36,6 +41,10 @@ # SWE changes slowly (snowpack accumulates/melts over weeks), so we tolerate # longer interior gaps in the SWE series before giving up on a value. MAX_SWE_GAP_DAYS = 30 +# SMAP has a ~1-3 day revisit cadence per pass; gaps come from RFI / dense +# vegetation / frozen ground. Soil moisture itself is slow-varying so the same +# generous interior interpolation limit as SWE is appropriate. +MAX_SM_GAP_DAYS = 30 def _to_daily_series(df, value_columns, daily_index): @@ -52,15 +61,16 @@ def _to_daily_series(df, value_columns, daily_index): def merge_dataframes(noaa_data, flow_data, site_id, start_date, end_date, - swe_data=None, huc8=None): + swe_data=None, sm_data=None, huc8=None): """ - Merge a site's NOAA temperature, flow, and SWE data onto one regular daily - index. + Merge a site's NOAA temperature, flow, SWE and soil-moisture data onto one + regular daily index. Short interior flow/temp gaps (<= MAX_GAP_DAYS) are interpolated time-aware; rows still missing a core value afterwards are dropped (no pooled-mean - fill). SWE is interpolated with a longer limit (it's slow-varying) and any - remaining missing values default to 0 -- they don't drop the row. + fill). SWE and soil_moisture are interpolated with longer limits (slow- + varying) and any remaining missing values default to 0 -- they don't drop + the row. """ if 'Date' not in noaa_data or 'Date' not in flow_data: raise ValueError("'Date' column missing in one of the dataframes") @@ -96,12 +106,24 @@ def merge_dataframes(noaa_data, flow_data, site_id, start_date, end_date, else: combined['SWE'] = float('nan') - # Drop rows still missing any core flow/temp value. SWE is NOT in - # CORE_COLUMNS, so a missing SWE never drops a row. + # SMAP soil moisture: slow-varying surface state; same shape as SWE + # handling -- generous interior interpolation limit, missing values + # default to 0 instead of dropping the row. + if sm_data is not None and not sm_data.empty: + sm_daily = _to_daily_series(sm_data, ['soil_moisture'], daily_index) + sm_daily['soil_moisture'] = pd.to_numeric(sm_daily['soil_moisture'], errors='coerce') + sm_daily = sm_daily.interpolate(method='time', limit=MAX_SM_GAP_DAYS, limit_area='inside') + combined['soil_moisture'] = sm_daily['soil_moisture'] + else: + combined['soil_moisture'] = float('nan') + + # Drop rows still missing any core flow/temp value. SWE / soil_moisture + # are NOT in CORE_COLUMNS, so a missing one never drops a row. before = len(combined) combined = combined.dropna(subset=CORE_COLUMNS) - # Any remaining SWE gaps default to 0 ("no snow data" / out-of-season). + # Any remaining SWE/SM gaps default to 0 ("no data" / out-of-season). combined['SWE'] = combined['SWE'].fillna(0.0) + combined['soil_moisture'] = combined['soil_moisture'].fillna(0.0) logger.info( "Site %s: %d/%d daily rows usable after gap handling", site_id, len(combined), before, @@ -121,12 +143,13 @@ def merge_dataframes(noaa_data, flow_data, site_id, start_date, end_date, def fetch_and_process_data(prefix, site_id, start_date, end_date, flow_data): """ - Resolve a site's coordinates and fetch its NOAA temperature + HUC SWE - series, and its HUC8 basin id (used as a Phase 3 basin embedding key). + Resolve a site's coordinates and fetch its NOAA temperature, HUC SWE and + SMAP soil-moisture series, and its HUC8 basin id (used as a Phase 3 basin + embedding key). - Returns (noaa_data, swe_data, huc8). Either dataframe may be empty (SWE - degrades gracefully; missing NOAA causes the caller to skip the site). - huc8 may be None when the lookup fails. + Returns (noaa_data, swe_data, sm_data, huc8). Any of the dataframes may be + empty (SWE and SM degrade gracefully; missing NOAA causes the caller to + skip the site). huc8 may be None when the lookup fails. """ if prefix == "USGS": coords_dict = get_usgs_coordinates(site_id) @@ -137,7 +160,7 @@ def fetch_and_process_data(prefix, site_id, start_date, end_date, flow_data): if not coords_dict: logger.error(f"Could not resolve coordinates for {prefix}:{site_id}. Skipping...") - return None, None, None + return None, None, None, None latitude = float(coords_dict['latitude']) longitude = float(coords_dict['longitude']) @@ -149,7 +172,7 @@ def fetch_and_process_data(prefix, site_id, start_date, end_date, flow_data): if noaa_data is None or noaa_data.empty: logger.warning(f"No NOAA data available for site ID {site_id}. Skipping...") - return None, None, None + return None, None, None, None noaa_data = noaa_data.copy() noaa_data['USGS_site_ID'] = site_id @@ -173,9 +196,32 @@ def fetch_and_process_data(prefix, site_id, start_date, end_date, flow_data): logger.warning("SWE fetch failed for %s: %s", site_id, e) swe_data = pd.DataFrame(columns=['Date', 'SWE']) + # SMAP soil moisture: also a history-window feature, fetched per HUC8. + # Lazy import keeps the heavy earthaccess/h5py stack out of the core + # training-env import path; if it's not installed or any step fails + # (auth, search, download, extraction), sm_data ends up empty and + # merge_dataframes treats it as "no data" (defaults to 0, doesn't drop + # the row), mirroring the SWE handling. + # + # OPENFLOW_DISABLE_SMAP=1 short-circuits to empty -- this is the lever + # for the ablation run (train without SMAP and compare on the held-out + # test set). + sm_data = pd.DataFrame(columns=['Date', 'soil_moisture']) + if os.getenv('OPENFLOW_DISABLE_SMAP', '').strip() in ('1', 'true', 'True'): + logger.info("OPENFLOW_DISABLE_SMAP set -- skipping SMAP for %s", site_id) + else: + try: + from data import nasa_moisture + sm_data = nasa_moisture.main(latitude, longitude, start_date, end_date) + except ImportError as e: + logger.warning("Soil-moisture deps unavailable (%s); skipping SMAP for %s", + e, site_id) + except Exception as e: + logger.warning("SMAP fetch failed for %s: %s", site_id, e) + # Gap handling and numeric coercion happen in merge_dataframes (per station, # on the regular daily index) -- not here, and never via a pooled mean. - return noaa_data, swe_data, huc8 + return noaa_data, swe_data, sm_data, huc8 def get_site_ids(filename=None): @@ -231,7 +277,7 @@ def main(training_num_years=7): logger.warning(f"Unrecognized prefix for site ID {site_id}. Skipping...") continue - noaa_dataframe, swe_dataframe, huc8 = fetch_and_process_data( + noaa_dataframe, swe_dataframe, sm_dataframe, huc8 = fetch_and_process_data( prefix, id, start_date, end_date, flow_dataframe) if noaa_dataframe is None or noaa_dataframe.empty or flow_dataframe.empty: logger.warning(f"No usable data for site ID {site_id}. Skipping...") @@ -239,7 +285,7 @@ def main(training_num_years=7): merged = merge_dataframes( noaa_dataframe, flow_dataframe, site_id, start_date, end_date, - swe_data=swe_dataframe, huc8=huc8) + swe_data=swe_dataframe, sm_data=sm_dataframe, huc8=huc8) if merged.empty: logger.warning(f"No usable merged data for site ID {site_id}. Skipping...") continue diff --git a/openFlowML/data/appeears.py b/openFlowML/data/appeears.py deleted file mode 100644 index 6db296c..0000000 --- a/openFlowML/data/appeears.py +++ /dev/null @@ -1,415 +0,0 @@ -import shutil -import logging -from data.utils.get_poly import simplify_polygon, validate_polygon, get_huc_polygon -import os -import datetime -import rasterio -from rasterio.plot import show -import numpy as np -from rasterio.mask import mask -from shapely.geometry import box -from data.utils.get_poly import check_polygon_intersection, get_huc_polygon, validate_polygon, simplify_polygon -from data.utils.data_utils import appeears_login, appeears_logout, load_vars, get_earthdata_auth, get_smap_data_bounds -import time -import tempfile -import argparse -import requests -import json -from mpl_toolkits.axes_grid1 import make_axes_locatable -from shapely import Polygon -import importlib.util -import matplotlib.pyplot as plt -from matplotlib.patches import Polygon as mplPolygon -matplotlib_spec = importlib.util.find_spec("matplotlib") -matplotlib_available = matplotlib_spec is not None - -# Configure logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') - -load_vars() - -def check_appeears_product(product_id, layer_name): - """ - Check if the specified product and layer are available in AppEEARS. - """ - # First, get the list of all products - products_url = "https://appeears.earthdatacloud.nasa.gov/api/product" - - try: - response = requests.get(products_url, headers=None, timeout=30) - response.raise_for_status() - - # Print raw response for debugging - print("Raw API Response:") - print(response.text[:1000]) # Print first 1000 characters - - products = response.json() - - print(f"Type of products: {type(products)}") - - if isinstance(products, list): - print("Products is a list. Printing first item:") - print(products[0] if products else "Empty list") - elif isinstance(products, dict): - print("Products is a dictionary. Printing keys:") - print(products.keys()) - else: - print(f"Unexpected type for products: {type(products)}") - - print("Available Soil Moisture Products:") - if isinstance(products, list): - for product in products: - if isinstance(product, dict) and 'ProductAndVersion' in product: - # Check if the product is related to soil moisture - if 'soil moisture' in product.get('Description', '').lower(): - print(f"- {product['ProductAndVersion']}: {product.get('Description', 'N/A')}") - print(f" Available: {product.get('Available', 'N/A')}") - print(f" Temporal Extent: {product.get('TemporalExtentStart', 'N/A')} to {product.get('TemporalExtentEnd', 'N/A')}") - print(f" Resolution: {product.get('Resolution', 'N/A')}") - print(f" Source: {product.get('Source', 'N/A')}") - print(" ---") - - except requests.Timeout: - logging.error("Request timed out while checking product availability") - return False - except requests.RequestException as e: - logging.error(f"Error checking product availability: {e}") - return False - except Exception as e: - logging.error(f"Unexpected error: {e}") - return False - - return True # Return True for now, adjust based on actual product availability check - -def get_product_layers(token, product_id): - """ - Get available layers for a specific product from AppEEARS API. - """ - url = f"https://appeears.earthdatacloud.nasa.gov/api/product/{product_id}" - headers = {"Authorization": f"Bearer {token}"} - - try: - response = requests.get(url, headers=headers) - response.raise_for_status() - product_info = response.json() - - logging.info(f"Available layers for {product_id}:") - for layer_name, layer_info in product_info.items(): - logging.info(f"- {layer_name}: {layer_info.get('Description', 'No description available')}") - - return product_info - except requests.RequestException as e: - logging.error(f"Error getting product layers: {e}") - return {} - -def submit_appears_task(token, polygon, start_date, end_date): - """ - Submit a task to AppEEARS API for SMAP data retrieval. - """ - product_id = "SPL3SMP_E.006" - layers = get_product_layers(token, product_id) - - if not layers: - logging.error(f"No layers found for product {product_id}") - return None - - # Choose the first available layer related to soil moisture - soil_moisture_layer = next((layer_name for layer_name, layer_info in layers.items() - if 'soil_moisture' in layer_name.lower() or 'soil_moisture' in layer_info.get('Description', '').lower()), - None) - - if not soil_moisture_layer: - logging.error(f"No soil moisture layer found for product {product_id}") - return None - - logging.info(f"Selected layer: {soil_moisture_layer}") - - url = "https://appeears.earthdatacloud.nasa.gov/api/task" - headers = { - "Authorization": f"Bearer {token}", - "Content-Type": "application/json" - } - - task_payload = { - "task_type": "area", - "task_name": "SMAP_Soil_Moisture_Extraction", - "params": { - "dates": [ - { - "startDate": start_date.strftime("%m-%d-%Y"), - "endDate": end_date.strftime("%m-%d-%Y") - } - ], - "layers": [ - { - "product": product_id, - "layer": soil_moisture_layer - } - ], - "output": { - "format": { - "type": "geotiff" - }, - "projection": "geographic" - }, - "geo": { - "type": "FeatureCollection", - "features": [ - { - "type": "Feature", - "properties": {}, - "geometry": { - "type": "Polygon", - "coordinates": [polygon] - } - } - ] - } - } - } - - try: - logging.info(f"Submitting task to URL: {url}") - logging.info(f"Headers: {headers}") - logging.info(f"Payload: {json.dumps(task_payload, indent=2)}") - - response = requests.post(url, headers=headers, data=json.dumps(task_payload), timeout=30) - - logging.info(f"Response status code: {response.status_code}") - logging.info(f"Response headers: {response.headers}") - logging.info(f"Response content: {response.text}") - - if response.status_code == 202: - return response.json()["task_id"] - else: - logging.error(f"Failed to submit task. Status code: {response.status_code}") - logging.error(f"Response: {response.text}") - return None - except requests.Timeout: - logging.error("Request timed out while submitting task") - return None - except requests.RequestException as e: - logging.error(f"Error submitting task: {e}") - return None - -def check_task_status(token, task_id): - """ - Check the status of an AppEEARS task. - """ - url = f"https://appeears.earthdatacloud.nasa.gov/api/task/{task_id}" - headers = {"Authorization": f"Bearer {token}"} - - response = requests.get(url, headers=headers) - - if response.status_code == 200: - return response.json()["status"] - else: - logging.error(f"Failed to check task status. Status code: {response.status_code}") - return None - -def download_task_results(token, task_id, output_dir): - """ - Download the results of a completed AppEEARS task. - """ - url = f"https://appeears.earthdatacloud.nasa.gov/api/bundle/{task_id}" - headers = {"Authorization": f"Bearer {token}"} - - response = requests.get(url, headers=headers) - - if response.status_code == 200: - bundle_info = response.json() - for file_info in bundle_info["files"]: - file_id = file_info["file_id"] - file_name = file_info["file_name"] - download_url = f"https://appeears.earthdatacloud.nasa.gov/api/bundle/{task_id}/{file_id}" - - file_response = requests.get(download_url, headers=headers, allow_redirects=True) - if file_response.status_code == 200: - with open(f"{output_dir}/{file_name}", "wb") as f: - f.write(file_response.content) - logging.info(f"Downloaded: {file_name}") - else: - logging.error(f"Failed to download {file_name}. Status code: {file_response.status_code}") - else: - logging.error(f"Failed to get bundle info. Status code: {response.status_code}") - -def extract_soil_moisture_from_geotiff(geotiff_path, polygon): - """ - Extract soil moisture data for the given polygon from the GeoTIFF file using rasterio. - """ - try: - with rasterio.open(geotiff_path) as src: - # Create a GeoJSON-like geometry object - geom = {"type": "Polygon", "coordinates": [polygon]} - - # Mask the raster with the polygon - out_image, out_transform = mask(src, [geom], crop=True) - - # Get the data from the masked raster - data = out_image[0] # Assuming it's a single-band raster - - # Calculate average soil moisture - valid_data = data[data != src.nodata] - if len(valid_data) > 0: - average_moisture = np.mean(valid_data) - return average_moisture - else: - logging.warning("No valid data found within the polygon") - return None - - except Exception as e: - logging.error(f"Error extracting soil moisture data: {e}") - logging.error("Traceback: ", exc_info=True) - return None - -def visualize_smap_data(geotiff_path, polygon, average_moisture): - """ - Create a detailed visualization of SMAP data with the polygon overlay. - """ - try: - with rasterio.open(geotiff_path) as src: - fig, ax = plt.subplots(figsize=(12, 8)) - - # Plot the SMAP data - show(src, ax=ax, cmap='YlGnBu') - - # Plot the polygon - poly = mplPolygon(polygon, facecolor='none', edgecolor='red', linewidth=2) - ax.add_patch(poly) - - # Add colorbar - divider = make_axes_locatable(ax) - cax = divider.append_axes("right", size="5%", pad=0.05) - plt.colorbar(ax.images[0], cax=cax, label='Soil Moisture') - - # Set the extent to focus on the polygon - min_lon, min_lat, max_lon, max_lat = Polygon(polygon).bounds - ax.set_xlim(min_lon - 0.5, max_lon + 0.5) - ax.set_ylim(min_lat - 0.5, max_lat + 0.5) - - # Add title and labels - plt.title(f'SMAP Soil Moisture Data with HUC8 Polygon\nAverage Soil Moisture: {average_moisture:.4f}') - plt.xlabel('Longitude') - plt.ylabel('Latitude') - - # Add grid - ax.grid(True, linestyle='--', alpha=0.5) - - # Show the plot - plt.tight_layout() - plt.show() - - except Exception as e: - logging.error(f"Error visualizing SMAP data: {e}") - logging.error("Traceback: ", exc_info=True) - -# Add a function to verify the token: -def verify_token(token): - url = "https://appeears.earthdatacloud.nasa.gov/api/user" - headers = {"Authorization": f"Bearer {token}"} - try: - response = requests.get(url, headers=headers) - if response.status_code == 200: - logging.info("Token verified successfully") - return True - else: - logging.error(f"Failed to verify token. Status code: {response.status_code}") - logging.error(f"Response: {response.text}") - return False - except requests.RequestException as e: - logging.error(f"Error verifying token: {e}") - return False - -def main(start_date, end_date, lat, lon, visual): - # Login to AppEEARS - token = appeears_login() - if not token: - logging.error("Failed to login to AppEEARS. Please check your credentials.") - return - - # Check if the product and layer are available - product_id = "SPL3SMP_E.003" - layer_name = "soil_moisture" - if not check_appeears_product(product_id, layer_name): - logging.error("Required product or layer is not available. Exiting.") - appeears_logout() - return - - # Get the HUC8 polygon - huc8_polygon = get_huc_polygon(lat, lon, 8) - if not huc8_polygon: - logging.error("Failed to retrieve HUC8 polygon") - appeears_logout() - return - - # Simplify & validate the polygon - simplified_polygon = simplify_polygon(huc8_polygon) - simplified_polygon = validate_polygon(simplified_polygon) - logging.debug(f"Validated polygon coordinates: {simplified_polygon}") - - # Submit AppEEARS task - task_id = submit_appears_task(token, simplified_polygon, start_date, end_date) - if not task_id: - logging.error("Failed to submit AppEEARS task") - appeears_logout() - return - - # Wait for task to complete - max_retries = 30 # 30 minutes maximum wait time - retries = 0 - while retries < max_retries: - status = check_task_status(token, task_id) - if status == "done": - logging.info("Task completed successfully") - break - elif status == "error": - logging.error("Task failed") - appeears_logout() - return - elif status is None: - logging.error("Failed to check task status") - appeears_logout() - return - time.sleep(60) # Wait for 60 seconds before checking again - retries += 1 - else: - logging.error("Task did not complete within the maximum wait time") - appeears_logout() - return - - # Download results - output_dir = tempfile.mkdtemp() - download_task_results(token, task_id, output_dir) - - # Process and visualize results - geotiff_file = next((f for f in os.listdir(output_dir) if f.endswith('.tif')), None) - if geotiff_file: - geotiff_path = os.path.join(output_dir, geotiff_file) - average_soil_moisture = extract_soil_moisture_from_geotiff(geotiff_path, simplified_polygon) - - if average_soil_moisture is not None: - logging.info(f"Average soil moisture: {average_soil_moisture:.4f}") - - if visual: - visualize_smap_data(geotiff_path, simplified_polygon, average_soil_moisture) - else: - logging.error("Failed to calculate average soil moisture") - else: - logging.error("No GeoTIFF file found in the downloaded results") - - # Clean up - shutil.rmtree(output_dir) - - # Logout from AppEEARS - appeears_logout() - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Calculate average soil moisture for a HUC8 polygon from SMAP L3 data.') - parser.add_argument('--start-date', type=lambda d: datetime.datetime.strptime(d, '%Y-%m-%d').date(), required=True, help='Start Date in YYYY-MM-DD format') - parser.add_argument('--end-date', type=lambda d: datetime.datetime.strptime(d, '%Y-%m-%d').date(), required=True, help='End Date in YYYY-MM-DD format') - parser.add_argument('--lat', type=float, required=True, help='Latitude of the point within the desired HUC8 polygon') - parser.add_argument('--lon', type=float, required=True, help='Longitude of the point within the desired HUC8 polygon') - parser.add_argument('--visual', action='store_true', help='Enable matplotlib visualization') - args = parser.parse_args() - - main(args.start_date, args.end_date, args.lat, args.lon, args.visual) \ No newline at end of file diff --git a/openFlowML/data/nasa_moisture.py b/openFlowML/data/nasa_moisture.py index 84eda14..1b920c4 100644 --- a/openFlowML/data/nasa_moisture.py +++ b/openFlowML/data/nasa_moisture.py @@ -1,330 +1,294 @@ -import datetime -import numpy as np -import h5py -import shutil -from scipy.spatial import cKDTree -from shapely.ops import transform -from shapely.geometry import Polygon -import logging +""" +SMAP L3 enhanced soil moisture (SPL3SMP_E v006) timeseries by HUC8. + +Given a station's lat/lon and a date window, this module looks up the enclosing +HUC8 polygon, searches NASA Earthdata for every SMAP granule whose footprint +intersects that polygon over the window, downloads them one at a time, +extracts the average volumetric soil moisture inside the polygon for each +granule, and returns a single daily series. + +Public API: + main(lat, lon, start_date, end_date) -> DataFrame[Date, soil_moisture] + +A failure at any step (HUC8 lookup, Earthdata auth, search, download, +extraction) degrades to an empty DataFrame -- combine_data treats missing +soil moisture as "no data" (defaults to 0 in the spine) rather than dropping +the row, mirroring the SWE handling. +""" + import argparse -from earthaccess import * -from data.utils.get_poly import check_polygon_intersection, get_huc_polygon, validate_polygon, simplify_polygon -from data.utils.data_utils import load_vars, get_earthdata_auth, get_smap_data_bounds +import logging import os import tempfile -# Conditionally import matplotlib -import importlib.util -from shapely.ops import transform -import earthaccess -import matplotlib.pyplot as plt -from matplotlib.patches import Polygon as mplPolygon -matplotlib_spec = importlib.util.find_spec("matplotlib") -matplotlib_available = matplotlib_spec is not None +from datetime import date, datetime +from typing import Optional + +import pandas as pd + +logger = logging.getLogger(__name__) +if not logging.getLogger().hasHandlers(): + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') -# Configure logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +# NSIDC SMAP L3 enhanced (9 km) passive radiometer soil moisture. +SMAP_SHORT_NAME = "SPL3SMP_E" +SMAP_VERSION = "006" +# HDF5 fill value for missing pixels in the SMAP product. +FILL_VALUE = -9999.0 +# Valid SMAP retrieval range for volumetric soil moisture (m^3/m^3). +VALID_MIN = 0.0 +VALID_MAX = 1.0 -load_vars() -def search_and_download_smap_data(start_date, end_date, auth, simplified_polygon): +def _empty() -> pd.DataFrame: + return pd.DataFrame(columns=['Date', 'soil_moisture']) + + +def _to_date(d) -> date: + if isinstance(d, datetime): + return d.date() + if isinstance(d, date): + return d + return datetime.strptime(str(d)[:10], "%Y-%m-%d").date() + + +def _login_earthdata(): """ - Search for SMAP L3 data between the given dates that intersect with the given polygon, - and download the smallest intersecting granule. + Authenticate with NASA Earthdata via earthaccess. Returns the Auth on + success, None on any failure -- the caller short-circuits to an empty + series so a missing credential doesn't bring down combine_data. """ try: - # Ensure we're authenticated - if not auth.authenticated: - logging.info("Not logged in, attempting to log in...") - if not auth.login(strategy="environment"): - raise RuntimeError("Failed to authenticate with NASA Earthdata Login") - else: - logging.info("Already authenticated, proceeding with search and download") - - # Search for SPL3SMP_E collection - collection_query = earthaccess.DataCollections().short_name("SPL3SMP_E").version("006") - collections = collection_query.get() - - if not collections: - logging.error("SMAP L3 SM_P_E collection not found") - return None - - collection = collections[0] - concept_id = collection.concept_id() - logging.info(f"Found SMAP_L3_SM_P_E collection with concept_id: {concept_id}") - - - # Calculate bounding box from simplified_polygon - lons, lats = zip(*simplified_polygon) - min_lon, max_lon = min(lons), max(lons) - min_lat, max_lat = min(lats), max(lats) - - # Now search for granules using DataGranules - granule_query = (earthaccess.DataGranules() - .concept_id(concept_id) - .temporal(start_date, end_date) - .bounding_box(min_lon, min_lat, max_lon, max_lat)) - - granule_hits = granule_query.hits() - logging.info(f"Number of granules found: {granule_hits}") - - if granule_hits == 0: - logging.warning(f"No SMAP data found from {start_date} to {end_date}") - return None - - granules = granule_query.get_all() - - if not granules: - logging.warning(f"No granules retrieved despite positive hit count") - return None - - # Log the number of granules found - logging.info(f"Retrieved {len(granules)} granules") - - # Find the smallest granule - smallest_granule = min(granules, key=lambda g: g.size()) - logging.info(f"Smallest intersecting granule size: {smallest_granule.size()} MB") - - # Create a temporary directory that won't be automatically deleted - temp_dir = tempfile.mkdtemp() - logging.info(f"Created temporary directory: {temp_dir}") - - # Download only the smallest granule - try: - downloaded_files = earthaccess.download(smallest_granule, local_path=temp_dir) - except Exception as e: - logging.error(f"Error during download: {str(e)}") - if temp_dir: - shutil.rmtree(temp_dir) - return None, None - - if downloaded_files: - downloaded_file = downloaded_files[0] - logging.info(f"Successfully downloaded: {downloaded_file}") - - # Verify that the file exists - if os.path.exists(downloaded_file): - logging.info(f"File exists at {downloaded_file}") - file_size = os.path.getsize(downloaded_file) - logging.info(f"File size: {file_size} bytes") - else: - logging.error(f"File does not exist at {downloaded_file}") - if temp_dir: - shutil.rmtree(temp_dir) - return None, None - - return downloaded_file, temp_dir - else: - logging.error("Failed to download SMAP data") - if temp_dir: - shutil.rmtree(temp_dir) - return None, None - + import earthaccess + except ImportError as e: + logger.warning("earthaccess not installed: %s", e) + return None + try: + auth = earthaccess.login(strategy="environment") + if auth is not None and getattr(auth, 'authenticated', False): + return auth + logger.warning("Earthdata authentication did not succeed") + return None except Exception as e: - logging.error(f"Error searching or downloading SMAP data: {e}") - logging.error("Traceback: ", exc_info=True) - if temp_dir: - shutil.rmtree(temp_dir) - return None, None + logger.warning("Earthdata authentication error: %s", e) + return None -def extract_soil_moisture(hdf_file, polygon, max_distance=0.1): + +def _get_huc8_polygon(lat: float, lon: float): + """ + HUC8 polygon (list of (lon, lat) tuples) enclosing the point, simplified. + Returns None when the WBD lookup fails or the simplification dependencies + aren't installed. + """ try: - with h5py.File(hdf_file, 'r') as file: - for dataset_name in file: - if 'Soil_Moisture_Retrieval_Data' in dataset_name: - soil_moisture = file[f'{dataset_name}/soil_moisture'][:] - lat = file[f'{dataset_name}/latitude'][:] - lon = file[f'{dataset_name}/longitude'][:] - logging.info(f"Found soil moisture data in {dataset_name}") - break - else: - logging.error("Could not find soil moisture data in the file") - return None - - polygon_obj = Polygon(polygon) - minx, miny, maxx, maxy = polygon_obj.bounds - - # Start with the polygon bounds and gradually expand - for distance in np.linspace(0, max_distance, 5): - expanded_bounds = (minx-distance, miny-distance, maxx+distance, maxy+distance) - mask = (lon >= expanded_bounds[0]) & (lon <= expanded_bounds[2]) & \ - (lat >= expanded_bounds[1]) & (lat <= expanded_bounds[3]) - - lons = lon[mask] - lats = lat[mask] - soil_moisture_masked = soil_moisture[mask] - - valid = (soil_moisture_masked != -9999.0) & (lons != -9999.0) & (lats != -9999.0) - lons = lons[valid] - lats = lats[valid] - soil_moisture_valid = soil_moisture_masked[valid] - - logging.info(f"Found {len(lons)} valid points within expanded bounds (distance: {distance})") - - if len(lons) > 0: - tree = cKDTree(np.column_stack((lons, lats))) - expanded_polygon = polygon_obj.buffer(distance) - mask_polygon = tree.query_ball_point(expanded_polygon.exterior.coords, r=0.01) - mask_polygon = np.unique(np.concatenate(mask_polygon)) - - soil_moisture_in_polygon = soil_moisture_valid[mask_polygon] - - if len(soil_moisture_in_polygon) > 0: - average_moisture = np.mean(soil_moisture_in_polygon) - logging.info(f"Found {len(soil_moisture_in_polygon)} points inside or near the polygon") - logging.info(f"Average soil moisture: {average_moisture:.4f}") - return average_moisture, expanded_polygon - - logging.warning("No valid soil moisture data found within or near the polygon") - return None, None + from data.utils.get_poly import get_huc_polygon, simplify_polygon + except ImportError as e: + logger.warning("Polygon utilities unavailable: %s", e) + return None + result = get_huc_polygon(lat, lon, huc_level=8) + if not result: + return None + polygon, _huc_id, _attributes = result + if not polygon: + return None + return simplify_polygon(polygon) + +def _search_granules(polygon, start_date, end_date): + """SMAP granules intersecting the polygon's bbox over the date window.""" + try: + import earthaccess + except ImportError as e: + logger.warning("earthaccess not installed: %s", e) + return [] + lons, lats = zip(*polygon) + bbox = (min(lons), min(lats), max(lons), max(lats)) + try: + granules = earthaccess.search_data( + short_name=SMAP_SHORT_NAME, + version=SMAP_VERSION, + temporal=(_to_date(start_date), _to_date(end_date)), + bounding_box=bbox, + ) except Exception as e: - logging.error(f"Error extracting soil moisture data: {e}") - logging.error("Traceback: ", exc_info=True) - return None, None + logger.warning("SMAP granule search failed: %s", e) + return [] + return list(granules) if granules else [] + -# Add this function to check data availability in the polygon area -def check_data_availability(hdf_file, polygon): +def _granule_date(granule) -> Optional[date]: + """ + Observation date for a granule. SMAP L3 filenames embed YYYYMMDD + (e.g. SMAP_L3_SM_P_E_20240115_R18290_001.h5), so we lift it from there + and fall back to the granule's temporal metadata if the filename pattern + doesn't match. + """ try: - with h5py.File(hdf_file, 'r') as file: - for time_of_day in ['AM', 'PM']: - try: - soil_moisture = file[f'Soil_Moisture_Retrieval_Data_{time_of_day}/soil_moisture'][:] - lat = file[f'Soil_Moisture_Retrieval_Data_{time_of_day}/latitude'][:] - lon = file[f'Soil_Moisture_Retrieval_Data_{time_of_day}/longitude'][:] - break - except KeyError: - continue - else: - logging.error("Could not find soil moisture data") - return - - # Get polygon bounds - poly = Polygon(polygon) - min_lon, min_lat, max_lon, max_lat = poly.bounds - - # Create masks for the area of interest and its surroundings - area_mask = (lat >= min_lat) & (lat <= max_lat) & (lon >= min_lon) & (lon <= max_lon) - surrounding_mask = (lat >= min_lat-1) & (lat <= max_lat+1) & (lon >= min_lon-1) & (lon <= max_lon+1) - - # Check for valid data - valid_data_mask = (soil_moisture != -9999.0) - - # Calculate statistics - total_points = np.sum(area_mask) - valid_points = np.sum(valid_data_mask & area_mask) - surrounding_valid_points = np.sum(valid_data_mask & surrounding_mask) - - logging.info(f"Total points in area of interest: {total_points}") - logging.info(f"Valid data points in area of interest: {valid_points}") - logging.info(f"Percentage of valid data in area: {valid_points/total_points*100:.2f}%") - logging.info(f"Valid data points in surrounding area: {surrounding_valid_points}") - - if valid_points == 0: - logging.warning("No valid data points found within the polygon.") - if surrounding_valid_points > 0: - logging.info("However, valid data points found in the surrounding area.") - else: - logging.warning("No valid data points found in the surrounding area either.") + links = granule.data_links() if hasattr(granule, 'data_links') else [] + for link in links or []: + name = link.rsplit('/', 1)[-1] + for token in name.split('_'): + if len(token) == 8 and token.isdigit(): + try: + return datetime.strptime(token, "%Y%m%d").date() + except ValueError: + continue + except Exception: + pass + try: + umm = granule.get("umm", {}) if hasattr(granule, 'get') else {} + beg = (umm.get("TemporalExtent", {}) + .get("RangeDateTime", {}) + .get("BeginningDateTime")) + if beg: + return datetime.strptime(beg[:10], "%Y-%m-%d").date() + except Exception: + pass + return None + + +def _extract_polygon_mean(hdf_path: str, polygon) -> Optional[float]: + """ + Average volumetric soil moisture across all valid SMAP pixels inside the + polygon's bounding box, across both AM and PM passes. Returns None when + no valid pixel intersects (cloudy / RFI day, or polygon entirely off the + EASE-grid for that pass). + + Bbox is sufficient at SMAP's 9 km grid spacing -- the HUC8 footprint is + rarely much larger than a handful of cells and the bbox vs true polygon + distinction is below the retrieval noise floor. + """ + try: + import h5py + import numpy as np + except ImportError as e: + logger.warning("h5py/numpy unavailable: %s", e) + return None + + lons, lats = zip(*polygon) + min_lon, max_lon = min(lons), max(lons) + min_lat, max_lat = min(lats), max(lats) + + # SMAP L3 v006 splits AM and PM passes into separate top-level groups. + pass_groups = [ + 'Soil_Moisture_Retrieval_Data_AM', + 'Soil_Moisture_Retrieval_Data_PM', + 'Soil_Moisture_Retrieval_Data', + ] + collected = [] + try: + with h5py.File(hdf_path, 'r') as f: + for grp_name in pass_groups: + if grp_name not in f: + continue + grp = f[grp_name] + if not all(k in grp for k in ('soil_moisture', 'latitude', 'longitude')): + continue + sm = grp['soil_moisture'][:] + lat = grp['latitude'][:] + lon = grp['longitude'][:] + in_bbox = ((lon >= min_lon) & (lon <= max_lon) & + (lat >= min_lat) & (lat <= max_lat)) + valid = ((sm != FILL_VALUE) & (sm >= VALID_MIN) & (sm <= VALID_MAX) & + (lon != FILL_VALUE) & (lat != FILL_VALUE)) + mask = in_bbox & valid + if mask.any(): + collected.extend(sm[mask].astype(float).tolist()) except Exception as e: - logging.error(f"Error checking data availability: {e}") + logger.warning("Failed to read %s: %s", hdf_path, e) + return None + + if not collected: + return None + return float(sum(collected) / len(collected)) + -def list_nsidc_collections(): +def _download_granule(granule, tmpdir): + """earthaccess.download wrapper that returns the local file path or None.""" try: - nsidc_query = earthaccess.collection_query().daac("NSIDC") - collections = nsidc_query.get() - - logging.info(f"Found {len(collections)} collections from NSIDC-DAAC:") - for collection in collections: - logging.info(f"- Short Name: {collection['umm']['ShortName']}, Version: {collection['umm'].get('Version', 'N/A')}") - - return collections + import earthaccess + except ImportError: + return None + try: + files = earthaccess.download([granule], local_path=tmpdir) except Exception as e: - logging.error(f"Error listing NSIDC collections: {e}") + logger.warning("Granule download failed: %s", e) return None + if not files: + return None + return files[0] + -def visualize_smap_and_polygon(hdf_file, polygon): +def main(lat: float, lon: float, start_date, end_date) -> pd.DataFrame: + """ + Fetch the SMAP L3 enhanced soil-moisture timeseries for the HUC8 enclosing + (lat, lon) over [start_date, end_date]. + + Returns a DataFrame with columns ['Date', 'soil_moisture'] where + soil_moisture is volumetric m^3/m^3 in [0, 1]. Multiple AM/PM passes on + the same day are averaged. Empty DataFrame on any failure -- the + combine_data spine treats missing soil moisture the same way it treats + missing SWE (defaults to 0, never drops the row). + """ try: - with h5py.File(hdf_file, 'r') as file: - for dataset_name in file: - if 'Soil_Moisture_Retrieval_Data' in dataset_name: - soil_moisture = file[f'{dataset_name}/soil_moisture'][:] - lat = file[f'{dataset_name}/latitude'][:] - lon = file[f'{dataset_name}/longitude'][:] - break - else: - logging.error("Could not find soil moisture data in the file") - return - - # Create a mask for valid data - valid_mask = (soil_moisture != -9999.0) & (lat != -9999.0) & (lon != -9999.0) - - fig, ax = plt.subplots(figsize=(12, 8)) - - # Plot SMAP data points - sc = ax.scatter(lon[valid_mask], lat[valid_mask], c=soil_moisture[valid_mask], - cmap='viridis', s=1, alpha=0.5) - plt.colorbar(sc, label='Soil Moisture') - - # Plot the polygon - poly = mplPolygon(polygon, facecolor='none', edgecolor='red', linewidth=2) - ax.add_patch(poly) - - # Set the extent to focus on the area around the polygon - poly_bounds = Polygon(polygon).bounds - ax.set_xlim(poly_bounds[0] - 1, poly_bounds[2] + 1) - ax.set_ylim(poly_bounds[1] - 1, poly_bounds[3] + 1) - - plt.title('SMAP Data and Polygon') - plt.xlabel('Longitude') - plt.ylabel('Latitude') - plt.show() - + polygon = _get_huc8_polygon(lat, lon) except Exception as e: - logging.error(f"Error visualizing SMAP data and polygon: {e}") - -def main(start_date, end_date, lat, lon, visual): - auth = get_earthdata_auth() - - huc8_polygon = get_huc_polygon(lat, lon, huc_level=8) - if not huc8_polygon: - logging.error("Failed to retrieve HUC8 polygon") - return - - simplified_polygon = simplify_polygon(huc8_polygon) - logging.info(f"Simplified polygon coordinates: {simplified_polygon}") - - downloaded_file, temp_dir = search_and_download_smap_data(start_date, end_date, auth, simplified_polygon) - - if downloaded_file and temp_dir: - try: - # Visualize SMAP data and polygon - visualize_smap_and_polygon(downloaded_file, simplified_polygon) - - # Extract soil moisture data - average_soil_moisture, used_polygon = extract_soil_moisture(downloaded_file, simplified_polygon) - - if average_soil_moisture is not None: - logging.info(f"Average soil moisture: {average_soil_moisture:.4f}") - if visual and matplotlib_available: + logger.warning("HUC8 polygon lookup failed: %s", e) + return _empty() + if not polygon: + logger.warning("No HUC8 polygon found for (%s, %s)", lat, lon) + return _empty() + + if _login_earthdata() is None: + return _empty() + + granules = _search_granules(polygon, start_date, end_date) + if not granules: + logger.warning("No SMAP granules found in [%s, %s] for the polygon", + start_date, end_date) + return _empty() + logger.info("Found %d SMAP granules", len(granules)) + + rows = [] + with tempfile.TemporaryDirectory(prefix='smap_') as tmpdir: + for granule in granules: + obs_date = _granule_date(granule) + if obs_date is None: + continue + path = _download_granule(granule, tmpdir) + if not path: + continue + try: + value = _extract_polygon_mean(path, polygon) + finally: + try: + os.remove(path) + except OSError: pass - #visualize_soil_moisture_simple(used_polygon, average_soil_moisture) - else: - logging.error("Failed to calculate soil moisture.") - - finally: - # Clean up: remove the temporary directory - shutil.rmtree(temp_dir) - logging.info(f"Removed temporary directory: {temp_dir}") - else: - logging.error("Failed to find or download SMAP data") + if value is None: + continue + rows.append((obs_date.strftime('%Y-%m-%d'), value)) + + if not rows: + return _empty() + + df = pd.DataFrame(rows, columns=['Date', 'soil_moisture']) + df['soil_moisture'] = pd.to_numeric(df['soil_moisture'], errors='coerce') + # AM + PM passes on the same date collapse to a single daily mean. + daily = df.groupby('Date', as_index=False)['soil_moisture'].mean() + return daily.sort_values('Date').reset_index(drop=True) + + if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Calculate average soil moisture for a HUC8 polygon from SMAP L3 data.') - parser.add_argument('--start-date', type=lambda d: datetime.datetime.strptime(d, '%Y-%m-%d').date(), required=True, help='Start Date in YYYY-MM-DD format') - parser.add_argument('--end-date', type=lambda d: datetime.datetime.strptime(d, '%Y-%m-%d').date(), required=True, help='End Date in YYYY-MM-DD format') - parser.add_argument('--lat', type=float, required=True, help='Latitude of the point within the desired HUC8 polygon') - parser.add_argument('--lon', type=float, required=True, help='Longitude of the point within the desired HUC8 polygon') - parser.add_argument('--visual', action='store_true', help='Enable matplotlib visualization') + parser = argparse.ArgumentParser( + description='Fetch SMAP L3 enhanced soil moisture timeseries by HUC8.') + parser.add_argument('--lat', type=float, required=True) + parser.add_argument('--lon', type=float, required=True) + parser.add_argument('--start-date', type=str, required=True, help='YYYY-MM-DD') + parser.add_argument('--end-date', type=str, required=True, help='YYYY-MM-DD') args = parser.parse_args() - - main(args.start_date, args.end_date, args.lat, args.lon, args.visual) \ No newline at end of file + df = main(args.lat, args.lon, args.start_date, args.end_date) + if df.empty: + print("No data") + else: + print(df.to_string(index=False)) diff --git a/openFlowML/data/soilmoisture.py b/openFlowML/data/soilmoisture.py deleted file mode 100644 index 84ccf8b..0000000 --- a/openFlowML/data/soilmoisture.py +++ /dev/null @@ -1,26 +0,0 @@ -import requests -import json - -def fetch_all_table_names(): - base_url = "https://SDMDataAccess.sc.egov.usda.gov" - headers = {"Content-Type": "application/json"} - - query_url = f"{base_url}/Tabular/post.rest" - # Query to fetch all table names from the information schema - sql_query = "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES" - query_data = { - "SERVICE": "query", - "REQUEST": "query", - "QUERY": sql_query, - "FORMAT": "JSON" - } - response = requests.post(query_url, headers=headers, data=json.dumps(query_data)) - - if response.status_code == 200: - return response.json() - else: - return f"Failed to fetch table names: {response.text}" - -# Execute the function to get all table names -result = fetch_all_table_names() -print(result) diff --git a/openFlowML/data/soilmoisture2.py b/openFlowML/data/soilmoisture2.py deleted file mode 100644 index 0dbd0f7..0000000 --- a/openFlowML/data/soilmoisture2.py +++ /dev/null @@ -1,98 +0,0 @@ -import asyncio -import xarray as xr -import pandas as pd -from shapely.geometry import Polygon -from pyproj import CRS -import fsspec -import dask -from dask.distributed import Client -import argparse -import logging -from data.utils.get_poly import get_huc8_polygon, simplify_polygon - -# Configure logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') - -async def fetch_soil_moisture_data(polygon: Polygon, date_range: pd.date_range) -> xr.Dataset: - """ - Asynchronously fetch CPC soil moisture data for a given polygon and date range. - - Args: - polygon (Polygon): Shapely Polygon object defining the area of interest - date_range (pd.date_range): Date range for data retrieval - - Returns: - xr.Dataset: Dataset containing soil moisture data for the specified region and time - """ - # CPC soil moisture data is available via Google Cloud Storage - gcs_url = "gs://noaa-cpc-pds/soil-moisture/" - - async with fsspec.open_files(gcs_url + "*.nc", mode='rb') as files: - datasets = [xr.open_dataset(file, engine='h5netcdf', chunks={'time': 'auto'}) - for file in files if file.start.date() in date_range] - - combined_data = xr.concat(datasets, dim='time') - - # Ensure CRS matches the polygon CRS (assuming WGS84) - data_crs = CRS.from_epsg(4326) - combined_data = combined_data.rio.write_crs(data_crs) - - # Clip data to polygon - clipped_data = combined_data.rio.clip([polygon], all_touched=True) - - return clipped_data - -async def process_soil_moisture_data(data: xr.Dataset) -> pd.DataFrame: - """ - Process the retrieved soil moisture data. - - Args: - data (xr.Dataset): Dataset containing soil moisture data - - Returns: - pd.DataFrame: Processed soil moisture data - """ - # Example processing - adjust as needed - mean_moisture = data['soilw'].mean(dim=['lat', 'lon']) - return mean_moisture.to_dataframe() - -async def main(lat: float, lon: float, start_date: str, end_date: str): - # Set up dask client for parallel processing - client = Client(n_workers=4, threads_per_worker=2, memory_limit='4GB') - - # Get HUC8 polygon - huc8_polygon = get_huc8_polygon(lat, lon) - if not huc8_polygon: - logging.error("No HUC8 polygon found") - await client.close() - return - - # Simplify the polygon - simplified_polygon = simplify_polygon(huc8_polygon) - logging.info(f"Simplified polygon: {simplified_polygon}") - - # Conve - # rt to Shapely Polygon - polygon = Polygon(simplified_polygon) - - date_range = pd.date_range(start=start_date, end=end_date, freq='D') - - data = await fetch_soil_moisture_data(polygon, date_range) - - # Process data using dask for parallelization - processed_data = await dask.compute(process_soil_moisture_data(data))[0] - - print(processed_data) - - # Clean up dask client - await client.close() - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Fetch soil moisture data for HUC8 polygon based on lat/lon.') - parser.add_argument('--lat', type=float, required=True, help='Latitude') - parser.add_argument('--lon', type=float, required=True, help='Longitude') - parser.add_argument('--start_date', type=str, required=True, help='Start date (YYYY-MM-DD)') - parser.add_argument('--end_date', type=str, required=True, help='End date (YYYY-MM-DD)') - args = parser.parse_args() - - asyncio.run(main(args.lat, args.lon, args.start_date, args.end_date)) \ No newline at end of file diff --git a/openFlowML/data/utils/data_utils.py b/openFlowML/data/utils/data_utils.py index 0924662..fc2aee3 100644 --- a/openFlowML/data/utils/data_utils.py +++ b/openFlowML/data/utils/data_utils.py @@ -2,18 +2,14 @@ import os import time import requests -from datetime import datetime, timedelta, timezone +from datetime import timedelta -# NOTE: h5py, python-dotenv and earthaccess are heavy, optional dependencies -# used only by the soil-moisture data modules. They are imported lazily inside -# the functions that need them so that lightweight consumers (e.g. combine_data -# importing preview_data) work with only the core requirements installed. +# NOTE: python-dotenv is a heavy optional dependency used only by load_vars +# (credential loading for the soil-moisture path). It is imported lazily so +# lightweight consumers like combine_data work with only the core requirements. -# Global variables to store token and expiration -_token = None -_expiration = None - # Additional function to display the beginning and ending of the dataframe +# Additional function to display the beginning and ending of the dataframe def preview_data(df, num_rows=4): logging.info("First few rows:") logging.info(df.head(num_rows)) @@ -59,117 +55,6 @@ def date_chunks(start_date, end_date, max_days=366): yield current, chunk_end current = chunk_end + timedelta(days=1) -def get_earthdata_auth(): - """ - Create and return an authenticated earthaccess Auth instance. - """ - from earthaccess import Auth - auth = Auth() - - username = os.getenv("EARTHDATA_USERNAME") - password = os.getenv("EARTHDATA_PASSWORD") - - logging.info(f"Attempting to authenticate with username: {username}") - - if username and password: - if auth.login(strategy="environment"): - logging.info("Successfully authenticated using environment variables") - return auth - else: - logging.warning("Authentication failed using environment variables") - else: - logging.warning("Environment variables not set or empty") - - raise RuntimeError("Failed to authenticate with NASA Earthdata Login") - - -def appeears_login(): - """ - Log in to AppEEARS and obtain a token. - """ - global _token, _expiration - url = "https://appeears.earthdatacloud.nasa.gov/api/login" - username = os.getenv("EARTHDATA_USERNAME") - password = os.getenv("EARTHDATA_PASSWORD") - - if not username or not password: - raise ValueError("EARTHDATA_USERNAME and EARTHDATA_PASSWORD environment variables must be set.") - - try: - response = requests.post(url, auth=(username, password)) - response.raise_for_status() - data = response.json() - _token = data['token'] - # Parse the expiration date string manually - expiration_str = data['expiration'] - # Remove the 'Z' at the end and split the string - date_part, time_part = expiration_str[:-1].split('T') - year, month, day = map(int, date_part.split('-')) - hour, minute, second = map(int, time_part.split(':')) - _expiration = datetime(year, month, day, hour, minute, second, tzinfo=timezone.utc) - return _token - except requests.RequestException as e: - print(f"Login failed: {e}") - return None - -def appeears_logout(): - """ - Log out from AppEEARS and invalidate the current token. - """ - global _token, _expiration - if not _token: - print("No active session to log out from.") - return True - - url = "https://appeears.earthdatacloud.nasa.gov/api/logout" - headers = {'Authorization': f'Bearer {_token}'} - try: - response = requests.post(url, headers=headers) - if response.status_code == 204: - _token = None - _expiration = None - return True - else: - print(f"Logout failed: {response.text}") - return False - except requests.RequestException as e: - print(f"Logout failed: {e}") - return False - -def get_smap_data_bounds(hdf_file): - """ - Get the actual bounding box of the SMAP data from the HDF file, excluding fill values. - """ - import h5py - try: - with h5py.File(hdf_file, 'r') as file: - for time_of_day in ['AM', 'PM']: - try: - lat_dataset = file[f'Soil_Moisture_Retrieval_Data_{time_of_day}/latitude'] - lon_dataset = file[f'Soil_Moisture_Retrieval_Data_{time_of_day}/longitude'] - - # Filter out fill values (assuming -9999.0 is the fill value) - valid_lat = lat_dataset[lat_dataset[:] != -9999.0] - valid_lon = lon_dataset[lon_dataset[:] != -9999.0] - - if len(valid_lat) > 0 and len(valid_lon) > 0: - min_lat = valid_lat.min() - max_lat = valid_lat.max() - min_lon = valid_lon.min() - max_lon = valid_lon.max() - - logging.info(f"Actual SMAP data bounds: Lon ({min_lon}, {max_lon}), Lat ({min_lat}, {max_lat})") - return (min_lon, min_lat, max_lon, max_lat) - else: - logging.warning(f"No valid data found for {time_of_day}") - except KeyError: - continue - logging.error("Could not find valid latitude or longitude data in the file") - return None - except Exception as e: - logging.error(f"Error getting SMAP data bounds: {e}") - return None - def load_vars(): """ @@ -180,14 +65,18 @@ def load_vars(): kill the interpreter just because credentials are absent (e.g. in CI or when running the test suite). """ - from dotenv import load_dotenv + try: + from dotenv import load_dotenv + except ImportError: + logging.warning("python-dotenv not installed; skipping creds.env load") + load_dotenv = None script_dir = os.path.dirname(os.path.abspath(__file__)) cred_env_path = os.path.join(script_dir, 'creds.env') - if os.path.exists(cred_env_path): + if load_dotenv is not None and os.path.exists(cred_env_path): load_dotenv(cred_env_path) logging.info(f"Loaded environment variables from {cred_env_path}") - else: + elif load_dotenv is not None: logging.warning(f"creds.env file not found at {cred_env_path}") required = ["EARTHDATA_USERNAME", "EARTHDATA_PASSWORD", diff --git a/openFlowML/normalize_data.py b/openFlowML/normalize_data.py index a8fdc92..8ae0104 100644 --- a/openFlowML/normalize_data.py +++ b/openFlowML/normalize_data.py @@ -27,9 +27,10 @@ # Required: rows missing any of these are dropped (no pooled-mean fill). CORE_REQUIRED = ['TMIN', 'TMAX', 'Min Flow', 'Max Flow'] -# Optional columns that get scaled when present. SWE is slow-varying; if it's -# missing for a row, default it to 0 rather than dropping the row. -OPTIONAL_NUMERIC = ['SWE'] +# Optional columns that get scaled when present. SWE and soil_moisture are +# slow-varying; if either is missing for a row, default to 0 rather than +# dropping the row. +OPTIONAL_NUMERIC = ['SWE', 'soil_moisture'] NUMERIC_COLUMNS = CORE_REQUIRED + OPTIONAL_NUMERIC # Streamflow is log-normal -- log1p before z-scoring is standard hydrology # practice. Temperature/SWE stay linear. @@ -145,11 +146,14 @@ def normalize_data(data, artifacts_dir=None): for column in present_numeric: data[column] = pd.to_numeric(data[column], errors='coerce') - # SWE is slowly varying and often legitimately zero -- any remaining - # missing value here defaults to 0 ("no snow data") instead of forcing - # the row out, so a station with no nearby SNOTEL still contributes. + # SWE and soil_moisture are slowly varying and often legitimately near + # zero -- any remaining missing value here defaults to 0 instead of + # forcing the row out, so a station with no nearby SNOTEL / no SMAP + # retrieval that day still contributes. if 'SWE' in data.columns: data['SWE'] = data['SWE'].fillna(0.0) + if 'soil_moisture' in data.columns: + data['soil_moisture'] = data['soil_moisture'].fillna(0.0) # combine_data owns per-station gap handling for flow + temperature; # anything still missing in the core columns here is dropped rather diff --git a/openFlowML/train.py b/openFlowML/train.py index bbd6993..6712df3 100644 --- a/openFlowML/train.py +++ b/openFlowML/train.py @@ -81,7 +81,8 @@ def main(): # 2. Sanity-check spine outputs the rest of training relies on. required = {'station_idx', 'basin_idx', 'site_id', 'Date', - 'Min Flow', 'Max Flow', 'TMIN', 'TMAX', 'SWE', + 'Min Flow', 'Max Flow', 'TMIN', 'TMAX', + 'SWE', 'soil_moisture', 'doy_sin', 'doy_cos'} missing = required - set(data.columns) if missing: diff --git a/openFlowML/windowing.py b/openFlowML/windowing.py index 1e5399c..9a9dfd2 100644 --- a/openFlowML/windowing.py +++ b/openFlowML/windowing.py @@ -25,11 +25,16 @@ logger = logging.getLogger(__name__) # Encoder window: everything we know up to the prediction time. Flow is here -# (these are observations), and SWE -- current snowpack is a strong predictor -# of snowmelt-fed runoff in Colorado. -ENCODER_FEATURES = ['Min Flow', 'Max Flow', 'TMIN', 'TMAX', 'SWE', 'doy_sin', 'doy_cos'] +# (these are observations); SWE -- current snowpack is a strong predictor of +# snowmelt-fed runoff in Colorado; soil_moisture (SMAP L3 enhanced) is the +# antecedent wetness state that gates how much new precipitation becomes runoff +# vs infiltrates. +ENCODER_FEATURES = ['Min Flow', 'Max Flow', 'TMIN', 'TMAX', + 'SWE', 'soil_moisture', + 'doy_sin', 'doy_cos'] # Decoder window: ONLY features available at forecast time. No flow (that's -# what we're predicting), no SWE (no skillful 14-day SWE forecast exists). +# what we're predicting), no SWE / no soil_moisture (neither has a skillful +# 14-day forecast available). DECODER_FEATURES = ['TMIN', 'TMAX', 'doy_sin', 'doy_cos'] # Target: log-z-scored flow during the decoder days. Both columns are already # log1p-transformed and z-scored by normalize_data, so MSE/Huber here behaves. diff --git a/tests/test_combine_data.py b/tests/test_combine_data.py index 131f570..17e7b4a 100644 --- a/tests/test_combine_data.py +++ b/tests/test_combine_data.py @@ -111,6 +111,37 @@ def test_merge_defaults_swe_to_zero_when_not_provided(): assert (merged['SWE'] == 0.0).all() +def test_merge_attaches_interpolated_soil_moisture_when_provided(): + noaa, flow = _make_frames() + dates = pd.date_range('2022-01-01', '2022-01-20', freq='D') + # Sparse SMAP retrievals -- only every 3rd day, like a real revisit cadence. + sm = pd.DataFrame({'Date': dates[::3], + 'soil_moisture': [0.15, 0.20, 0.25, 0.30, 0.25, 0.20, 0.15]}) + merged = combine_data.merge_dataframes( + noaa, flow, 'USGS:TEST', + datetime(2022, 1, 1), datetime(2022, 1, 20), + sm_data=sm) + assert 'soil_moisture' in merged.columns + assert not merged['soil_moisture'].isnull().any() + interior = merged.loc[ + pd.to_datetime(merged['Date']) <= pd.Timestamp('2022-01-19'), + 'soil_moisture'] + assert interior.max() <= 0.30 and interior.min() >= 0.15 + trailing = merged.loc[ + pd.to_datetime(merged['Date']) > pd.Timestamp('2022-01-19'), + 'soil_moisture'] + # Trailing rows past the last known SM date default to 0 (same as SWE). + assert (trailing == 0.0).all() + + +def test_merge_defaults_soil_moisture_to_zero_when_not_provided(): + noaa, flow = _make_frames() + merged = combine_data.merge_dataframes( + noaa, flow, 'USGS:TEST', datetime(2022, 1, 1), datetime(2022, 1, 20)) + assert 'soil_moisture' in merged.columns + assert (merged['soil_moisture'] == 0.0).all() + + def test_merge_records_huc8_when_provided(): noaa, flow = _make_frames() merged = combine_data.merge_dataframes( diff --git a/tests/test_nasa_moisture.py b/tests/test_nasa_moisture.py new file mode 100644 index 0000000..97ad1cb --- /dev/null +++ b/tests/test_nasa_moisture.py @@ -0,0 +1,232 @@ +""" +Tests for the consolidated SMAP soil-moisture fetcher (nasa_moisture.main). + +We mock at the earthaccess + polygon-lookup boundaries so the test runs offline +and deterministically; the on-disk extraction is exercised by writing a tiny +real HDF5 file and reading it back through _extract_polygon_mean. +""" + +from datetime import date, datetime + +import numpy as np +import pandas as pd +import pytest + +# h5py is part of requirements-data.txt; skip if not installed. +h5py = pytest.importorskip('h5py') + +from data import nasa_moisture + + +# Polygon big enough to contain the synthetic grids used below. +_POLYGON = [(-106.5, 39.5), (-104.5, 39.5), (-104.5, 41.5), (-106.5, 41.5)] + + +def _write_smap_h5(path, sm_am, sm_pm=None): + """Write a minimal SMAP-shaped HDF5 file at `path`.""" + lat = np.array([ + [40.0, 40.0, 40.0], + [40.5, 40.5, 40.5], + [41.0, 41.0, 41.0], + ]) + lon = np.array([ + [-106.0, -105.5, -105.0], + [-106.0, -105.5, -105.0], + [-106.0, -105.5, -105.0], + ]) + with h5py.File(path, 'w') as f: + am = f.create_group('Soil_Moisture_Retrieval_Data_AM') + am.create_dataset('soil_moisture', data=sm_am) + am.create_dataset('latitude', data=lat) + am.create_dataset('longitude', data=lon) + if sm_pm is not None: + pm = f.create_group('Soil_Moisture_Retrieval_Data_PM') + pm.create_dataset('soil_moisture', data=sm_pm) + pm.create_dataset('latitude', data=lat) + pm.create_dataset('longitude', data=lon) + + +class _FakeGranule: + """Stand-in for an earthaccess granule with the bits nasa_moisture uses.""" + + def __init__(self, filename): + self._filename = filename + + def data_links(self): + return [f"https://example.test/{self._filename}"] + + +def test_extract_polygon_mean_averages_valid_pixels_only(tmp_path): + path = str(tmp_path / 'sm.h5') + sm = np.array([ + [-9999.0, 0.25, 0.30], + [0.20, 0.35, 0.40], + [-9999.0, 0.45, 0.50], + ]) + _write_smap_h5(path, sm) + value = nasa_moisture._extract_polygon_mean(path, _POLYGON) + expected = (0.25 + 0.30 + 0.20 + 0.35 + 0.40 + 0.45 + 0.50) / 7 + assert value == pytest.approx(expected) + + +def test_extract_polygon_mean_returns_none_when_all_fill(tmp_path): + path = str(tmp_path / 'sm.h5') + sm = np.full((3, 3), -9999.0) + _write_smap_h5(path, sm) + assert nasa_moisture._extract_polygon_mean(path, _POLYGON) is None + + +def test_extract_polygon_mean_combines_am_and_pm(tmp_path): + path = str(tmp_path / 'sm.h5') + sm_am = np.full((3, 3), 0.20) + sm_pm = np.full((3, 3), 0.40) + _write_smap_h5(path, sm_am, sm_pm=sm_pm) + # 9 AM pixels at 0.2 + 9 PM pixels at 0.4 -> grand mean of 0.30. + assert nasa_moisture._extract_polygon_mean(path, _POLYGON) == pytest.approx(0.30) + + +def test_extract_polygon_mean_clips_to_valid_range(tmp_path): + path = str(tmp_path / 'sm.h5') + sm = np.array([ + [0.30, 0.40, 5.0], # 5.0 > VALID_MAX -- rejected + [0.20, -0.10, 0.50], # -0.10 < VALID_MIN -- rejected + [0.10, 0.20, 0.30], + ]) + _write_smap_h5(path, sm) + value = nasa_moisture._extract_polygon_mean(path, _POLYGON) + kept = [0.30, 0.40, 0.20, 0.50, 0.10, 0.20, 0.30] + assert value == pytest.approx(sum(kept) / len(kept)) + + +def test_granule_date_parses_from_filename(): + g = _FakeGranule('SMAP_L3_SM_P_E_20240115_R18290_001.h5') + assert nasa_moisture._granule_date(g) == date(2024, 1, 15) + + +def test_granule_date_returns_none_when_no_pattern(): + g = _FakeGranule('arbitrary_filename.h5') + assert nasa_moisture._granule_date(g) is None + + +def test_main_returns_empty_when_polygon_lookup_fails(monkeypatch): + monkeypatch.setattr(nasa_moisture, '_get_huc8_polygon', lambda lat, lon: None) + out = nasa_moisture.main(40.0, -105.0, '2024-01-01', '2024-01-03') + assert out.empty + assert list(out.columns) == ['Date', 'soil_moisture'] + + +def test_main_returns_empty_when_auth_fails(monkeypatch): + monkeypatch.setattr(nasa_moisture, '_get_huc8_polygon', lambda lat, lon: _POLYGON) + monkeypatch.setattr(nasa_moisture, '_login_earthdata', lambda: None) + out = nasa_moisture.main(40.0, -105.0, '2024-01-01', '2024-01-03') + assert out.empty + + +def test_main_returns_empty_when_no_granules(monkeypatch): + monkeypatch.setattr(nasa_moisture, '_get_huc8_polygon', lambda lat, lon: _POLYGON) + monkeypatch.setattr(nasa_moisture, '_login_earthdata', lambda: object()) + monkeypatch.setattr(nasa_moisture, '_search_granules', + lambda poly, s, e: []) + out = nasa_moisture.main(40.0, -105.0, '2024-01-01', '2024-01-03') + assert out.empty + + +def test_main_end_to_end_with_mocked_search_and_download(monkeypatch, tmp_path): + # Two distinct dates, each with a known mean. Build the granule fixtures. + g1_path = str(tmp_path / 'SMAP_L3_SM_P_E_20240115_R18290_001.h5') + g2_path = str(tmp_path / 'SMAP_L3_SM_P_E_20240116_R18290_001.h5') + _write_smap_h5(g1_path, np.full((3, 3), 0.30)) + _write_smap_h5(g2_path, np.full((3, 3), 0.50)) + + g1 = _FakeGranule('SMAP_L3_SM_P_E_20240115_R18290_001.h5') + g2 = _FakeGranule('SMAP_L3_SM_P_E_20240116_R18290_001.h5') + + download_map = {id(g1): g1_path, id(g2): g2_path} + + monkeypatch.setattr(nasa_moisture, '_get_huc8_polygon', lambda lat, lon: _POLYGON) + monkeypatch.setattr(nasa_moisture, '_login_earthdata', lambda: object()) + monkeypatch.setattr(nasa_moisture, '_search_granules', + lambda poly, s, e: [g1, g2]) + + def fake_download(granule, tmpdir): + # Copy fixture into the temp dir so the real cleanup path runs cleanly. + import shutil + src = download_map[id(granule)] + dst = f"{tmpdir}/{src.rsplit('/', 1)[-1]}" + shutil.copy(src, dst) + return dst + + monkeypatch.setattr(nasa_moisture, '_download_granule', fake_download) + + out = nasa_moisture.main(40.0, -105.0, '2024-01-15', '2024-01-16') + assert list(out.columns) == ['Date', 'soil_moisture'] + assert len(out) == 2 + assert out.iloc[0]['Date'] == '2024-01-15' + assert out.iloc[0]['soil_moisture'] == pytest.approx(0.30) + assert out.iloc[1]['Date'] == '2024-01-16' + assert out.iloc[1]['soil_moisture'] == pytest.approx(0.50) + + +def test_main_collapses_multiple_granules_on_same_date(monkeypatch, tmp_path): + # Two granules tagged with the same date -- the daily series collapses + # them via mean (single row, averaged value). + a_path = str(tmp_path / 'SMAP_L3_SM_P_E_20240120_pass_a_001.h5') + b_path = str(tmp_path / 'SMAP_L3_SM_P_E_20240120_pass_b_001.h5') + _write_smap_h5(a_path, np.full((3, 3), 0.20)) + _write_smap_h5(b_path, np.full((3, 3), 0.40)) + + g_a = _FakeGranule(a_path.rsplit('/', 1)[-1]) + g_b = _FakeGranule(b_path.rsplit('/', 1)[-1]) + download_map = {id(g_a): a_path, id(g_b): b_path} + + monkeypatch.setattr(nasa_moisture, '_get_huc8_polygon', lambda lat, lon: _POLYGON) + monkeypatch.setattr(nasa_moisture, '_login_earthdata', lambda: object()) + monkeypatch.setattr(nasa_moisture, '_search_granules', + lambda poly, s, e: [g_a, g_b]) + + def fake_download(granule, tmpdir): + import shutil + src = download_map[id(granule)] + dst = f"{tmpdir}/{src.rsplit('/', 1)[-1]}" + shutil.copy(src, dst) + return dst + + monkeypatch.setattr(nasa_moisture, '_download_granule', fake_download) + + out = nasa_moisture.main(40.0, -105.0, '2024-01-20', '2024-01-20') + assert len(out) == 1 + assert out.iloc[0]['Date'] == '2024-01-20' + assert out.iloc[0]['soil_moisture'] == pytest.approx(0.30) + + +def test_main_skips_granules_that_fail_download(monkeypatch, tmp_path): + good_path = str(tmp_path / 'SMAP_L3_SM_P_E_20240117_R001_001.h5') + _write_smap_h5(good_path, np.full((3, 3), 0.25)) + + g_bad = _FakeGranule('SMAP_L3_SM_P_E_20240115_R000_001.h5') + g_good = _FakeGranule('SMAP_L3_SM_P_E_20240117_R001_001.h5') + + monkeypatch.setattr(nasa_moisture, '_get_huc8_polygon', lambda lat, lon: _POLYGON) + monkeypatch.setattr(nasa_moisture, '_login_earthdata', lambda: object()) + monkeypatch.setattr(nasa_moisture, '_search_granules', + lambda poly, s, e: [g_bad, g_good]) + + def fake_download(granule, tmpdir): + if granule is g_bad: + return None + import shutil + dst = f"{tmpdir}/{good_path.rsplit('/', 1)[-1]}" + shutil.copy(good_path, dst) + return dst + + monkeypatch.setattr(nasa_moisture, '_download_granule', fake_download) + + out = nasa_moisture.main(40.0, -105.0, '2024-01-15', '2024-01-17') + assert len(out) == 1 + assert out.iloc[0]['Date'] == '2024-01-17' + + +def test_to_date_accepts_string_datetime_and_date(): + assert nasa_moisture._to_date('2024-01-15') == date(2024, 1, 15) + assert nasa_moisture._to_date(datetime(2024, 1, 15, 12, 0)) == date(2024, 1, 15) + assert nasa_moisture._to_date(date(2024, 1, 15)) == date(2024, 1, 15) diff --git a/tests/test_normalize_data.py b/tests/test_normalize_data.py index d291f99..34ae4b5 100644 --- a/tests/test_normalize_data.py +++ b/tests/test_normalize_data.py @@ -112,6 +112,30 @@ def test_normalize_fills_missing_swe_with_zero_not_dropping_rows(): assert out['SWE'].isnull().sum() == 0 +def test_normalize_scales_soil_moisture_when_present(tmp_path): + import json + df = _make_combined() + df['soil_moisture'] = np.linspace(0.05, 0.45, 20) + out = normalize_data.normalize_data(df, artifacts_dir=str(tmp_path)) + assert out is not None + scalers = json.loads((tmp_path / 'scalers.json').read_text()) + assert 'soil_moisture' in scalers + # Identity transform (no log1p) for soil moisture; z-scored on the raw scale. + assert scalers['soil_moisture']['transform'] == 'identity' + assert abs(out['soil_moisture'].mean()) < 1e-6 + assert abs(out['soil_moisture'].std(ddof=0) - 1.0) < 1e-6 + + +def test_normalize_fills_missing_soil_moisture_with_zero(): + df = _make_combined() + df['soil_moisture'] = [np.nan] * len(df) + out = normalize_data.normalize_data(df) + assert out is not None + # Missing soil moisture should NOT drop rows. + assert len(out) == len(df) + assert out['soil_moisture'].isnull().sum() == 0 + + def test_flow_columns_are_log_transformed_before_scaling(tmp_path): import json df = _make_combined() diff --git a/tests/test_windowing.py b/tests/test_windowing.py index a4fad59..9493854 100644 --- a/tests/test_windowing.py +++ b/tests/test_windowing.py @@ -20,6 +20,7 @@ def _make_station_frame(site_id, station_idx, basin_idx, n_days=120, start='2022 'TMIN': rng.standard_normal(n_days), 'TMAX': rng.standard_normal(n_days), 'SWE': rng.standard_normal(n_days), + 'soil_moisture': rng.standard_normal(n_days), 'doy_sin': np.sin(2 * np.pi * np.arange(n_days) / 365), 'doy_cos': np.cos(2 * np.pi * np.arange(n_days) / 365), }) @@ -39,6 +40,9 @@ def test_decoder_features_carry_no_flow_information(): assert 'Max Flow' not in windowing.DECODER_FEATURES # SWE is also a current-conditions feature -- no skillful 14-day forecast. assert 'SWE' not in windowing.DECODER_FEATURES + # Same for soil moisture: SMAP is encoder-only, never in the decoder window. + assert 'soil_moisture' not in windowing.DECODER_FEATURES + assert 'soil_moisture' in windowing.ENCODER_FEATURES def test_build_windows_shapes_are_correct(): @@ -145,3 +149,9 @@ def test_build_windows_rejects_dataframe_missing_required_columns(): df = _make_station_frame('USGS:A', 1, 1).drop(columns=['SWE']) with pytest.raises(ValueError): windowing.build_windows(df) + + +def test_build_windows_rejects_dataframe_missing_soil_moisture(): + df = _make_station_frame('USGS:A', 1, 1).drop(columns=['soil_moisture']) + with pytest.raises(ValueError): + windowing.build_windows(df) From 71fcc391c79765d18ee5aba999deca6d365947a3 Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 18 May 2026 01:45:10 +0000 Subject: [PATCH 2/3] Phase 4 fixes: SM quality flag, granule cache, smarter missing-value handling The Phase 4 review flagged three issues that would substantively bias the SMAP signal during training; this fixes all three so the with-vs-without-SMAP ablation actually measures the feature, not the imputation noise. 1. Better missing-value handling for soil_moisture - fillna(0) said "Sahara desert" on every SMAP gap (RFI, frozen ground, sensor outage). Replaced with: forward-fill + back-fill (SM is slow- varying) -> station median -> 0 only if the station has no observations at all. - New sm_observed indicator (0/1) tells the model when soil_moisture is a real / short-gap-interpolated retrieval vs imputed via the fallback. Wired through normalize_data (NOT z-scored -- binary semantics) and into windowing.ENCODER_FEATURES (encoder-only, like soil_moisture). 2. SMAP retrieval_qual_flag filtering - _extract_polygon_mean now drops pixels where bit 0 of the recommended- quality flag is set. Wintertime Colorado retrievals routinely include frozen-ground pixels that look numerically valid but the SMAP team flags as not recommended; without this filter they corrupted the polygon mean. - Falls back to the prior behavior (no quality filter) when the dataset is absent in the granule, since older versions sometimes omit it. 3. Granule cache to avoid the obvious redownload-per-station waste - SPL3SMP_E granules are global daily files (~30-100 MB each), so every station sees the same granule for a given day. New module-level _GRANULE_PATH_CACHE plus a persistent cache dir (OPENFLOW_SMAP_CACHE_DIR, defaults to /tmp/openflow_smap_cache) means each granule is fetched once per process and -- with the env var set to a stable path -- once across runs. - Replaced the TemporaryDirectory in main() since deleting the cache after every per-station call defeated the point. Tests: 4 new cases for quality-flag filtering (recommended / all-rejected / dataset-missing / disabled), 2 new cases for the granule cache (in-process dedup and disk-resume), 3 new cases for combine_data SM handling (indicator truth, median fallback, zero-only-as-last-resort), 2 new cases for normalize_data preserving binary sm_observed. Full local suite: 102 passed. --- openFlowML/combine_data.py | 24 ++++-- openFlowML/data/nasa_moisture.py | 131 ++++++++++++++++++++++++------- openFlowML/normalize_data.py | 22 ++++-- openFlowML/train.py | 2 +- openFlowML/windowing.py | 6 +- tests/test_combine_data.py | 68 +++++++++++++--- tests/test_nasa_moisture.py | 121 +++++++++++++++++++++++++++- tests/test_normalize_data.py | 23 ++++++ tests/test_windowing.py | 4 + 9 files changed, 347 insertions(+), 54 deletions(-) diff --git a/openFlowML/combine_data.py b/openFlowML/combine_data.py index 78eb0c8..647679a 100644 --- a/openFlowML/combine_data.py +++ b/openFlowML/combine_data.py @@ -106,9 +106,9 @@ def merge_dataframes(noaa_data, flow_data, site_id, start_date, end_date, else: combined['SWE'] = float('nan') - # SMAP soil moisture: slow-varying surface state; same shape as SWE - # handling -- generous interior interpolation limit, missing values - # default to 0 instead of dropping the row. + # SMAP soil moisture: slow-varying surface state; generous interior + # interpolation limit (matches SWE). Edge handling happens AFTER the + # core-column dropna below, where we have the final row set. if sm_data is not None and not sm_data.empty: sm_daily = _to_daily_series(sm_data, ['soil_moisture'], daily_index) sm_daily['soil_moisture'] = pd.to_numeric(sm_daily['soil_moisture'], errors='coerce') @@ -121,9 +121,23 @@ def merge_dataframes(noaa_data, flow_data, site_id, start_date, end_date, # are NOT in CORE_COLUMNS, so a missing one never drops a row. before = len(combined) combined = combined.dropna(subset=CORE_COLUMNS) - # Any remaining SWE/SM gaps default to 0 ("no data" / out-of-season). + # SWE: 0 means "no snow", which IS a legitimate default; keep that. combined['SWE'] = combined['SWE'].fillna(0.0) - combined['soil_moisture'] = combined['soil_moisture'].fillna(0.0) + # Soil moisture: 0 means "Sahara desert", which is NOT a legitimate + # default for a SMAP gap (RFI / frozen ground / sensor outage). We + # carry the last interpolated value forward and backward (SM is slow- + # varying), then fall back to the station's median observed SM where + # ffill/bfill can't reach, and only to 0 if the station has no + # observations at all. An sm_observed indicator (1 = real or short- + # gap interpolated retrieval, 0 = imputed via ffill / median fallback) + # gives the model a way to tell the truth from the imputation. + sm_series = combined['soil_moisture'] + combined['sm_observed'] = sm_series.notna().astype('int64') + sm_series = sm_series.ffill().bfill() + median = sm_series.median(skipna=True) + if pd.isna(median): + median = 0.0 + combined['soil_moisture'] = sm_series.fillna(median) logger.info( "Site %s: %d/%d daily rows usable after gap handling", site_id, len(combined), before, diff --git a/openFlowML/data/nasa_moisture.py b/openFlowML/data/nasa_moisture.py index 1b920c4..c1137ae 100644 --- a/openFlowML/data/nasa_moisture.py +++ b/openFlowML/data/nasa_moisture.py @@ -3,17 +3,23 @@ Given a station's lat/lon and a date window, this module looks up the enclosing HUC8 polygon, searches NASA Earthdata for every SMAP granule whose footprint -intersects that polygon over the window, downloads them one at a time, -extracts the average volumetric soil moisture inside the polygon for each -granule, and returns a single daily series. +intersects that polygon over the window, downloads them (with a granule cache +shared across calls so neighboring stations don't redownload the same global +daily file), filters each granule to recommended-quality pixels inside the +polygon's bbox, and returns the daily polygon mean. Public API: main(lat, lon, start_date, end_date) -> DataFrame[Date, soil_moisture] A failure at any step (HUC8 lookup, Earthdata auth, search, download, extraction) degrades to an empty DataFrame -- combine_data treats missing -soil moisture as "no data" (defaults to 0 in the spine) rather than dropping -the row, mirroring the SWE handling. +soil moisture as "no SMAP today" (forward-filled, then site-median fallback, +plus an sm_observed indicator) rather than dropping the row. + +Performance: SPL3SMP_E granules are global daily files (~30-100 MB each), so +two stations in different HUC8s on the same day pull the same granule. The +module-level _GRANULE_PATH_CACHE deduplicates within a single process; set +OPENFLOW_SMAP_CACHE_DIR to a persistent path to keep granules across runs. """ import argparse @@ -37,6 +43,13 @@ # Valid SMAP retrieval range for volumetric soil moisture (m^3/m^3). VALID_MIN = 0.0 VALID_MAX = 1.0 +# retrieval_qual_flag bit 0: 0 = retrieval is recommended quality, 1 = not. +# Conservative filter: drop any pixel where bit 0 is set. +QUAL_RECOMMENDED_BIT = 0 + +# Process-level cache so the per-station SMAP loop doesn't re-download the +# same global daily granule for every basin. Keyed by granule filename. +_GRANULE_PATH_CACHE: dict = {} def _empty() -> pd.DataFrame: @@ -146,12 +159,19 @@ def _granule_date(granule) -> Optional[date]: return None -def _extract_polygon_mean(hdf_path: str, polygon) -> Optional[float]: +def _extract_polygon_mean(hdf_path: str, polygon, + use_quality_flag: bool = True) -> Optional[float]: """ Average volumetric soil moisture across all valid SMAP pixels inside the polygon's bounding box, across both AM and PM passes. Returns None when - no valid pixel intersects (cloudy / RFI day, or polygon entirely off the - EASE-grid for that pass). + no valid pixel intersects (cloudy / RFI day, frozen ground, or polygon + entirely off the EASE-grid for that pass). + + Pixels are filtered against the explicit fill value, the [0, 1] valid + range, and -- when the dataset is present -- the SMAP retrieval_qual_flag + bit 0 ("recommended quality"). Without the quality filter, winter Colorado + retrievals include frozen-ground pixels that look numerically plausible + but the SMAP team flags as not recommended. Bbox is sufficient at SMAP's 9 km grid spacing -- the HUC8 footprint is rarely much larger than a handful of cells and the bbox vs true polygon @@ -191,6 +211,13 @@ def _extract_polygon_mean(hdf_path: str, polygon) -> Optional[float]: (lat >= min_lat) & (lat <= max_lat)) valid = ((sm != FILL_VALUE) & (sm >= VALID_MIN) & (sm <= VALID_MAX) & (lon != FILL_VALUE) & (lat != FILL_VALUE)) + # Quality flag: drop pixels where the recommended-quality bit + # is set. Older granules occasionally omit the dataset; in that + # case we proceed without the filter rather than dropping all. + if use_quality_flag and 'retrieval_qual_flag' in grp: + qflag = grp['retrieval_qual_flag'][:].astype('int32') + recommended = ((qflag >> QUAL_RECOMMENDED_BIT) & 1) == 0 + valid = valid & recommended mask = in_bbox & valid if mask.any(): collected.extend(sm[mask].astype(float).tolist()) @@ -203,20 +230,67 @@ def _extract_polygon_mean(hdf_path: str, polygon) -> Optional[float]: return float(sum(collected) / len(collected)) -def _download_granule(granule, tmpdir): - """earthaccess.download wrapper that returns the local file path or None.""" +def _granule_cache_key(granule) -> Optional[str]: + """ + Stable id for a granule, used as the cache key. Falls back through + `.data_links()` (the most reliable per-granule identifier) before giving + up; granules without a usable id are simply not cached. + """ + try: + links = granule.data_links() if hasattr(granule, 'data_links') else [] + for link in links or []: + name = link.rsplit('/', 1)[-1] + if name: + return name + except Exception: + pass + return None + + +def _get_cache_dir() -> str: + """ + Persistent SMAP granule cache directory. Defaults to a subdirectory of the + system tempdir; override with OPENFLOW_SMAP_CACHE_DIR to share the cache + across runs (e.g. a CI cache action). + """ + base = (os.environ.get('OPENFLOW_SMAP_CACHE_DIR') or + os.path.join(tempfile.gettempdir(), 'openflow_smap_cache')) + os.makedirs(base, exist_ok=True) + return base + + +def _download_granule(granule, cache_dir): + """ + Download `granule` into `cache_dir`, or return its cached path. Returns + None on failure (and does NOT delete the cached file -- the cache is + intentionally process-lifetime+). + """ + key = _granule_cache_key(granule) + if key: + cached = _GRANULE_PATH_CACHE.get(key) + if cached and os.path.exists(cached): + logger.debug("Granule cache hit: %s", key) + return cached + # Also check disk in case a previous process populated the dir. + on_disk = os.path.join(cache_dir, key) + if os.path.exists(on_disk): + _GRANULE_PATH_CACHE[key] = on_disk + return on_disk try: import earthaccess except ImportError: return None try: - files = earthaccess.download([granule], local_path=tmpdir) + files = earthaccess.download([granule], local_path=cache_dir) except Exception as e: logger.warning("Granule download failed: %s", e) return None if not files: return None - return files[0] + path = files[0] + if key: + _GRANULE_PATH_CACHE[key] = path + return path def main(lat: float, lon: float, start_date, end_date) -> pd.DataFrame: @@ -249,25 +323,22 @@ def main(lat: float, lon: float, start_date, end_date) -> pd.DataFrame: return _empty() logger.info("Found %d SMAP granules", len(granules)) + cache_dir = _get_cache_dir() rows = [] - with tempfile.TemporaryDirectory(prefix='smap_') as tmpdir: - for granule in granules: - obs_date = _granule_date(granule) - if obs_date is None: - continue - path = _download_granule(granule, tmpdir) - if not path: - continue - try: - value = _extract_polygon_mean(path, polygon) - finally: - try: - os.remove(path) - except OSError: - pass - if value is None: - continue - rows.append((obs_date.strftime('%Y-%m-%d'), value)) + for granule in granules: + obs_date = _granule_date(granule) + if obs_date is None: + continue + path = _download_granule(granule, cache_dir) + if not path: + continue + # Cached files are intentionally NOT removed -- the next station's + # call to main() in this process will hit the cache for the same + # global daily granule. + value = _extract_polygon_mean(path, polygon) + if value is None: + continue + rows.append((obs_date.strftime('%Y-%m-%d'), value)) if not rows: return _empty() diff --git a/openFlowML/normalize_data.py b/openFlowML/normalize_data.py index 8ae0104..cc8122c 100644 --- a/openFlowML/normalize_data.py +++ b/openFlowML/normalize_data.py @@ -27,10 +27,17 @@ # Required: rows missing any of these are dropped (no pooled-mean fill). CORE_REQUIRED = ['TMIN', 'TMAX', 'Min Flow', 'Max Flow'] -# Optional columns that get scaled when present. SWE and soil_moisture are -# slow-varying; if either is missing for a row, default to 0 rather than -# dropping the row. +# Optional continuous columns that get scaled when present. SWE and +# soil_moisture are slow-varying; if either is missing for a row, default to +# 0 rather than dropping the row (combine_data does smarter SM handling, but +# this is the safety net). OPTIONAL_NUMERIC = ['SWE', 'soil_moisture'] +# Binary indicator columns: included as model inputs but NOT z-scored. Scaling +# a 0/1 indicator destroys the semantics (the model needs to see 0 vs 1 as +# distinct cases, not as samples from a centered normal). sm_observed marks +# rows where soil_moisture is a real / short-gap-interpolated SMAP retrieval +# vs imputed via the median fallback in combine_data. +INDICATOR_COLUMNS = ['sm_observed'] NUMERIC_COLUMNS = CORE_REQUIRED + OPTIONAL_NUMERIC # Streamflow is log-normal -- log1p before z-scoring is standard hydrology # practice. Temperature/SWE stay linear. @@ -141,9 +148,11 @@ def normalize_data(data, artifacts_dir=None): if missing_core: raise ValueError(f"Missing required columns in the data: {missing_core}") - # Coerce every numeric column we know about, including SWE if present. + # Coerce every numeric column we know about, including SWE if present, + # plus the binary indicator columns (still numeric, just not scaled). present_numeric = [c for c in NUMERIC_COLUMNS if c in data.columns] - for column in present_numeric: + present_indicator = [c for c in INDICATOR_COLUMNS if c in data.columns] + for column in present_numeric + present_indicator: data[column] = pd.to_numeric(data[column], errors='coerce') # SWE and soil_moisture are slowly varying and often legitimately near @@ -154,6 +163,9 @@ def normalize_data(data, artifacts_dir=None): data['SWE'] = data['SWE'].fillna(0.0) if 'soil_moisture' in data.columns: data['soil_moisture'] = data['soil_moisture'].fillna(0.0) + # Indicators default to 0 ("not observed") when missing. + for column in present_indicator: + data[column] = data[column].fillna(0).astype('int64') # combine_data owns per-station gap handling for flow + temperature; # anything still missing in the core columns here is dropped rather diff --git a/openFlowML/train.py b/openFlowML/train.py index 6712df3..ba88367 100644 --- a/openFlowML/train.py +++ b/openFlowML/train.py @@ -82,7 +82,7 @@ def main(): # 2. Sanity-check spine outputs the rest of training relies on. required = {'station_idx', 'basin_idx', 'site_id', 'Date', 'Min Flow', 'Max Flow', 'TMIN', 'TMAX', - 'SWE', 'soil_moisture', + 'SWE', 'soil_moisture', 'sm_observed', 'doy_sin', 'doy_cos'} missing = required - set(data.columns) if missing: diff --git a/openFlowML/windowing.py b/openFlowML/windowing.py index 9a9dfd2..c3c44d5 100644 --- a/openFlowML/windowing.py +++ b/openFlowML/windowing.py @@ -28,9 +28,11 @@ # (these are observations); SWE -- current snowpack is a strong predictor of # snowmelt-fed runoff in Colorado; soil_moisture (SMAP L3 enhanced) is the # antecedent wetness state that gates how much new precipitation becomes runoff -# vs infiltrates. +# vs infiltrates; sm_observed is the 0/1 indicator that flags whether +# soil_moisture for that day was a real SMAP retrieval or imputed by +# combine_data's median fallback. ENCODER_FEATURES = ['Min Flow', 'Max Flow', 'TMIN', 'TMAX', - 'SWE', 'soil_moisture', + 'SWE', 'soil_moisture', 'sm_observed', 'doy_sin', 'doy_cos'] # Decoder window: ONLY features available at forecast time. No flow (that's # what we're predicting), no SWE / no soil_moisture (neither has a skillful diff --git a/tests/test_combine_data.py b/tests/test_combine_data.py index 17e7b4a..b47c825 100644 --- a/tests/test_combine_data.py +++ b/tests/test_combine_data.py @@ -123,23 +123,71 @@ def test_merge_attaches_interpolated_soil_moisture_when_provided(): sm_data=sm) assert 'soil_moisture' in merged.columns assert not merged['soil_moisture'].isnull().any() - interior = merged.loc[ - pd.to_datetime(merged['Date']) <= pd.Timestamp('2022-01-19'), - 'soil_moisture'] - assert interior.max() <= 0.30 and interior.min() >= 0.15 - trailing = merged.loc[ - pd.to_datetime(merged['Date']) > pd.Timestamp('2022-01-19'), - 'soil_moisture'] - # Trailing rows past the last known SM date default to 0 (same as SWE). - assert (trailing == 0.0).all() + # Every value must lie within the observed range -- ffill/bfill carries + # the edge values to the trailing days, not 0. + sm_min, sm_max = 0.15, 0.30 + assert merged['soil_moisture'].max() <= sm_max + assert merged['soil_moisture'].min() >= sm_min -def test_merge_defaults_soil_moisture_to_zero_when_not_provided(): +def test_merge_soil_moisture_indicator_flags_real_observations(): + noaa, flow = _make_frames() + dates = pd.date_range('2022-01-01', '2022-01-20', freq='D') + # SMAP retrievals on a sparse subset of days -- the rest get interpolated + # or ffilled / bfilled. + sm = pd.DataFrame({'Date': dates[::3], + 'soil_moisture': [0.15, 0.20, 0.25, 0.30, 0.25, 0.20, 0.15]}) + merged = combine_data.merge_dataframes( + noaa, flow, 'USGS:TEST', + datetime(2022, 1, 1), datetime(2022, 1, 20), + sm_data=sm) + assert 'sm_observed' in merged.columns + # The indicator is 1 on real-or-interpolated rows (everything from the + # first observation to the last observation, since MAX_SM_GAP_DAYS=30 + # exceeds the 3-day spacing) and 0 on rows imputed via ffill/bfill. + rows = merged.set_index(pd.to_datetime(merged['Date'])) + # Jan 19 is the last sparse observation; Jan 20 is ffill / bfill territory. + assert rows.loc['2022-01-19', 'sm_observed'] == 1 + assert rows.loc['2022-01-20', 'sm_observed'] == 0 + + +def test_merge_soil_moisture_falls_back_to_site_median_not_zero(): + # If a station has SOME observations but nothing extending to a portion of + # the window, the missing rows should imputed via ffill/bfill (which here + # leaves no gap), then the median if a gap remained -- never 0. + noaa, flow = _make_frames() + dates = pd.date_range('2022-01-01', '2022-01-20', freq='D') + # Only two observations, well within the window. Long stretches before and + # after these get ffilled / bfilled to those known values, NOT to 0. + sm = pd.DataFrame({'Date': [dates[5], dates[10]], + 'soil_moisture': [0.30, 0.40]}) + merged = combine_data.merge_dataframes( + noaa, flow, 'USGS:TEST', + datetime(2022, 1, 1), datetime(2022, 1, 20), + sm_data=sm) + # Leading days bfill to 0.30; trailing days ffill from 0.40. Neither is 0. + assert (merged['soil_moisture'] > 0).all() + # sm_observed: the two real observations + the interior rows between them + # (interpolated within MAX_SM_GAP_DAYS) are flagged as observed (1). Days + # before the first or after the last observation are imputed via ffill/ + # bfill and flagged not-observed (0). + rows = merged.set_index(pd.to_datetime(merged['Date'])) + assert rows.loc['2022-01-01', 'sm_observed'] == 0 # before first obs + assert rows.loc['2022-01-06', 'sm_observed'] == 1 # first obs + assert rows.loc['2022-01-08', 'sm_observed'] == 1 # interior interpolated + assert rows.loc['2022-01-11', 'sm_observed'] == 1 # second obs + assert rows.loc['2022-01-20', 'sm_observed'] == 0 # after last obs + + +def test_merge_defaults_soil_moisture_to_zero_only_when_no_observations(): noaa, flow = _make_frames() merged = combine_data.merge_dataframes( noaa, flow, 'USGS:TEST', datetime(2022, 1, 1), datetime(2022, 1, 20)) assert 'soil_moisture' in merged.columns + # With no SMAP data at all, fall back to 0 (last-resort default). assert (merged['soil_moisture'] == 0.0).all() + # And the indicator correctly says "none of these are real observations". + assert (merged['sm_observed'] == 0).all() def test_merge_records_huc8_when_provided(): diff --git a/tests/test_nasa_moisture.py b/tests/test_nasa_moisture.py index 97ad1cb..322bf30 100644 --- a/tests/test_nasa_moisture.py +++ b/tests/test_nasa_moisture.py @@ -6,6 +6,7 @@ real HDF5 file and reading it back through _extract_polygon_mean. """ +import os from datetime import date, datetime import numpy as np @@ -22,7 +23,7 @@ _POLYGON = [(-106.5, 39.5), (-104.5, 39.5), (-104.5, 41.5), (-106.5, 41.5)] -def _write_smap_h5(path, sm_am, sm_pm=None): +def _write_smap_h5(path, sm_am, sm_pm=None, qflag_am=None, qflag_pm=None): """Write a minimal SMAP-shaped HDF5 file at `path`.""" lat = np.array([ [40.0, 40.0, 40.0], @@ -39,11 +40,15 @@ def _write_smap_h5(path, sm_am, sm_pm=None): am.create_dataset('soil_moisture', data=sm_am) am.create_dataset('latitude', data=lat) am.create_dataset('longitude', data=lon) + if qflag_am is not None: + am.create_dataset('retrieval_qual_flag', data=qflag_am) if sm_pm is not None: pm = f.create_group('Soil_Moisture_Retrieval_Data_PM') pm.create_dataset('soil_moisture', data=sm_pm) pm.create_dataset('latitude', data=lat) pm.create_dataset('longitude', data=lon) + if qflag_pm is not None: + pm.create_dataset('retrieval_qual_flag', data=qflag_pm) class _FakeGranule: @@ -230,3 +235,117 @@ def test_to_date_accepts_string_datetime_and_date(): assert nasa_moisture._to_date('2024-01-15') == date(2024, 1, 15) assert nasa_moisture._to_date(datetime(2024, 1, 15, 12, 0)) == date(2024, 1, 15) assert nasa_moisture._to_date(date(2024, 1, 15)) == date(2024, 1, 15) + + +def test_extract_polygon_mean_drops_not_recommended_quality_pixels(tmp_path): + path = str(tmp_path / 'sm.h5') + # All pixels have legitimate-looking soil moisture; the quality flag + # rejects half of them (bit 0 = 1 means "not recommended"). The polygon + # mean must reflect only the recommended pixels. + sm = np.full((3, 3), 0.40) + qflag = np.array([ + [1, 1, 1], # all not-recommended + [0, 0, 0], # all recommended + [1, 1, 1], # all not-recommended + ], dtype='int32') + _write_smap_h5(path, sm, qflag_am=qflag) + value = nasa_moisture._extract_polygon_mean(path, _POLYGON) + # Only the middle row's 3 pixels survive -- all happen to be 0.40. + assert value == pytest.approx(0.40) + + +def test_extract_polygon_mean_returns_none_when_quality_flag_rejects_all(tmp_path): + path = str(tmp_path / 'sm.h5') + sm = np.full((3, 3), 0.40) + qflag = np.ones((3, 3), dtype='int32') # every pixel "not recommended" + _write_smap_h5(path, sm, qflag_am=qflag) + assert nasa_moisture._extract_polygon_mean(path, _POLYGON) is None + + +def test_extract_polygon_mean_proceeds_when_quality_dataset_missing(tmp_path): + # Older granules sometimes omit retrieval_qual_flag; we must not refuse + # to extract in that case. + path = str(tmp_path / 'sm.h5') + sm = np.full((3, 3), 0.30) + _write_smap_h5(path, sm) # no qflag dataset + assert nasa_moisture._extract_polygon_mean(path, _POLYGON) == pytest.approx(0.30) + + +def test_extract_polygon_mean_ignores_quality_flag_when_disabled(tmp_path): + path = str(tmp_path / 'sm.h5') + sm = np.full((3, 3), 0.40) + qflag = np.ones((3, 3), dtype='int32') # all "not recommended" + _write_smap_h5(path, sm, qflag_am=qflag) + value = nasa_moisture._extract_polygon_mean( + path, _POLYGON, use_quality_flag=False) + # With the filter off, every in-bbox / in-range pixel counts. + assert value == pytest.approx(0.40) + + +def test_granule_cache_dedupes_downloads_within_a_process(monkeypatch, tmp_path): + # Two main() calls hitting the same granule must only invoke the actual + # download once -- the cache lets neighboring stations share global daily + # granules without re-fetching. + nasa_moisture._GRANULE_PATH_CACHE.clear() + cache_dir = str(tmp_path / 'cache') + monkeypatch.setenv('OPENFLOW_SMAP_CACHE_DIR', cache_dir) + + granule_path = str(tmp_path / 'SMAP_L3_SM_P_E_20240115_R001_001.h5') + _write_smap_h5(granule_path, np.full((3, 3), 0.25)) + g = _FakeGranule('SMAP_L3_SM_P_E_20240115_R001_001.h5') + + download_calls = {'n': 0} + + class _FakeEA: + @staticmethod + def download(granules, local_path): + download_calls['n'] += 1 + import shutil + dst = os.path.join(local_path, 'SMAP_L3_SM_P_E_20240115_R001_001.h5') + shutil.copy(granule_path, dst) + return [dst] + + import sys + monkeypatch.setitem(sys.modules, 'earthaccess', _FakeEA) + monkeypatch.setattr(nasa_moisture, '_get_huc8_polygon', lambda lat, lon: _POLYGON) + monkeypatch.setattr(nasa_moisture, '_login_earthdata', lambda: object()) + monkeypatch.setattr(nasa_moisture, '_search_granules', + lambda poly, s, e: [g]) + + out1 = nasa_moisture.main(40.0, -105.0, '2024-01-15', '2024-01-15') + out2 = nasa_moisture.main(40.1, -105.1, '2024-01-15', '2024-01-15') + + assert not out1.empty and not out2.empty + # The actual download was invoked exactly once across both main() calls. + assert download_calls['n'] == 1 + + +def test_granule_cache_picks_up_existing_disk_files(monkeypatch, tmp_path): + # If a previous process populated the cache dir, a fresh process should + # still hit the cache without re-downloading. + nasa_moisture._GRANULE_PATH_CACHE.clear() + cache_dir = str(tmp_path / 'cache') + os.makedirs(cache_dir, exist_ok=True) + monkeypatch.setenv('OPENFLOW_SMAP_CACHE_DIR', cache_dir) + + # Pre-populate the cache dir with the granule file (simulating a previous + # run that already downloaded it). + cached_path = os.path.join(cache_dir, 'SMAP_L3_SM_P_E_20240115_R001_001.h5') + _write_smap_h5(cached_path, np.full((3, 3), 0.35)) + g = _FakeGranule('SMAP_L3_SM_P_E_20240115_R001_001.h5') + + class _FailingEA: + @staticmethod + def download(granules, local_path): + raise AssertionError("download must not be called when cached on disk") + + import sys + monkeypatch.setitem(sys.modules, 'earthaccess', _FailingEA) + monkeypatch.setattr(nasa_moisture, '_get_huc8_polygon', lambda lat, lon: _POLYGON) + monkeypatch.setattr(nasa_moisture, '_login_earthdata', lambda: object()) + monkeypatch.setattr(nasa_moisture, '_search_granules', + lambda poly, s, e: [g]) + + out = nasa_moisture.main(40.0, -105.0, '2024-01-15', '2024-01-15') + assert not out.empty + assert out.iloc[0]['soil_moisture'] == pytest.approx(0.35) diff --git a/tests/test_normalize_data.py b/tests/test_normalize_data.py index 34ae4b5..09a9edc 100644 --- a/tests/test_normalize_data.py +++ b/tests/test_normalize_data.py @@ -136,6 +136,29 @@ def test_normalize_fills_missing_soil_moisture_with_zero(): assert out['soil_moisture'].isnull().sum() == 0 +def test_normalize_keeps_sm_observed_as_binary_indicator(tmp_path): + import json + df = _make_combined() + df['soil_moisture'] = np.linspace(0.05, 0.45, 20) + df['sm_observed'] = [1, 0, 1, 0, 1, 0, 1, 0, 1, 0] * 2 + out = normalize_data.normalize_data(df, artifacts_dir=str(tmp_path)) + assert out is not None + # Indicator MUST NOT be z-scored -- the model needs to see 0 vs 1. + assert set(out['sm_observed'].unique()) <= {0, 1} + scalers = json.loads((tmp_path / 'scalers.json').read_text()) + assert 'sm_observed' not in scalers + + +def test_normalize_defaults_sm_observed_to_zero_when_missing(): + df = _make_combined() + df['sm_observed'] = [1, np.nan, 1, np.nan] * 5 + out = normalize_data.normalize_data(df) + assert out is not None + # Missing rows default to 0 ("not observed"). + assert out['sm_observed'].isnull().sum() == 0 + assert set(out['sm_observed'].unique()) <= {0, 1} + + def test_flow_columns_are_log_transformed_before_scaling(tmp_path): import json df = _make_combined() diff --git a/tests/test_windowing.py b/tests/test_windowing.py index 9493854..760c77c 100644 --- a/tests/test_windowing.py +++ b/tests/test_windowing.py @@ -21,6 +21,7 @@ def _make_station_frame(site_id, station_idx, basin_idx, n_days=120, start='2022 'TMAX': rng.standard_normal(n_days), 'SWE': rng.standard_normal(n_days), 'soil_moisture': rng.standard_normal(n_days), + 'sm_observed': rng.integers(0, 2, n_days), 'doy_sin': np.sin(2 * np.pi * np.arange(n_days) / 365), 'doy_cos': np.cos(2 * np.pi * np.arange(n_days) / 365), }) @@ -43,6 +44,9 @@ def test_decoder_features_carry_no_flow_information(): # Same for soil moisture: SMAP is encoder-only, never in the decoder window. assert 'soil_moisture' not in windowing.DECODER_FEATURES assert 'soil_moisture' in windowing.ENCODER_FEATURES + # The sm_observed indicator is also encoder-only. + assert 'sm_observed' in windowing.ENCODER_FEATURES + assert 'sm_observed' not in windowing.DECODER_FEATURES def test_build_windows_shapes_are_correct(): From 69b271821ab6116c1a4e9d1189370d6113ec1652 Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 18 May 2026 02:15:37 +0000 Subject: [PATCH 3/3] Phase 5: USDM drought + USBR RISE reservoirs (full); CBRFC + S2F (stubs) Adds four new data sources -- two wired as encoder features, two structured as baseline-comparison hooks for train.py. == New encoder features == USDM drought (data/get_drought.py) - Clean public USDM data service per HUC8; weekly snapshots of percent area in each drought category (D0..D4) collapsed to a single intensity index (D0*1 + D1*2 + ... + D4*5, range 0..500). - Weekly snapshots forward-filled to daily in combine_data with a 14-day ffill limit (USDM cadence is ~weekly so 14 days of headroom is safe). - 0 means "no drought", which IS the legitimate default for missing rows. USBR RISE reservoirs (data/get_reservoir.py) - REST fetch of daily storage + release for any USBR-managed reservoir, via the public RISE JSON:API. Per-page pagination is followed via the JSON:API `links.next` cursor. - Station -> reservoir(s) mapping lives in .github/reservoir_mapping.txt (commented template shipped; unmapped stations are treated as unregulated and the two reservoir columns + reservoir_observed indicator stay at 0). - For sites with multiple upstream reservoirs, storage and release are summed (total water held back / total outflow). - reservoir_observed (0/1) tells the model when the storage / release columns are real data vs the unregulated default, the same way sm_observed flags real-vs-imputed soil moisture. == Baseline comparison hooks (stubs) == NOAA CBRFC (data/get_cbrfc.py) + USBR S2F (data/get_s2f.py) - Define the train.py integration point (`baseline_predictions(test_samples) -> Optional[np.ndarray]`) and the per-sample fetch API. - Both return None until the historical archive is wired in; train.py treats None as "skip this baseline" and the persistence comparison is unaffected. - The module docstrings spell out the data-access gap and the obvious follow-up paths (AHPS tarball archive for CBRFC; per-basin CSV/PDF scrape for S2F, plus the seasonal->daily disaggregation it needs to be comparable to our 14-day horizon). == Spine wiring == combine_data - merge_dataframes accepts drought_data + reservoir_data; per-source interpolation limits (MAX_DROUGHT_GAP_DAYS=14, MAX_RESERVOIR_GAP_DAYS=14) mirror the SWE / SMAP patterns. - fetch_and_process_data returns a single dict instead of a growing tuple; each source has its own try/except so a single fetch failure can't take down the rest of the spine. - Per-source ablation env vars: OPENFLOW_DISABLE_DROUGHT, OPENFLOW_DISABLE_RESERVOIR (joining OPENFLOW_DISABLE_SMAP). Surfaced as workflow_dispatch inputs in ml_training.yml so you can run an ablation from the GitHub UI without touching yaml. normalize_data + windowing - OPTIONAL_NUMERIC now includes drought_index + reservoir_{storage,release}. - INDICATOR_COLUMNS adds reservoir_observed (kept binary, NOT z-scored). - ENCODER_FEATURES grows from 9 to 13 (the new auxiliaries are all history-window features; the decoder window stays clean of them since none has a skillful 14-day forecast). train.py - Reports CBRFC + S2F per-horizon MAE alongside the persistence comparison whenever those modules return a non-None prediction tensor. With the current stubs, they're silently skipped; the integration is ready to activate the moment fetch() returns real data. Tests: 130 passed locally. New test files for drought (8 cases), reservoir (7 cases), and external baselines (7 cases). Existing combine_data, normalize_data, windowing tests updated for the new columns. --- .github/reservoir_mapping.txt | 33 +++++ .github/workflows/ml_training.yml | 16 +++ openFlowML/combine_data.py | 167 +++++++++++++++++----- openFlowML/data/get_cbrfc.py | 216 +++++++++++++++++++++++++++++ openFlowML/data/get_drought.py | 180 ++++++++++++++++++++++++ openFlowML/data/get_reservoir.py | 222 ++++++++++++++++++++++++++++++ openFlowML/data/get_s2f.py | 97 +++++++++++++ openFlowML/normalize_data.py | 35 ++--- openFlowML/train.py | 39 +++++- openFlowML/windowing.py | 22 +-- tests/test_combine_data.py | 67 +++++++++ tests/test_drought.py | 114 +++++++++++++++ tests/test_external_baselines.py | 64 +++++++++ tests/test_normalize_data.py | 28 ++++ tests/test_reservoir.py | 133 ++++++++++++++++++ tests/test_windowing.py | 13 ++ 16 files changed, 1381 insertions(+), 65 deletions(-) create mode 100644 .github/reservoir_mapping.txt create mode 100644 openFlowML/data/get_cbrfc.py create mode 100644 openFlowML/data/get_drought.py create mode 100644 openFlowML/data/get_reservoir.py create mode 100644 openFlowML/data/get_s2f.py create mode 100644 tests/test_drought.py create mode 100644 tests/test_external_baselines.py create mode 100644 tests/test_reservoir.py diff --git a/.github/reservoir_mapping.txt b/.github/reservoir_mapping.txt new file mode 100644 index 0000000..2efd5da --- /dev/null +++ b/.github/reservoir_mapping.txt @@ -0,0 +1,33 @@ +# Station -> USBR RISE catalog-item-id mapping for reservoir features. +# +# Format (comma-separated, one record per line): +# , , , +# +# Where: +# - site_id matches an entry in .github/site_ids.txt (e.g. USGS:09163500) +# - reservoir_name is human-readable, used only for logging +# - storage_itemId is a USBR RISE catalog-item id for a storage timeseries +# (acre-feet, ideally daily) +# - release_itemId is a USBR RISE catalog-item id for an outflow / release +# timeseries (cfs, ideally daily). Use an empty trailing comma if the +# reservoir has no release series available. +# +# A station with no entries here is treated as unregulated -- both reservoir +# features default to 0 and the reservoir_observed indicator stays at 0, +# which is the correct semantic for an unregulated headwater gauge. +# +# To find catalog-item ids, browse https://data.usbr.gov/rise/ or query +# https://data.usbr.gov/rise/api/catalog-item?query= +# and look for the daily storage / release records you want. +# +# Multiple reservoirs may map to one station; storage values are summed +# (total water held back) and release values are summed (total outflow) at +# read time. +# +# Lines beginning with # are comments and ignored. Empty lines are ignored. +# +# Example (commented out until you've verified the right itemIds for your +# stations -- the IDs below are illustrative placeholders): +# +# USGS:09163500, Lake Powell, 6126, 6127 +# USGS:09114500, Blue Mesa Reservoir, 6135, 6136 diff --git a/.github/workflows/ml_training.yml b/.github/workflows/ml_training.yml index 05fd329..2675d15 100644 --- a/.github/workflows/ml_training.yml +++ b/.github/workflows/ml_training.yml @@ -4,6 +4,19 @@ on: schedule: - cron: '0 0 * * 1' # Run weekly at midnight on Monday workflow_dispatch: + inputs: + disable_smap: + description: 'Skip SMAP soil moisture (ablation)' + type: boolean + default: false + disable_drought: + description: 'Skip USDM drought (ablation)' + type: boolean + default: false + disable_reservoir: + description: 'Skip USBR reservoir (ablation)' + type: boolean + default: false jobs: train: @@ -34,6 +47,9 @@ jobs: env: EARTHDATA_USERNAME: ${{ secrets.EARTHDATA_USERNAME }} EARTHDATA_PASSWORD: ${{ secrets.EARTHDATA_PASSWORD }} + OPENFLOW_DISABLE_SMAP: ${{ inputs.disable_smap && '1' || '' }} + OPENFLOW_DISABLE_DROUGHT: ${{ inputs.disable_drought && '1' || '' }} + OPENFLOW_DISABLE_RESERVOIR: ${{ inputs.disable_reservoir && '1' || '' }} run: python openFlowML/train.py # The model is not usable for inference without the scaler parameters diff --git a/openFlowML/combine_data.py b/openFlowML/combine_data.py index 647679a..cdeb88f 100644 --- a/openFlowML/combine_data.py +++ b/openFlowML/combine_data.py @@ -45,6 +45,13 @@ # vegetation / frozen ground. Soil moisture itself is slow-varying so the same # generous interior interpolation limit as SWE is appropriate. MAX_SM_GAP_DAYS = 30 +# USDM is published weekly; ffill across 8-day gaps is the natural cadence +# for ffill-only handling. Drought intensity is even slower-varying than SWE. +MAX_DROUGHT_GAP_DAYS = 14 +# Reservoir storage / release are reported daily by USBR; gaps are short +# (occasional missing days), so the regular MAX_GAP_DAYS interpolation limit +# is enough. Reservoir state is also slow-varying. +MAX_RESERVOIR_GAP_DAYS = 14 def _to_daily_series(df, value_columns, daily_index): @@ -61,16 +68,19 @@ def _to_daily_series(df, value_columns, daily_index): def merge_dataframes(noaa_data, flow_data, site_id, start_date, end_date, - swe_data=None, sm_data=None, huc8=None): + swe_data=None, sm_data=None, + drought_data=None, reservoir_data=None, + huc8=None): """ - Merge a site's NOAA temperature, flow, SWE and soil-moisture data onto one - regular daily index. + Merge a site's flow, NOAA temperature, SWE, SMAP soil moisture, USDM + drought, and USBR reservoir series onto one regular daily index. Short interior flow/temp gaps (<= MAX_GAP_DAYS) are interpolated time-aware; rows still missing a core value afterwards are dropped (no pooled-mean - fill). SWE and soil_moisture are interpolated with longer limits (slow- - varying) and any remaining missing values default to 0 -- they don't drop - the row. + fill). All Phase 4/5 auxiliary features (SWE, soil_moisture, drought_index, + reservoir_storage, reservoir_release) are slow-varying, get longer + interpolation limits, and missing values are imputed per the column- + specific semantics rather than dropping the row. """ if 'Date' not in noaa_data or 'Date' not in flow_data: raise ValueError("'Date' column missing in one of the dataframes") @@ -117,6 +127,37 @@ def merge_dataframes(noaa_data, flow_data, site_id, start_date, end_date, else: combined['soil_moisture'] = float('nan') + # USDM drought intensity: weekly snapshots, forward-filled. ffill is + # the right move for a step-function weekly series (interpolation + # would smear discrete category changes into ramps). + if drought_data is not None and not drought_data.empty: + d_daily = _to_daily_series(drought_data, ['drought_index'], daily_index) + d_daily['drought_index'] = pd.to_numeric(d_daily['drought_index'], errors='coerce') + d_daily['drought_index'] = d_daily['drought_index'].ffill(limit=MAX_DROUGHT_GAP_DAYS) + combined['drought_index'] = d_daily['drought_index'] + else: + combined['drought_index'] = float('nan') + + # USBR reservoir storage + release: daily series, slow-varying. + # Interpolate short interior gaps. Empty df = unregulated station + # (no mapping entry); the column stays NaN and is handled below. + if reservoir_data is not None and not reservoir_data.empty: + r_daily = _to_daily_series( + reservoir_data, + ['reservoir_storage', 'reservoir_release'], + daily_index, + ) + for col in ('reservoir_storage', 'reservoir_release'): + r_daily[col] = pd.to_numeric(r_daily[col], errors='coerce') + r_daily = r_daily.interpolate(method='time', + limit=MAX_RESERVOIR_GAP_DAYS, + limit_area='inside') + combined['reservoir_storage'] = r_daily['reservoir_storage'] + combined['reservoir_release'] = r_daily['reservoir_release'] + else: + combined['reservoir_storage'] = float('nan') + combined['reservoir_release'] = float('nan') + # Drop rows still missing any core flow/temp value. SWE / soil_moisture # are NOT in CORE_COLUMNS, so a missing one never drops a row. before = len(combined) @@ -138,6 +179,25 @@ def merge_dataframes(noaa_data, flow_data, site_id, start_date, end_date, if pd.isna(median): median = 0.0 combined['soil_moisture'] = sm_series.fillna(median) + + # Drought: 0 means "no drought anywhere in the HUC", which IS a + # legitimate default for a missing weekly value. Fill remaining gaps + # with 0 -- the model can treat absence as "we don't have a drought + # signal here" (which behaves equivalently to no drought). + combined['drought_index'] = combined['drought_index'].fillna(0.0) + + # Reservoir: 0 storage / 0 release reads as "tiny / empty reservoir", + # which IS misleading -- the right semantic for an unregulated gauge + # is "no upstream reservoir at all". A reservoir_observed indicator + # (1 = mapped + retrieval succeeded for this row, 0 = unmapped / no + # data) tells the model when the storage / release columns are real. + # The numeric columns then default to 0 only as a placeholder; the + # indicator carries the truth. + res_series = combined['reservoir_storage'] + combined['reservoir_observed'] = res_series.notna().astype('int64') + combined['reservoir_storage'] = res_series.ffill().bfill().fillna(0.0) + rel_series = combined['reservoir_release'] + combined['reservoir_release'] = rel_series.ffill().bfill().fillna(0.0) logger.info( "Site %s: %d/%d daily rows usable after gap handling", site_id, len(combined), before, @@ -155,15 +215,25 @@ def merge_dataframes(noaa_data, flow_data, site_id, start_date, end_date, return pd.DataFrame() +def _disabled(env_var): + """Truthy-toggle env-var check, used for the per-source ablation levers.""" + return os.getenv(env_var, '').strip() in ('1', 'true', 'True') + + def fetch_and_process_data(prefix, site_id, start_date, end_date, flow_data): """ - Resolve a site's coordinates and fetch its NOAA temperature, HUC SWE and - SMAP soil-moisture series, and its HUC8 basin id (used as a Phase 3 basin - embedding key). - - Returns (noaa_data, swe_data, sm_data, huc8). Any of the dataframes may be - empty (SWE and SM degrade gracefully; missing NOAA causes the caller to - skip the site). huc8 may be None when the lookup fails. + Resolve a site's coordinates and fetch all its auxiliary timeseries: + NOAA temperature, NRCS SWE, SMAP soil moisture, USDM drought, USBR + reservoir storage + release, and the enclosing HUC8 basin id. + + Returns a dict with keys: + noaa_data, swe_data, sm_data, drought_data, reservoir_data, huc8 + Any data frame may be empty (auxiliary features degrade gracefully; only + missing NOAA causes the caller to skip the site). huc8 may be None. + + Per-source env vars short-circuit the corresponding fetch to empty for + ablation runs: + OPENFLOW_DISABLE_SMAP, OPENFLOW_DISABLE_DROUGHT, OPENFLOW_DISABLE_RESERVOIR """ if prefix == "USGS": coords_dict = get_usgs_coordinates(site_id) @@ -174,7 +244,7 @@ def fetch_and_process_data(prefix, site_id, start_date, end_date, flow_data): if not coords_dict: logger.error(f"Could not resolve coordinates for {prefix}:{site_id}. Skipping...") - return None, None, None, None + return None latitude = float(coords_dict['latitude']) longitude = float(coords_dict['longitude']) @@ -186,7 +256,7 @@ def fetch_and_process_data(prefix, site_id, start_date, end_date, flow_data): if noaa_data is None or noaa_data.empty: logger.warning(f"No NOAA data available for site ID {site_id}. Skipping...") - return None, None, None, None + return None noaa_data = noaa_data.copy() noaa_data['USGS_site_ID'] = site_id @@ -202,26 +272,16 @@ def fetch_and_process_data(prefix, site_id, start_date, end_date, flow_data): if not huc8: logger.warning("No HUC8 resolved for %s -- basin embedding will fall back", site_id) - # SWE is a history-window feature; degrade gracefully if the AWDB fetch - # fails -- the row stays, SWE defaults to 0 in merge_dataframes. + # SWE: graceful failure -> empty. try: swe_data = get_swe.get_swe(latitude, longitude, start_date, end_date) except Exception as e: logger.warning("SWE fetch failed for %s: %s", site_id, e) swe_data = pd.DataFrame(columns=['Date', 'SWE']) - # SMAP soil moisture: also a history-window feature, fetched per HUC8. - # Lazy import keeps the heavy earthaccess/h5py stack out of the core - # training-env import path; if it's not installed or any step fails - # (auth, search, download, extraction), sm_data ends up empty and - # merge_dataframes treats it as "no data" (defaults to 0, doesn't drop - # the row), mirroring the SWE handling. - # - # OPENFLOW_DISABLE_SMAP=1 short-circuits to empty -- this is the lever - # for the ablation run (train without SMAP and compare on the held-out - # test set). + # SMAP soil moisture: lazy import (heavy deps), env-var lever, graceful. sm_data = pd.DataFrame(columns=['Date', 'soil_moisture']) - if os.getenv('OPENFLOW_DISABLE_SMAP', '').strip() in ('1', 'true', 'True'): + if _disabled('OPENFLOW_DISABLE_SMAP'): logger.info("OPENFLOW_DISABLE_SMAP set -- skipping SMAP for %s", site_id) else: try: @@ -233,9 +293,39 @@ def fetch_and_process_data(prefix, site_id, start_date, end_date, flow_data): except Exception as e: logger.warning("SMAP fetch failed for %s: %s", site_id, e) - # Gap handling and numeric coercion happen in merge_dataframes (per station, - # on the regular daily index) -- not here, and never via a pooled mean. - return noaa_data, swe_data, sm_data, huc8 + # USDM drought: light deps, env-var lever, graceful. + drought_data = pd.DataFrame(columns=['Date', 'drought_index']) + if _disabled('OPENFLOW_DISABLE_DROUGHT'): + logger.info("OPENFLOW_DISABLE_DROUGHT set -- skipping USDM for %s", site_id) + else: + try: + from data import get_drought + drought_data = get_drought.get_drought(latitude, longitude, start_date, end_date) + except Exception as e: + logger.warning("USDM drought fetch failed for %s: %s", site_id, e) + + # USBR reservoir: needs a station->reservoir mapping; empty when no entry. + reservoir_data = pd.DataFrame( + columns=['Date', 'reservoir_storage', 'reservoir_release']) + if _disabled('OPENFLOW_DISABLE_RESERVOIR'): + logger.info("OPENFLOW_DISABLE_RESERVOIR set -- skipping RISE for %s", + site_id) + else: + try: + from data import get_reservoir + reservoir_data = get_reservoir.get_reservoir( + f"{prefix}:{site_id}", start_date, end_date) + except Exception as e: + logger.warning("USBR reservoir fetch failed for %s: %s", site_id, e) + + return { + 'noaa_data': noaa_data, + 'swe_data': swe_data, + 'sm_data': sm_data, + 'drought_data': drought_data, + 'reservoir_data': reservoir_data, + 'huc8': huc8, + } def get_site_ids(filename=None): @@ -291,15 +381,22 @@ def main(training_num_years=7): logger.warning(f"Unrecognized prefix for site ID {site_id}. Skipping...") continue - noaa_dataframe, swe_dataframe, sm_dataframe, huc8 = fetch_and_process_data( + fetched = fetch_and_process_data( prefix, id, start_date, end_date, flow_dataframe) - if noaa_dataframe is None or noaa_dataframe.empty or flow_dataframe.empty: + if (fetched is None + or fetched['noaa_data'] is None + or fetched['noaa_data'].empty + or flow_dataframe.empty): logger.warning(f"No usable data for site ID {site_id}. Skipping...") continue merged = merge_dataframes( - noaa_dataframe, flow_dataframe, site_id, start_date, end_date, - swe_data=swe_dataframe, sm_data=sm_dataframe, huc8=huc8) + fetched['noaa_data'], flow_dataframe, site_id, start_date, end_date, + swe_data=fetched['swe_data'], + sm_data=fetched['sm_data'], + drought_data=fetched['drought_data'], + reservoir_data=fetched['reservoir_data'], + huc8=fetched['huc8']) if merged.empty: logger.warning(f"No usable merged data for site ID {site_id}. Skipping...") continue diff --git a/openFlowML/data/get_cbrfc.py b/openFlowML/data/get_cbrfc.py new file mode 100644 index 0000000..4c7c80d --- /dev/null +++ b/openFlowML/data/get_cbrfc.py @@ -0,0 +1,216 @@ +""" +NOAA Colorado Basin River Forecast Center (CBRFC) streamflow forecast baseline. + +A second comparison baseline alongside persistence. The model needs to beat +CBRFC's official short-range forecast to claim it's adding value over what an +operational forecaster has access to. + +Public API: + fetch(site_id, anchor_date, horizon_days) -> DataFrame[Date, cbrfc_flow] + Forecast issued on or before `anchor_date`, predictions for the next + `horizon_days`. Empty DataFrame when no usable forecast exists. + + baseline_predictions(test_samples) -> Optional[np.ndarray] + Stack predictions for the entire test set into a (N, horizon, 2) + tensor (same shape as persistence_baseline output in train.py). Returns + None when CBRFC coverage is sparse enough that the comparison would + be meaningless. + +IMPLEMENTATION STATUS: + The CBRFC publishes operational deterministic forecasts via the + Advanced Hydrologic Prediction Service (AHPS) at water.weather.gov, and + Ensemble Streamflow Prediction (ESP) products through their own portal at + cbrfc.noaa.gov. Both have *current-day* access; the **historical archive** + needed for backtesting against our test-set anchor dates is the gap: + + - AHPS does not expose historical issuance via its REST API; the + archived forecasts live in tarballs at + https://water.weather.gov/ahps/download.php + - CBRFC's ESP archive is accessible per-basin via their THREDDS server + but requires a per-issuance lookup that is meaningfully more involved + than this stub. + + Filling either path in (the obvious follow-up commit) lets the baseline + actually evaluate against the test set. Until then, fetch() returns empty + for any anchor_date != today, and baseline_predictions() returns None so + train.py skips the comparison cleanly. +""" + +import logging +from datetime import date, datetime, timedelta +from typing import List, Optional + +import numpy as np +import pandas as pd + +from data.utils import data_utils + +logger = logging.getLogger(__name__) +if not logging.getLogger().hasHandlers(): + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + +# AHPS public site forecast page (one page per gauge by NWS LID). +AHPS_FORECAST_URL = "https://water.weather.gov/ahps2/hydrograph_to_xml.php" +# Default 14-day horizon matches windowing.DECODER_DAYS. +DEFAULT_HORIZON_DAYS = 14 + + +def _empty() -> pd.DataFrame: + return pd.DataFrame(columns=['Date', 'cbrfc_flow']) + + +def _to_date(d) -> date: + if isinstance(d, datetime): + return d.date() + if isinstance(d, date): + return d + return datetime.strptime(str(d)[:10], '%Y-%m-%d').date() + + +def _ahps_lid_for_site(site_id: str) -> Optional[str]: + """ + Map a USGS / DWR site id to an NWS / AHPS LID (5-letter location id). + The mapping isn't algorithmic; it's a curated lookup. Returns None if no + LID is known for the site, in which case fetch() returns empty. + + Populate this when wiring CBRFC for a specific gauge: + {'USGS:09163500': 'CRSC2', ...} + """ + return _AHPS_LID_TABLE.get(site_id) + + +_AHPS_LID_TABLE: dict = { + # site_id -> NWS LID. Empty by default; fill in per gauge as needed. +} + + +def fetch_current(site_id: str, horizon_days: int = DEFAULT_HORIZON_DAYS) -> pd.DataFrame: + """ + Pull the most recent AHPS forecast for `site_id`. Returns DataFrame + [Date, cbrfc_flow] in cfs, one row per forecast day, or an empty frame + when the site has no LID mapping or AHPS returned no usable forecast. + + The AHPS public forecast page only exposes the current forecast issuance, + so this is the "what does CBRFC think tomorrow's flow is right now" + helper. Historical issuances need the archive (see module docstring). + """ + lid = _ahps_lid_for_site(site_id) + if not lid: + logger.info("No AHPS LID mapped for %s; CBRFC fetch skipped", site_id) + return _empty() + params = {'gage': lid, 'output': 'xml'} + response = data_utils.request_with_retry(AHPS_FORECAST_URL, params=params) + if response is None: + return _empty() + rows = _parse_ahps_forecast_xml(response.text, horizon_days) + if not rows: + return _empty() + return pd.DataFrame(rows, columns=['Date', 'cbrfc_flow']) + + +def _parse_ahps_forecast_xml(text: str, horizon_days: int) -> List[tuple]: + """ + Extract daily forecast (date, cfs) rows from the AHPS hydrograph XML. + + AHPS XML wraps a `` block of `` elements, each with a + `` ISO timestamp and a `` numeric value (typically flow + in kcfs or stage in ft -- the gauge metadata tells you which). Lifts the + primary value, collapses sub-daily issuances to daily mean, caps at + horizon_days days from the issuance date. + """ + try: + from xml.etree import ElementTree as ET + except ImportError: + return [] + try: + root = ET.fromstring(text) + except ET.ParseError: + return [] + rows: list = [] + forecast = root.find('forecast') + if forecast is None: + return rows + for datum in forecast.findall('datum'): + valid = datum.findtext('valid') + primary = datum.findtext('primary') + if not valid or primary is None: + continue + try: + d = datetime.fromisoformat(valid.replace('Z', '+00:00')).date() + v = float(primary) + except (ValueError, TypeError): + continue + rows.append((d.strftime('%Y-%m-%d'), v)) + if not rows: + return rows + # Collapse multiple values per day to daily mean. + df = pd.DataFrame(rows, columns=['Date', 'cbrfc_flow']) + daily = df.groupby('Date', as_index=False)['cbrfc_flow'].mean() + daily = daily.sort_values('Date').head(horizon_days) + return list(daily.itertuples(index=False, name=None)) + + +def fetch(site_id: str, anchor_date, + horizon_days: int = DEFAULT_HORIZON_DAYS) -> pd.DataFrame: + """ + CBRFC forecast for `site_id` issued on or before `anchor_date`, for the + `horizon_days` days following the anchor. + + For anchor_date == today, this is equivalent to fetch_current. For any + historical anchor_date, the AHPS API does not expose the issuance; this + returns empty until the historical archive integration lands (see module + docstring). + """ + anchor = _to_date(anchor_date) + if anchor >= date.today(): + return fetch_current(site_id, horizon_days=horizon_days) + logger.debug("CBRFC historical forecast for %s @ %s requires archive integration", + site_id, anchor) + return _empty() + + +def baseline_predictions(test_samples) -> Optional[np.ndarray]: + """ + Per-sample CBRFC forecast aligned with `windowing.WindowedSample` test + items. Returns an (N, horizon, target_features) array on the SCALED flow + target space (so it lines up with model_pred), or None when fewer than + one usable forecast was found and the comparison would be vacuous. + + This is the integration point train.py calls; when filled in, it produces + a third row in the per-horizon MAE table alongside model and persistence. + The shape mirrors `_persistence_pred_for_samples` in train.py. + """ + if not test_samples: + return None + # Until the historical CBRFC archive is wired in there's nothing to + # backtest against; signal "no comparison" cleanly to the caller. + found = 0 + horizon = test_samples[0].target_Y.shape[0] + rows = [] + for sample in test_samples: + # Forecast was issued at the end of the encoder window -- that's + # `anchor_date + encoder_days` for a sample whose anchor_date is the + # first encoder day. WindowedSample doesn't store the forecast-issue + # date explicitly; reconstruct as anchor + (encoder_days - 1). + forecast_issue = (sample.anchor_date + + pd.Timedelta(days=sample.encoder_X.shape[0] - 1)) + forecast = fetch(sample.site_id, forecast_issue, horizon_days=horizon) + if forecast.empty: + rows.append(np.full((horizon, sample.target_Y.shape[1]), np.nan, + dtype='float32')) + continue + found += 1 + # The fetched cbrfc_flow is on the raw cfs scale; train.py knows how + # to invert the scalers if a non-scaled prediction set is provided. + # Broadcast the single CBRFC value across both Min Flow + Max Flow + # target columns until the upstream product distinguishes them. + values = forecast['cbrfc_flow'].astype('float32').to_numpy() + if len(values) < horizon: + padded = np.full(horizon, np.nan, dtype='float32') + padded[:len(values)] = values + values = padded + rows.append(np.tile(values[:horizon, None], (1, sample.target_Y.shape[1]))) + + if found == 0: + return None + return np.stack(rows).astype('float32') diff --git a/openFlowML/data/get_drought.py b/openFlowML/data/get_drought.py new file mode 100644 index 0000000..d0925ba --- /dev/null +++ b/openFlowML/data/get_drought.py @@ -0,0 +1,180 @@ +""" +US Drought Monitor (USDM) drought intensity timeseries by HUC8. + +Wraps the public USDM data service at usdmdataservices.unl.edu. USDM publishes +weekly snapshots of percent area in each drought category (None / D0 / D1 / +D2 / D3 / D4); we collapse those to a single intensity index per week and +forward-fill to daily. + +Public API: + main(lat, lon, start_date, end_date) -> DataFrame[Date, drought_index] + +drought_index = sum_c (D{c}_percent * weight_c) where weight is 1..5 for +D0..D4. Range 0 (no drought anywhere in HUC) to 500 (entire HUC in D4 +exceptional drought). 0 is the legitimate default for missing rows. +""" + +import argparse +import logging +from datetime import date, datetime +from typing import Optional + +import pandas as pd + +from data.utils import data_utils +from data import get_swe + +logger = logging.getLogger(__name__) +if not logging.getLogger().hasHandlers(): + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + +USDM_URL = "https://usdmdataservices.unl.edu/api/HUCStatistics/GetWeeklyHUCStatistics" + +# Ordinal weights for the intensity index. None has weight 0 (it cancels). +_CATEGORY_WEIGHTS = {'D0': 1, 'D1': 2, 'D2': 3, 'D3': 4, 'D4': 5} + + +def _empty() -> pd.DataFrame: + return pd.DataFrame(columns=['Date', 'drought_index']) + + +def _to_date(d) -> date: + if isinstance(d, datetime): + return d.date() + if isinstance(d, date): + return d + return datetime.strptime(str(d)[:10], '%Y-%m-%d').date() + + +def _format_mdy(d) -> str: + """USDM API expects M/D/YYYY (no zero-padding).""" + d = _to_date(d) + return f"{d.month}/{d.day}/{d.year}" + + +def get_drought_weekly(huc_id, start_date, end_date, huc_level=8): + """ + Fetch raw weekly USDM percent-area-by-category records for a HUC. + + Returns a list of dicts (one per ISO week) or [] on any failure. + """ + if not huc_id: + return [] + params = { + 'aoi': str(huc_id), + 'hucLevel': str(huc_level), + 'startdate': _format_mdy(start_date), + 'enddate': _format_mdy(end_date), + 'statisticsType': '2', # percent area by drought category + } + headers = {'Accept': 'application/json'} + response = data_utils.request_with_retry(USDM_URL, params=params, headers=headers) + if response is None: + return [] + try: + payload = response.json() + except ValueError: + return [] + if not isinstance(payload, list): + return [] + return payload + + +def _record_to_index(record) -> float: + """ + Collapse a USDM weekly record (percent area per category) into a single + weighted-intensity index. Missing category fields treated as 0. + """ + total = 0.0 + for cat, weight in _CATEGORY_WEIGHTS.items(): + try: + total += float(record.get(cat, 0) or 0) * weight + except (TypeError, ValueError): + continue + return total + + +def _parse_record_date(record) -> Optional[date]: + """ + USDM payloads expose the snapshot date under a few possible keys; tolerate + all of them. + """ + for key in ('MapDate', 'ValidStart', 'validStart', 'mapDate'): + raw = record.get(key) + if not raw: + continue + raw = str(raw)[:10] + for fmt in ('%Y-%m-%d', '%Y%m%d'): + try: + return datetime.strptime(raw, fmt).date() + except ValueError: + continue + return None + + +def get_drought(lat, lon, start_date, end_date, huc_level=8) -> pd.DataFrame: + """ + End-to-end: lat/lon -> HUC -> USDM weekly intensity -> daily DataFrame. + + Weekly records are forward-filled across the daily index (USDM is a weekly + snapshot; values stay constant for the week). Empty DataFrame on any + failure (HUC lookup, API, parsing) so the spine treats it as "no data". + """ + try: + huc_id = get_swe.get_huc_id(lat, lon, level=huc_level) + except Exception as e: + logger.warning("HUC%d lookup failed for (%s, %s): %s", huc_level, lat, lon, e) + return _empty() + if not huc_id: + logger.warning("Could not resolve HUC%d for (%s, %s)", huc_level, lat, lon) + return _empty() + + records = get_drought_weekly(huc_id, start_date, end_date, huc_level=huc_level) + if not records: + logger.warning("No USDM records for HUC %s in [%s, %s]", + huc_id, start_date, end_date) + return _empty() + + rows = [] + for record in records: + d = _parse_record_date(record) + if d is None: + continue + rows.append((d.strftime('%Y-%m-%d'), _record_to_index(record))) + + if not rows: + return _empty() + + weekly = (pd.DataFrame(rows, columns=['Date', 'drought_index']) + .drop_duplicates(subset=['Date']) + .sort_values('Date')) + weekly['Date'] = pd.to_datetime(weekly['Date']) + weekly = weekly.set_index('Date') + + # Forward-fill weekly snapshots over the daily index. + daily_index = pd.date_range( + pd.Timestamp(_to_date(start_date)), + pd.Timestamp(_to_date(end_date)), + freq='D', name='Date', + ) + daily = weekly.reindex(daily_index, method='ffill') + daily = daily.reset_index() + daily['Date'] = daily['Date'].dt.strftime('%Y-%m-%d') + return daily[['Date', 'drought_index']] + + +def main(lat, lon, start_date, end_date, huc_level=8) -> pd.DataFrame: + df = get_drought(lat, lon, start_date, end_date, huc_level=huc_level) + data_utils.preview_data(df) + return df + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Fetch USDM drought intensity timeseries by HUC.') + parser.add_argument('--lat', type=float, required=True) + parser.add_argument('--lon', type=float, required=True) + parser.add_argument('--start-date', type=str, required=True, help='YYYY-MM-DD') + parser.add_argument('--end-date', type=str, required=True, help='YYYY-MM-DD') + parser.add_argument('--huc-level', type=int, default=8, choices=[2, 4, 6, 8, 10, 12]) + args = parser.parse_args() + main(args.lat, args.lon, args.start_date, args.end_date, huc_level=args.huc_level) diff --git a/openFlowML/data/get_reservoir.py b/openFlowML/data/get_reservoir.py new file mode 100644 index 0000000..7ec5357 --- /dev/null +++ b/openFlowML/data/get_reservoir.py @@ -0,0 +1,222 @@ +""" +USBR RISE reservoir storage + release timeseries by site. + +For regulated rivers, downstream flow is heavily driven by reservoir +operations (releases) and storage state (how full the basin is). This module +maps a site_id to one or more upstream reservoirs via .github/reservoir_mapping.txt +and pulls daily storage (acre-feet) + release (cfs) from the public USBR RISE +catalog API. + +Public API: + main(site_id, start_date, end_date) -> DataFrame[Date, reservoir_storage, + reservoir_release] + +Stations with no mapping entry (e.g. unregulated headwater gauges) get an +empty DataFrame back; combine_data treats that as "0 / not observed", which +is the correct semantic. +""" + +import argparse +import logging +import os +from datetime import date, datetime +from typing import List, Optional, Tuple + +import pandas as pd + +from data.utils import data_utils + +logger = logging.getLogger(__name__) +if not logging.getLogger().hasHandlers(): + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + +RISE_RESULT_URL = "https://data.usbr.gov/rise/api/result" + +# Reservoir mapping config lives alongside site_ids.txt in .github/. +_DEFAULT_MAPPING = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), + '.github', 'reservoir_mapping.txt', +) + + +def _empty() -> pd.DataFrame: + return pd.DataFrame(columns=['Date', 'reservoir_storage', 'reservoir_release']) + + +def _to_date(d) -> date: + if isinstance(d, datetime): + return d.date() + if isinstance(d, date): + return d + return datetime.strptime(str(d)[:10], '%Y-%m-%d').date() + + +def load_mapping(path: Optional[str] = None) -> dict: + """ + Parse the reservoir-mapping config into + {site_id: [(reservoir_name, storage_itemId, release_itemId), ...]}. + + Stations without entries are simply absent from the returned dict, which + the caller treats as "no reservoir for this station". + """ + if path is None: + path = _DEFAULT_MAPPING + mapping: dict = {} + if not os.path.exists(path): + logger.info("No reservoir mapping file at %s -- all stations unregulated", path) + return mapping + try: + with open(path, 'r') as f: + for raw in f: + line = raw.strip() + if not line or line.startswith('#'): + continue + parts = [p.strip() for p in line.split(',')] + if len(parts) < 3: + logger.warning("Skipping malformed reservoir mapping line: %s", raw.rstrip()) + continue + site_id = parts[0] + reservoir = parts[1] + storage = parts[2] or None + release = parts[3] if len(parts) >= 4 and parts[3] else None + mapping.setdefault(site_id, []).append((reservoir, storage, release)) + except OSError as e: + logger.warning("Could not read reservoir mapping %s: %s", path, e) + return mapping + + +def fetch_rise_series(item_id, start_date, end_date) -> pd.DataFrame: + """ + Pull a single RISE catalog-item timeseries as a [Date, value] frame. + + RISE's `/result` endpoint paginates; the public API caps page size, so we + walk forward until the next-page URL is exhausted. Empty DataFrame on any + failure -- the spine treats it as missing. + """ + if not item_id: + return pd.DataFrame(columns=['Date', 'value']) + + start = _to_date(start_date) + end = _to_date(end_date) + params = { + 'itemId': str(item_id), + 'after': start.strftime('%Y-%m-%dT00:00:00Z'), + 'before': end.strftime('%Y-%m-%dT23:59:59Z'), + 'order': 'ASC', + 'itemsPerPage': '10000', + } + headers = {'Accept': 'application/vnd.api+json'} + + rows: List[Tuple[str, float]] = [] + url = RISE_RESULT_URL + next_params = params + safety = 200 # hard cap on pagination loops to avoid runaway requests + while url and safety > 0: + safety -= 1 + response = data_utils.request_with_retry(url, params=next_params, headers=headers) + if response is None: + break + try: + payload = response.json() + except ValueError: + break + for item in (payload.get('data') or []): + attrs = item.get('attributes') or {} + ts = attrs.get('dateTime') or attrs.get('resultDateTime') + val = attrs.get('result') + if ts is None or val is None: + continue + try: + rows.append((str(ts)[:10], float(val))) + except (TypeError, ValueError): + continue + # Next-page link: JSON:API uses links.next as an absolute URL with + # query string baked in; once we follow it we drop params. + links = payload.get('links') or {} + nxt = links.get('next') + if not nxt or nxt == url: + break + url = nxt + next_params = None + + if not rows: + return pd.DataFrame(columns=['Date', 'value']) + df = pd.DataFrame(rows, columns=['Date', 'value']) + # Collapse any duplicate dates within the series (hourly -> daily). + df = df.groupby('Date', as_index=False)['value'].mean() + return df.sort_values('Date').reset_index(drop=True) + + +def get_reservoir(site_id, start_date, end_date, + mapping_path: Optional[str] = None) -> pd.DataFrame: + """ + Resolve the reservoir(s) mapped to site_id, fetch storage + release from + RISE, and return a daily DataFrame[Date, reservoir_storage, reservoir_release] + over the requested window. + + If multiple reservoirs map to one site, storage and release are summed + (total water held back upstream, total outflow). Returns an empty frame + when there is no mapping entry, mirroring the SWE / SMAP graceful pattern. + """ + mapping = load_mapping(mapping_path) + entries = mapping.get(site_id) + if not entries: + return _empty() + + storage_frames = [] + release_frames = [] + for reservoir_name, storage_id, release_id in entries: + if storage_id: + ts = fetch_rise_series(storage_id, start_date, end_date) + if not ts.empty: + ts = ts.rename(columns={'value': 'reservoir_storage'}) + storage_frames.append(ts) + if release_id: + tr = fetch_rise_series(release_id, start_date, end_date) + if not tr.empty: + tr = tr.rename(columns={'value': 'reservoir_release'}) + release_frames.append(tr) + logger.info("Site %s reservoir %s: storage=%s release=%s", + site_id, reservoir_name, + 'yes' if storage_id else '-', + 'yes' if release_id else '-') + + def _sum_by_date(frames, col): + if not frames: + return pd.DataFrame(columns=['Date', col]) + combined = pd.concat(frames, ignore_index=True) + return combined.groupby('Date', as_index=False)[col].sum() + + storage_total = _sum_by_date(storage_frames, 'reservoir_storage') + release_total = _sum_by_date(release_frames, 'reservoir_release') + + if storage_total.empty and release_total.empty: + return _empty() + + merged = (storage_total + .merge(release_total, on='Date', how='outer') + .sort_values('Date') + .reset_index(drop=True)) + if 'reservoir_storage' not in merged.columns: + merged['reservoir_storage'] = float('nan') + if 'reservoir_release' not in merged.columns: + merged['reservoir_release'] = float('nan') + return merged[['Date', 'reservoir_storage', 'reservoir_release']] + + +def main(site_id, start_date, end_date, + mapping_path: Optional[str] = None) -> pd.DataFrame: + df = get_reservoir(site_id, start_date, end_date, mapping_path=mapping_path) + data_utils.preview_data(df) + return df + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Fetch USBR RISE reservoir storage + release for a site.') + parser.add_argument('--site-id', type=str, required=True, help='e.g. USGS:09163500') + parser.add_argument('--start-date', type=str, required=True, help='YYYY-MM-DD') + parser.add_argument('--end-date', type=str, required=True, help='YYYY-MM-DD') + parser.add_argument('--mapping', type=str, default=None, + help='Override path to reservoir_mapping.txt') + args = parser.parse_args() + main(args.site_id, args.start_date, args.end_date, mapping_path=args.mapping) diff --git a/openFlowML/data/get_s2f.py b/openFlowML/data/get_s2f.py new file mode 100644 index 0000000..9b7921e --- /dev/null +++ b/openFlowML/data/get_s2f.py @@ -0,0 +1,97 @@ +""" +USBR Snow-to-Flow (S2F) seasonal forecast baseline. + +Reclamation publishes monthly seasonal water-supply forecasts (typically +April-July or April-September runoff *volume*) for major reservoirs in the +Upper Colorado and other basins. These are the operational statistical +baselines water managers actually use. + +Public API: + fetch(site_id, anchor_date) -> DataFrame[Date, s2f_volume_kaf] + baseline_predictions(test_samples) -> Optional[np.ndarray] + +IMPLEMENTATION STATUS + TIMESCALE CAVEAT: + S2F predicts a *seasonal volume* (e.g. April-July total runoff into + Reservoir X, in kAF). Our model predicts *daily* flow over a 14-day + horizon. Comparing them apples-to-apples is awkward: you'd have to either + + (a) aggregate our 14-day prediction into a fraction-of-seasonal-volume + contribution and compare to S2F's fractional progress estimate, or + (b) disaggregate S2F's seasonal volume into a daily climatological + shape and compare daily. + + Approach (b) is what this stub assumes when filled in -- given a seasonal + forecast V_kAF and a per-site climatology of the daily distribution + inside the forecast season, predict daily_flow[d] = V_kAF * daily_share[d]. + That requires a per-site daily-share lookup that doesn't exist in the + repo yet; building it is part of the follow-up. + + The data fetch itself: USBR publishes S2F products at + https://www.usbr.gov/uc/water/crsp/wsf/ (Upper Colorado WSF) + https://www.usbr.gov/pn/hydromet/forecast.html (Pacific Northwest) + as month-by-month CSV / PDFs that need scraping. There's no clean REST + archive. Until that's wired, fetch() returns empty and + baseline_predictions() returns None so train.py skips this baseline. +""" + +import logging +from datetime import date, datetime +from typing import Optional + +import numpy as np +import pandas as pd + +logger = logging.getLogger(__name__) +if not logging.getLogger().hasHandlers(): + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + + +def _empty() -> pd.DataFrame: + return pd.DataFrame(columns=['Date', 's2f_volume_kaf']) + + +def _to_date(d) -> date: + if isinstance(d, datetime): + return d.date() + if isinstance(d, date): + return d + return datetime.strptime(str(d)[:10], '%Y-%m-%d').date() + + +def fetch(site_id: str, anchor_date) -> pd.DataFrame: + """ + S2F seasonal-volume forecast for `site_id`, as of the month containing + `anchor_date`. Returns a single-row DataFrame [Date, s2f_volume_kaf] + when found, or empty otherwise. + + Currently stubbed: USBR S2F doesn't expose a clean REST archive, so this + always returns empty until the per-basin scraper is built. The function + signature is the integration point so combine_data / baselines can adopt + the real fetch without changing callers. + """ + logger.debug("S2F fetch stubbed for %s @ %s -- archive integration pending", + site_id, _to_date(anchor_date)) + return _empty() + + +def baseline_predictions(test_samples) -> Optional[np.ndarray]: + """ + Per-sample S2F-derived daily prediction aligned with WindowedSample test + items. Returns (N, horizon, target_features) on the raw flow scale when + feasible, or None when the S2F archive isn't wired in (current state). + + See the module docstring for the daily-disaggregation approach this + expects when fully implemented; until then it's a no-op returning None + so train.py reports "S2F: not available" and moves on. + """ + if not test_samples: + return None + found = 0 + for sample in test_samples: + if not fetch(sample.site_id, sample.anchor_date).empty: + found += 1 + if found == 0: + return None + # The disaggregation step (seasonal kAF -> daily cfs via per-site + # climatological share) belongs here once fetch() returns real data. + return None diff --git a/openFlowML/normalize_data.py b/openFlowML/normalize_data.py index cc8122c..c5c463e 100644 --- a/openFlowML/normalize_data.py +++ b/openFlowML/normalize_data.py @@ -27,17 +27,17 @@ # Required: rows missing any of these are dropped (no pooled-mean fill). CORE_REQUIRED = ['TMIN', 'TMAX', 'Min Flow', 'Max Flow'] -# Optional continuous columns that get scaled when present. SWE and -# soil_moisture are slow-varying; if either is missing for a row, default to -# 0 rather than dropping the row (combine_data does smarter SM handling, but -# this is the safety net). -OPTIONAL_NUMERIC = ['SWE', 'soil_moisture'] +# Optional continuous columns that get scaled when present. SWE, soil_moisture, +# drought_index, and the two reservoir columns are all slow-varying auxiliaries +# that combine_data imputes when missing rather than dropping the row. +OPTIONAL_NUMERIC = ['SWE', 'soil_moisture', 'drought_index', + 'reservoir_storage', 'reservoir_release'] # Binary indicator columns: included as model inputs but NOT z-scored. Scaling # a 0/1 indicator destroys the semantics (the model needs to see 0 vs 1 as -# distinct cases, not as samples from a centered normal). sm_observed marks -# rows where soil_moisture is a real / short-gap-interpolated SMAP retrieval -# vs imputed via the median fallback in combine_data. -INDICATOR_COLUMNS = ['sm_observed'] +# distinct cases, not as samples from a centered normal). +# sm_observed -- 1 = real / short-gap-interpolated SMAP retrieval +# reservoir_observed -- 1 = station has a USBR mapping AND fetch succeeded +INDICATOR_COLUMNS = ['sm_observed', 'reservoir_observed'] NUMERIC_COLUMNS = CORE_REQUIRED + OPTIONAL_NUMERIC # Streamflow is log-normal -- log1p before z-scoring is standard hydrology # practice. Temperature/SWE stay linear. @@ -155,14 +155,15 @@ def normalize_data(data, artifacts_dir=None): for column in present_numeric + present_indicator: data[column] = pd.to_numeric(data[column], errors='coerce') - # SWE and soil_moisture are slowly varying and often legitimately near - # zero -- any remaining missing value here defaults to 0 instead of - # forcing the row out, so a station with no nearby SNOTEL / no SMAP - # retrieval that day still contributes. - if 'SWE' in data.columns: - data['SWE'] = data['SWE'].fillna(0.0) - if 'soil_moisture' in data.columns: - data['soil_moisture'] = data['soil_moisture'].fillna(0.0) + # All optional numerics are slow-varying auxiliaries; combine_data + # imputes them with smarter per-source logic. The safety net here + # defaults any remaining missing value to 0 instead of dropping the + # row, so a station with no nearby SNOTEL / no SMAP retrieval / + # outside USDM coverage / no upstream reservoir still contributes. + for column in ('SWE', 'soil_moisture', 'drought_index', + 'reservoir_storage', 'reservoir_release'): + if column in data.columns: + data[column] = data[column].fillna(0.0) # Indicators default to 0 ("not observed") when missing. for column in present_indicator: data[column] = data[column].fillna(0).astype('int64') diff --git a/openFlowML/train.py b/openFlowML/train.py index ba88367..69a265f 100644 --- a/openFlowML/train.py +++ b/openFlowML/train.py @@ -83,6 +83,8 @@ def main(): required = {'station_idx', 'basin_idx', 'site_id', 'Date', 'Min Flow', 'Max Flow', 'TMIN', 'TMAX', 'SWE', 'soil_moisture', 'sm_observed', + 'drought_index', + 'reservoir_storage', 'reservoir_release', 'reservoir_observed', 'doy_sin', 'doy_cos'} missing = required - set(data.columns) if missing: @@ -144,17 +146,46 @@ def main(): verbose=2, ) - # 7. Evaluate against the persistence baseline on the held-out test set. + # 7. Evaluate against persistence + any external operational baselines + # on the held-out test set. External baselines (CBRFC, S2F) return + # None when their archive integration isn't wired in yet, in which + # case they're silently skipped. if test_inputs is not None and len(splits.test) > 0: model_pred = net.predict(test_inputs, verbose=0) model_mae_per_h = _summarize_per_horizon(test_targets, model_pred) persistence_pred = _persistence_pred_for_samples(splits.test) persistence_mae_per_h = _summarize_per_horizon(test_targets, persistence_pred) + + external_mae_per_h = {} + try: + from data import get_cbrfc + cbrfc_pred = get_cbrfc.baseline_predictions(splits.test) + if cbrfc_pred is not None: + external_mae_per_h['cbrfc'] = _summarize_per_horizon( + test_targets, cbrfc_pred) + except Exception as e: + logger.warning("CBRFC baseline skipped: %s", e) + try: + from data import get_s2f + s2f_pred = get_s2f.baseline_predictions(splits.test) + if s2f_pred is not None: + external_mae_per_h['s2f'] = _summarize_per_horizon( + test_targets, s2f_pred) + except Exception as e: + logger.warning("S2F baseline skipped: %s", e) + logger.info("Test MAE (scaled space) per horizon day:") - for k, (m, p) in enumerate(zip(model_mae_per_h, persistence_mae_per_h), start=1): + for k in range(len(model_mae_per_h)): + m = model_mae_per_h[k] + p = persistence_mae_per_h[k] verdict = "BEATS" if m < p else "LOSES_TO" - logger.info(" day %2d: model=%.4f persistence=%.4f (%s baseline)", - k, m, p, verdict) + extras = "".join( + f" {name}={vals[k]:.4f}" + for name, vals in external_mae_per_h.items()) + logger.info(" day %2d: model=%.4f persistence=%.4f (%s persistence)%s", + k + 1, m, p, verdict, extras) + for name in external_mae_per_h: + logger.info("External baseline available: %s", name) # 8. Persist the model + training config alongside the scaler/index JSON # that combine_data already wrote. These five artifacts together are diff --git a/openFlowML/windowing.py b/openFlowML/windowing.py index c3c44d5..57c6a9e 100644 --- a/openFlowML/windowing.py +++ b/openFlowML/windowing.py @@ -24,19 +24,23 @@ logger = logging.getLogger(__name__) -# Encoder window: everything we know up to the prediction time. Flow is here -# (these are observations); SWE -- current snowpack is a strong predictor of -# snowmelt-fed runoff in Colorado; soil_moisture (SMAP L3 enhanced) is the -# antecedent wetness state that gates how much new precipitation becomes runoff -# vs infiltrates; sm_observed is the 0/1 indicator that flags whether -# soil_moisture for that day was a real SMAP retrieval or imputed by -# combine_data's median fallback. +# Encoder window: everything we know up to the prediction time. +# flow (observations); SWE -- snowpack drives snowmelt-fed runoff in CO; +# soil_moisture (SMAP L3 enhanced) -- antecedent wetness gates infiltration +# vs runoff; sm_observed -- 1 = real / short-gap-interpolated SMAP, 0 = +# imputed via combine_data's median fallback; drought_index (USDM) -- HUC8 +# drought intensity 0..500; reservoir_storage / reservoir_release (USBR +# RISE) -- upstream regulation state for regulated rivers, defaults to 0 +# for unmapped (unregulated) stations; reservoir_observed -- 1 when the +# station actually has reservoir data, 0 when it's the unregulated default. ENCODER_FEATURES = ['Min Flow', 'Max Flow', 'TMIN', 'TMAX', 'SWE', 'soil_moisture', 'sm_observed', + 'drought_index', + 'reservoir_storage', 'reservoir_release', 'reservoir_observed', 'doy_sin', 'doy_cos'] # Decoder window: ONLY features available at forecast time. No flow (that's -# what we're predicting), no SWE / no soil_moisture (neither has a skillful -# 14-day forecast available). +# what we're predicting), no SWE / no soil_moisture / no drought / no +# reservoir state (none of these have a skillful 14-day forecast available). DECODER_FEATURES = ['TMIN', 'TMAX', 'doy_sin', 'doy_cos'] # Target: log-z-scored flow during the decoder days. Both columns are already # log1p-transformed and z-scored by normalize_data, so MSE/Huber here behaves. diff --git a/tests/test_combine_data.py b/tests/test_combine_data.py index b47c825..45bcc78 100644 --- a/tests/test_combine_data.py +++ b/tests/test_combine_data.py @@ -190,6 +190,73 @@ def test_merge_defaults_soil_moisture_to_zero_only_when_no_observations(): assert (merged['sm_observed'] == 0).all() +def test_merge_attaches_drought_index_with_ffill(): + noaa, flow = _make_frames() + dates = pd.date_range('2022-01-01', '2022-01-20', freq='D') + # USDM weekly snapshots on Jan 4 and Jan 11. + drought = pd.DataFrame({ + 'Date': [dates[3], dates[10]], + 'drought_index': [100.0, 250.0], + }) + merged = combine_data.merge_dataframes( + noaa, flow, 'USGS:TEST', + datetime(2022, 1, 1), datetime(2022, 1, 20), + drought_data=drought) + assert 'drought_index' in merged.columns + rows = merged.set_index(pd.to_datetime(merged['Date'])) + # Pre-first-snapshot days have no value to ffill from -> 0 (the default). + assert rows.loc['2022-01-01', 'drought_index'] == 0.0 + # The snapshot day takes the first value. + assert rows.loc['2022-01-04', 'drought_index'] == 100.0 + # ffilled through the week. + assert rows.loc['2022-01-10', 'drought_index'] == 100.0 + # Next snapshot kicks in. + assert rows.loc['2022-01-11', 'drought_index'] == 250.0 + assert rows.loc['2022-01-20', 'drought_index'] == 250.0 + + +def test_merge_defaults_drought_index_to_zero_when_not_provided(): + noaa, flow = _make_frames() + merged = combine_data.merge_dataframes( + noaa, flow, 'USGS:TEST', datetime(2022, 1, 1), datetime(2022, 1, 20)) + assert 'drought_index' in merged.columns + assert (merged['drought_index'] == 0.0).all() + + +def test_merge_attaches_reservoir_storage_and_release(): + noaa, flow = _make_frames() + dates = pd.date_range('2022-01-01', '2022-01-20', freq='D') + reservoir = pd.DataFrame({ + 'Date': dates[::5], # every 5 days + 'reservoir_storage': [1000.0, 1100.0, 1050.0, 1000.0], + 'reservoir_release': [50.0, 60.0, 55.0, 50.0], + }) + merged = combine_data.merge_dataframes( + noaa, flow, 'USGS:TEST', + datetime(2022, 1, 1), datetime(2022, 1, 20), + reservoir_data=reservoir) + assert 'reservoir_storage' in merged.columns + assert 'reservoir_release' in merged.columns + assert 'reservoir_observed' in merged.columns + # Observed where the reservoir series provided a value (post-interpolation). + assert merged['reservoir_observed'].sum() > 0 + # Storage stays in the observed range after ffill / bfill. + s_min, s_max = 1000.0, 1100.0 + assert merged['reservoir_storage'].max() <= s_max + assert merged['reservoir_storage'].min() >= s_min + + +def test_merge_defaults_reservoir_to_zero_when_unregulated(): + # Unregulated station: no reservoir_data passed -> storage / release default + # to 0 and reservoir_observed stays 0. + noaa, flow = _make_frames() + merged = combine_data.merge_dataframes( + noaa, flow, 'USGS:TEST', datetime(2022, 1, 1), datetime(2022, 1, 20)) + assert (merged['reservoir_storage'] == 0.0).all() + assert (merged['reservoir_release'] == 0.0).all() + assert (merged['reservoir_observed'] == 0).all() + + def test_merge_records_huc8_when_provided(): noaa, flow = _make_frames() merged = combine_data.merge_dataframes( diff --git a/tests/test_drought.py b/tests/test_drought.py new file mode 100644 index 0000000..b755c2d --- /dev/null +++ b/tests/test_drought.py @@ -0,0 +1,114 @@ +""" +Tests for get_drought.main / get_drought helpers (USDM intensity-by-HUC). + +USDM is a clean REST endpoint; we mock with requests_mock and don't hit the +network. The HUC lookup is patched out per test so we can isolate the USDM +parsing + forward-fill logic from the WBD ArcGIS dependency. +""" + +from datetime import datetime + +import pandas as pd +import pytest +import requests_mock as _rm + +from data import get_drought + + +@pytest.fixture +def mock_api(): + with _rm.Mocker() as m: + yield m + + +def test_intensity_weights_collapse_categories_correctly(): + # 100% D2 == weight 3 == index 300; D0..D4 weights are ordinal (1..5). + assert get_drought._record_to_index( + {'D0': 0, 'D1': 0, 'D2': 100, 'D3': 0, 'D4': 0}) == 300.0 + # 50% D1 + 50% D3 == 0.5*2 + 0.5*4 = 3 -- but expressed as percent it's + # 50 + 50 -> 50*2 + 50*4 = 300. + assert get_drought._record_to_index( + {'D0': 0, 'D1': 50, 'D2': 0, 'D3': 50, 'D4': 0}) == 300.0 + # No drought anywhere -> 0. + assert get_drought._record_to_index( + {'D0': 0, 'D1': 0, 'D2': 0, 'D3': 0, 'D4': 0}) == 0.0 + # Maximum: 100% D4 -> 500. + assert get_drought._record_to_index( + {'D0': 0, 'D1': 0, 'D2': 0, 'D3': 0, 'D4': 100}) == 500.0 + + +def test_intensity_treats_missing_categories_as_zero(): + # USDM occasionally omits zero-category fields; missing == 0, not error. + assert get_drought._record_to_index({'D4': 100}) == 500.0 + assert get_drought._record_to_index({}) == 0.0 + + +def test_format_mdy_strips_zero_padding(): + # USDM API requires M/D/YYYY (NOT MM/DD/YYYY). + assert get_drought._format_mdy('2024-01-05') == '1/5/2024' + assert get_drought._format_mdy(datetime(2024, 12, 31)) == '12/31/2024' + + +def test_get_drought_weekly_parses_records(mock_api): + mock_api.get( + get_drought.USDM_URL, + json=[ + {'MapDate': '2024-01-02', 'D0': 100, 'D1': 0, 'D2': 0, 'D3': 0, 'D4': 0}, + {'MapDate': '2024-01-09', 'D0': 0, 'D1': 50, 'D2': 50, 'D3': 0, 'D4': 0}, + ], + ) + recs = get_drought.get_drought_weekly('14010001', '2024-01-01', '2024-01-15') + assert len(recs) == 2 + + +def test_get_drought_returns_empty_when_huc_lookup_fails(monkeypatch): + monkeypatch.setattr(get_drought.get_swe, 'get_huc_id', lambda *a, **kw: None) + df = get_drought.get_drought(40.0, -106.0, '2024-01-01', '2024-01-15') + assert df.empty + assert list(df.columns) == ['Date', 'drought_index'] + + +def test_get_drought_forward_fills_weekly_to_daily(mock_api, monkeypatch): + monkeypatch.setattr(get_drought.get_swe, 'get_huc_id', lambda *a, **kw: '14010001') + mock_api.get( + get_drought.USDM_URL, + json=[ + {'MapDate': '2024-01-02', 'D0': 100, 'D1': 0, 'D2': 0, 'D3': 0, 'D4': 0}, + {'MapDate': '2024-01-09', 'D0': 0, 'D1': 50, 'D2': 50, 'D3': 0, 'D4': 0}, + ], + ) + df = get_drought.get_drought(40.0, -106.0, '2024-01-01', '2024-01-15') + # Daily reindex over Jan 1..15 = 15 rows. + assert len(df) == 15 + rows = df.set_index('Date') + # Days before the first weekly snapshot stay NaN (we used ffill only -- + # the spine fills those with 0). + assert pd.isna(rows.loc['2024-01-01', 'drought_index']) + # Days from Jan 2 through Jan 8 inherit the first snapshot's intensity + # (100% D0 -> weight 1 -> 100). + assert rows.loc['2024-01-02', 'drought_index'] == 100.0 + assert rows.loc['2024-01-08', 'drought_index'] == 100.0 + # Days from Jan 9 onward inherit the second snapshot (50*2 + 50*3 = 250). + assert rows.loc['2024-01-09', 'drought_index'] == 250.0 + assert rows.loc['2024-01-15', 'drought_index'] == 250.0 + + +def test_get_drought_returns_empty_when_api_returns_nothing(mock_api, monkeypatch): + monkeypatch.setattr(get_drought.get_swe, 'get_huc_id', lambda *a, **kw: '14010001') + mock_api.get(get_drought.USDM_URL, json=[]) + df = get_drought.get_drought(40.0, -106.0, '2024-01-01', '2024-01-15') + assert df.empty + + +def test_get_drought_tolerates_alternate_date_keys(mock_api, monkeypatch): + monkeypatch.setattr(get_drought.get_swe, 'get_huc_id', lambda *a, **kw: '14010001') + mock_api.get( + get_drought.USDM_URL, + json=[ + # USDM has been observed to return ValidStart instead of MapDate. + {'ValidStart': '20240102', 'D0': 100, 'D1': 0, 'D2': 0, 'D3': 0, 'D4': 0}, + ], + ) + df = get_drought.get_drought(40.0, -106.0, '2024-01-01', '2024-01-15') + assert len(df) == 15 + assert df.set_index('Date').loc['2024-01-02', 'drought_index'] == 100.0 diff --git a/tests/test_external_baselines.py b/tests/test_external_baselines.py new file mode 100644 index 0000000..82dd7b2 --- /dev/null +++ b/tests/test_external_baselines.py @@ -0,0 +1,64 @@ +""" +Smoke tests for the CBRFC + S2F baseline modules. + +Both modules expose a `baseline_predictions(test_samples)` hook that train.py +calls during evaluation. Until the historical archive integrations are wired +in (see module docstrings), both must return None cleanly -- not crash, not +log noise, not pollute the persistence comparison. These tests pin that +contract so the train.py integration stays safe. +""" + +import numpy as np +import pandas as pd +import pytest + +from data import get_cbrfc, get_s2f +from windowing import WindowedSample + + +def _sample(site_id='USGS:09163500'): + """A single WindowedSample-shaped object sufficient for baseline_predictions.""" + return WindowedSample( + encoder_X=np.zeros((60, 9), dtype='float32'), + decoder_X=np.zeros((14, 4), dtype='float32'), + target_Y=np.zeros((14, 2), dtype='float32'), + persistence_anchor=np.zeros(2, dtype='float32'), + station_idx=1, + basin_idx=1, + site_id=site_id, + anchor_date=pd.Timestamp('2024-01-01'), + ) + + +def test_cbrfc_baseline_returns_none_when_no_lid_mapped(): + # The default _AHPS_LID_TABLE is empty -> every site has no LID -> every + # fetch returns empty -> the predictions stack is None (skip cleanly). + pred = get_cbrfc.baseline_predictions([_sample()]) + assert pred is None + + +def test_cbrfc_baseline_returns_none_for_empty_test_set(): + assert get_cbrfc.baseline_predictions([]) is None + + +def test_cbrfc_fetch_current_skips_when_no_lid(): + df = get_cbrfc.fetch_current('USGS:UNKNOWN') + assert df.empty + assert list(df.columns) == ['Date', 'cbrfc_flow'] + + +def test_cbrfc_fetch_historical_returns_empty_until_archive_wired(): + # Anchor in the past -> archive lookup is stubbed -> empty. + df = get_cbrfc.fetch('USGS:09163500', '2024-01-01') + assert df.empty + + +def test_s2f_baseline_returns_none_until_archive_wired(): + assert get_s2f.baseline_predictions([_sample()]) is None + assert get_s2f.baseline_predictions([]) is None + + +def test_s2f_fetch_returns_empty_until_archive_wired(): + df = get_s2f.fetch('USGS:09163500', '2024-01-01') + assert df.empty + assert list(df.columns) == ['Date', 's2f_volume_kaf'] diff --git a/tests/test_normalize_data.py b/tests/test_normalize_data.py index 09a9edc..5756cd1 100644 --- a/tests/test_normalize_data.py +++ b/tests/test_normalize_data.py @@ -149,6 +149,34 @@ def test_normalize_keeps_sm_observed_as_binary_indicator(tmp_path): assert 'sm_observed' not in scalers +def test_normalize_scales_drought_and_reservoir_when_present(tmp_path): + import json + df = _make_combined() + df['drought_index'] = np.linspace(0.0, 400.0, 20) + df['reservoir_storage'] = np.linspace(800_000.0, 1_200_000.0, 20) + df['reservoir_release'] = np.linspace(50.0, 250.0, 20) + out = normalize_data.normalize_data(df, artifacts_dir=str(tmp_path)) + assert out is not None + scalers = json.loads((tmp_path / 'scalers.json').read_text()) + for col in ('drought_index', 'reservoir_storage', 'reservoir_release'): + assert col in scalers + # All three are continuous; z-scored, no log transform. + assert scalers[col]['transform'] == 'identity' + assert abs(out[col].mean()) < 1e-6 + assert abs(out[col].std(ddof=0) - 1.0) < 1e-6 + + +def test_normalize_keeps_reservoir_observed_as_binary(tmp_path): + import json + df = _make_combined() + df['reservoir_observed'] = [1, 0] * 10 + out = normalize_data.normalize_data(df, artifacts_dir=str(tmp_path)) + assert out is not None + assert set(out['reservoir_observed'].unique()) <= {0, 1} + scalers = json.loads((tmp_path / 'scalers.json').read_text()) + assert 'reservoir_observed' not in scalers + + def test_normalize_defaults_sm_observed_to_zero_when_missing(): df = _make_combined() df['sm_observed'] = [1, np.nan, 1, np.nan] * 5 diff --git a/tests/test_reservoir.py b/tests/test_reservoir.py new file mode 100644 index 0000000..127e5c6 --- /dev/null +++ b/tests/test_reservoir.py @@ -0,0 +1,133 @@ +""" +Tests for get_reservoir (USBR RISE storage + release). + +The mapping file and the RISE result endpoint are both mocked. We don't hit +data.usbr.gov in the test run; the test asserts the mapping is parsed correctly +and that the JSON:API response shape is normalised into a [Date, value] frame. +""" + +import pandas as pd +import pytest +import requests_mock as _rm + +from data import get_reservoir + + +@pytest.fixture +def mock_api(): + with _rm.Mocker() as m: + yield m + + +def test_load_mapping_skips_comments_and_blanks(tmp_path): + path = tmp_path / 'reservoirs.txt' + path.write_text( + "# header comment\n" + "\n" + "USGS:09163500, Lake Powell, 6126, 6127\n" + "# inline comment line\n" + "USGS:09114500, Blue Mesa, 6135,\n" # no release id + " USGS:09070500 , Dillon , 6200 , 6201 \n" # extra whitespace + ) + mapping = get_reservoir.load_mapping(str(path)) + assert set(mapping) == {'USGS:09163500', 'USGS:09114500', 'USGS:09070500'} + assert mapping['USGS:09163500'] == [('Lake Powell', '6126', '6127')] + # The release id is empty -> None for that slot. + assert mapping['USGS:09114500'] == [('Blue Mesa', '6135', None)] + # Whitespace is stripped. + assert mapping['USGS:09070500'] == [('Dillon', '6200', '6201')] + + +def test_load_mapping_handles_missing_file(tmp_path): + missing = tmp_path / 'does_not_exist.txt' + assert get_reservoir.load_mapping(str(missing)) == {} + + +def test_load_mapping_skips_malformed_lines(tmp_path): + path = tmp_path / 'reservoirs.txt' + path.write_text( + "USGS:09163500, Lake Powell, 6126, 6127\n" + "this line has too few fields\n" + "USGS:09114500, Blue Mesa, 6135\n" + ) + mapping = get_reservoir.load_mapping(str(path)) + assert set(mapping) == {'USGS:09163500', 'USGS:09114500'} + + +def test_fetch_rise_series_parses_jsonapi_response(mock_api): + mock_api.get( + get_reservoir.RISE_RESULT_URL, + json={ + 'data': [ + {'attributes': {'dateTime': '2024-01-01T00:00:00Z', 'result': 100.0}}, + {'attributes': {'dateTime': '2024-01-02T00:00:00Z', 'result': 200.0}}, + {'attributes': {'dateTime': '2024-01-03T00:00:00Z', 'result': 300.0}}, + ], + 'links': {}, + }, + ) + df = get_reservoir.fetch_rise_series('6126', '2024-01-01', '2024-01-03') + assert list(df['Date']) == ['2024-01-01', '2024-01-02', '2024-01-03'] + assert list(df['value']) == [100.0, 200.0, 300.0] + + +def test_fetch_rise_series_collapses_sub_daily_to_daily_mean(mock_api): + mock_api.get( + get_reservoir.RISE_RESULT_URL, + json={ + 'data': [ + {'attributes': {'dateTime': '2024-01-01T00:00:00Z', 'result': 100.0}}, + {'attributes': {'dateTime': '2024-01-01T12:00:00Z', 'result': 200.0}}, + ], + 'links': {}, + }, + ) + df = get_reservoir.fetch_rise_series('6126', '2024-01-01', '2024-01-01') + assert len(df) == 1 + assert df.iloc[0]['value'] == 150.0 + + +def test_fetch_rise_series_returns_empty_on_empty_id(mock_api): + # No request must be issued when item_id is falsy. + df = get_reservoir.fetch_rise_series(None, '2024-01-01', '2024-01-03') + assert df.empty + assert mock_api.call_count == 0 + + +def test_get_reservoir_returns_empty_for_unmapped_station(tmp_path): + mapping_path = tmp_path / 'reservoirs.txt' + mapping_path.write_text("USGS:09163500, Lake Powell, 6126, 6127\n") + df = get_reservoir.get_reservoir( + 'USGS:UNREGULATED', '2024-01-01', '2024-01-03', mapping_path=str(mapping_path)) + assert df.empty + assert list(df.columns) == ['Date', 'reservoir_storage', 'reservoir_release'] + + +def test_get_reservoir_sums_multiple_reservoirs(mock_api, tmp_path): + mapping_path = tmp_path / 'reservoirs.txt' + mapping_path.write_text( + "USGS:09163500, ResA, 100, 200\n" + "USGS:09163500, ResB, 101, 201\n" + ) + + # Different responses by query string -- match on itemId. + def _by_item(request, context): + item_id = request.qs.get('itemid', [''])[0] + value_map = { + '100': 1000.0, # ResA storage + '101': 500.0, # ResB storage + '200': 50.0, # ResA release + '201': 25.0, # ResB release + } + v = value_map[item_id] + return { + 'data': [{'attributes': {'dateTime': '2024-01-01T00:00:00Z', 'result': v}}], + 'links': {}, + } + + mock_api.get(get_reservoir.RISE_RESULT_URL, json=_by_item) + df = get_reservoir.get_reservoir( + 'USGS:09163500', '2024-01-01', '2024-01-01', mapping_path=str(mapping_path)) + assert len(df) == 1 + assert df.iloc[0]['reservoir_storage'] == 1500.0 # 1000 + 500 + assert df.iloc[0]['reservoir_release'] == 75.0 # 50 + 25 diff --git a/tests/test_windowing.py b/tests/test_windowing.py index 760c77c..a306377 100644 --- a/tests/test_windowing.py +++ b/tests/test_windowing.py @@ -22,6 +22,10 @@ def _make_station_frame(site_id, station_idx, basin_idx, n_days=120, start='2022 'SWE': rng.standard_normal(n_days), 'soil_moisture': rng.standard_normal(n_days), 'sm_observed': rng.integers(0, 2, n_days), + 'drought_index': rng.standard_normal(n_days), + 'reservoir_storage': rng.standard_normal(n_days), + 'reservoir_release': rng.standard_normal(n_days), + 'reservoir_observed': rng.integers(0, 2, n_days), 'doy_sin': np.sin(2 * np.pi * np.arange(n_days) / 365), 'doy_cos': np.cos(2 * np.pi * np.arange(n_days) / 365), }) @@ -47,6 +51,15 @@ def test_decoder_features_carry_no_flow_information(): # The sm_observed indicator is also encoder-only. assert 'sm_observed' in windowing.ENCODER_FEATURES assert 'sm_observed' not in windowing.DECODER_FEATURES + # USDM drought + USBR reservoir are encoder-only too -- none has a + # skillful 14-day forecast available. + assert 'drought_index' in windowing.ENCODER_FEATURES + assert 'drought_index' not in windowing.DECODER_FEATURES + assert 'reservoir_storage' in windowing.ENCODER_FEATURES + assert 'reservoir_release' in windowing.ENCODER_FEATURES + assert 'reservoir_observed' in windowing.ENCODER_FEATURES + for c in ('reservoir_storage', 'reservoir_release', 'reservoir_observed'): + assert c not in windowing.DECODER_FEATURES def test_build_windows_shapes_are_correct():