diff --git a/.github/reservoir_mapping.txt b/.github/reservoir_mapping.txt new file mode 100644 index 0000000..2efd5da --- /dev/null +++ b/.github/reservoir_mapping.txt @@ -0,0 +1,33 @@ +# Station -> USBR RISE catalog-item-id mapping for reservoir features. +# +# Format (comma-separated, one record per line): +# , , , +# +# Where: +# - site_id matches an entry in .github/site_ids.txt (e.g. USGS:09163500) +# - reservoir_name is human-readable, used only for logging +# - storage_itemId is a USBR RISE catalog-item id for a storage timeseries +# (acre-feet, ideally daily) +# - release_itemId is a USBR RISE catalog-item id for an outflow / release +# timeseries (cfs, ideally daily). Use an empty trailing comma if the +# reservoir has no release series available. +# +# A station with no entries here is treated as unregulated -- both reservoir +# features default to 0 and the reservoir_observed indicator stays at 0, +# which is the correct semantic for an unregulated headwater gauge. +# +# To find catalog-item ids, browse https://data.usbr.gov/rise/ or query +# https://data.usbr.gov/rise/api/catalog-item?query= +# and look for the daily storage / release records you want. +# +# Multiple reservoirs may map to one station; storage values are summed +# (total water held back) and release values are summed (total outflow) at +# read time. +# +# Lines beginning with # are comments and ignored. Empty lines are ignored. +# +# Example (commented out until you've verified the right itemIds for your +# stations -- the IDs below are illustrative placeholders): +# +# USGS:09163500, Lake Powell, 6126, 6127 +# USGS:09114500, Blue Mesa Reservoir, 6135, 6136 diff --git a/.github/workflows/ml_training.yml b/.github/workflows/ml_training.yml index 1fd5fd3..2675d15 100644 --- a/.github/workflows/ml_training.yml +++ b/.github/workflows/ml_training.yml @@ -4,6 +4,19 @@ on: schedule: - cron: '0 0 * * 1' # Run weekly at midnight on Monday workflow_dispatch: + inputs: + disable_smap: + description: 'Skip SMAP soil moisture (ablation)' + type: boolean + default: false + disable_drought: + description: 'Skip USDM drought (ablation)' + type: boolean + default: false + disable_reservoir: + description: 'Skip USBR reservoir (ablation)' + type: boolean + default: false jobs: train: @@ -19,11 +32,24 @@ jobs: python-version: '3.11' - name: Install dependencies + # earthaccess + h5py + shapely are needed for the SMAP soil-moisture + # fetch wired into combine_data. They are installed AFTER tensorflow so + # pip resolves the shared transitive deps (notably typing-extensions) + # against the tensorflow pin. If the version resolver cannot find a + # compatible earthaccess, combine_data still runs -- the SMAP fetch + # silently degrades to "no data" and soil_moisture defaults to 0. run: | python -m pip install --upgrade pip pip install -r openFlowML/requirements.txt + pip install earthaccess h5py shapely - name: Train model + env: + EARTHDATA_USERNAME: ${{ secrets.EARTHDATA_USERNAME }} + EARTHDATA_PASSWORD: ${{ secrets.EARTHDATA_PASSWORD }} + OPENFLOW_DISABLE_SMAP: ${{ inputs.disable_smap && '1' || '' }} + OPENFLOW_DISABLE_DROUGHT: ${{ inputs.disable_drought && '1' || '' }} + OPENFLOW_DISABLE_RESERVOIR: ${{ inputs.disable_reservoir && '1' || '' }} run: python openFlowML/train.py # The model is not usable for inference without the scaler parameters diff --git a/openFlowML/combine_data.py b/openFlowML/combine_data.py index 5ed5fe3..cdeb88f 100644 --- a/openFlowML/combine_data.py +++ b/openFlowML/combine_data.py @@ -20,7 +20,12 @@ gaps are dropped rather than filled with a pooled (cross-station) mean - the per-station frames carry a clean 'site_id' column -TODO (later in Phase 2): wire in SWE as a history-window feature. +Phase 4 (SMAP soil moisture): + - Soil moisture is fetched per station from NASA SMAP L3 enhanced via + nasa_moisture (lazy import; failure degrades to "no SMAP" for that site). + - Treated like SWE in the spine: slow-varying, longer interior interpolation + limit, missing rows default to 0 rather than being dropped. + - Set OPENFLOW_DISABLE_SMAP=1 to run the ablation baseline without SMAP. """ if not logging.getLogger().hasHandlers(): @@ -36,6 +41,17 @@ # SWE changes slowly (snowpack accumulates/melts over weeks), so we tolerate # longer interior gaps in the SWE series before giving up on a value. MAX_SWE_GAP_DAYS = 30 +# SMAP has a ~1-3 day revisit cadence per pass; gaps come from RFI / dense +# vegetation / frozen ground. Soil moisture itself is slow-varying so the same +# generous interior interpolation limit as SWE is appropriate. +MAX_SM_GAP_DAYS = 30 +# USDM is published weekly; ffill across 8-day gaps is the natural cadence +# for ffill-only handling. Drought intensity is even slower-varying than SWE. +MAX_DROUGHT_GAP_DAYS = 14 +# Reservoir storage / release are reported daily by USBR; gaps are short +# (occasional missing days), so the regular MAX_GAP_DAYS interpolation limit +# is enough. Reservoir state is also slow-varying. +MAX_RESERVOIR_GAP_DAYS = 14 def _to_daily_series(df, value_columns, daily_index): @@ -52,15 +68,19 @@ def _to_daily_series(df, value_columns, daily_index): def merge_dataframes(noaa_data, flow_data, site_id, start_date, end_date, - swe_data=None, huc8=None): + swe_data=None, sm_data=None, + drought_data=None, reservoir_data=None, + huc8=None): """ - Merge a site's NOAA temperature, flow, and SWE data onto one regular daily - index. + Merge a site's flow, NOAA temperature, SWE, SMAP soil moisture, USDM + drought, and USBR reservoir series onto one regular daily index. Short interior flow/temp gaps (<= MAX_GAP_DAYS) are interpolated time-aware; rows still missing a core value afterwards are dropped (no pooled-mean - fill). SWE is interpolated with a longer limit (it's slow-varying) and any - remaining missing values default to 0 -- they don't drop the row. + fill). All Phase 4/5 auxiliary features (SWE, soil_moisture, drought_index, + reservoir_storage, reservoir_release) are slow-varying, get longer + interpolation limits, and missing values are imputed per the column- + specific semantics rather than dropping the row. """ if 'Date' not in noaa_data or 'Date' not in flow_data: raise ValueError("'Date' column missing in one of the dataframes") @@ -96,12 +116,88 @@ def merge_dataframes(noaa_data, flow_data, site_id, start_date, end_date, else: combined['SWE'] = float('nan') - # Drop rows still missing any core flow/temp value. SWE is NOT in - # CORE_COLUMNS, so a missing SWE never drops a row. + # SMAP soil moisture: slow-varying surface state; generous interior + # interpolation limit (matches SWE). Edge handling happens AFTER the + # core-column dropna below, where we have the final row set. + if sm_data is not None and not sm_data.empty: + sm_daily = _to_daily_series(sm_data, ['soil_moisture'], daily_index) + sm_daily['soil_moisture'] = pd.to_numeric(sm_daily['soil_moisture'], errors='coerce') + sm_daily = sm_daily.interpolate(method='time', limit=MAX_SM_GAP_DAYS, limit_area='inside') + combined['soil_moisture'] = sm_daily['soil_moisture'] + else: + combined['soil_moisture'] = float('nan') + + # USDM drought intensity: weekly snapshots, forward-filled. ffill is + # the right move for a step-function weekly series (interpolation + # would smear discrete category changes into ramps). + if drought_data is not None and not drought_data.empty: + d_daily = _to_daily_series(drought_data, ['drought_index'], daily_index) + d_daily['drought_index'] = pd.to_numeric(d_daily['drought_index'], errors='coerce') + d_daily['drought_index'] = d_daily['drought_index'].ffill(limit=MAX_DROUGHT_GAP_DAYS) + combined['drought_index'] = d_daily['drought_index'] + else: + combined['drought_index'] = float('nan') + + # USBR reservoir storage + release: daily series, slow-varying. + # Interpolate short interior gaps. Empty df = unregulated station + # (no mapping entry); the column stays NaN and is handled below. + if reservoir_data is not None and not reservoir_data.empty: + r_daily = _to_daily_series( + reservoir_data, + ['reservoir_storage', 'reservoir_release'], + daily_index, + ) + for col in ('reservoir_storage', 'reservoir_release'): + r_daily[col] = pd.to_numeric(r_daily[col], errors='coerce') + r_daily = r_daily.interpolate(method='time', + limit=MAX_RESERVOIR_GAP_DAYS, + limit_area='inside') + combined['reservoir_storage'] = r_daily['reservoir_storage'] + combined['reservoir_release'] = r_daily['reservoir_release'] + else: + combined['reservoir_storage'] = float('nan') + combined['reservoir_release'] = float('nan') + + # Drop rows still missing any core flow/temp value. SWE / soil_moisture + # are NOT in CORE_COLUMNS, so a missing one never drops a row. before = len(combined) combined = combined.dropna(subset=CORE_COLUMNS) - # Any remaining SWE gaps default to 0 ("no snow data" / out-of-season). + # SWE: 0 means "no snow", which IS a legitimate default; keep that. combined['SWE'] = combined['SWE'].fillna(0.0) + # Soil moisture: 0 means "Sahara desert", which is NOT a legitimate + # default for a SMAP gap (RFI / frozen ground / sensor outage). We + # carry the last interpolated value forward and backward (SM is slow- + # varying), then fall back to the station's median observed SM where + # ffill/bfill can't reach, and only to 0 if the station has no + # observations at all. An sm_observed indicator (1 = real or short- + # gap interpolated retrieval, 0 = imputed via ffill / median fallback) + # gives the model a way to tell the truth from the imputation. + sm_series = combined['soil_moisture'] + combined['sm_observed'] = sm_series.notna().astype('int64') + sm_series = sm_series.ffill().bfill() + median = sm_series.median(skipna=True) + if pd.isna(median): + median = 0.0 + combined['soil_moisture'] = sm_series.fillna(median) + + # Drought: 0 means "no drought anywhere in the HUC", which IS a + # legitimate default for a missing weekly value. Fill remaining gaps + # with 0 -- the model can treat absence as "we don't have a drought + # signal here" (which behaves equivalently to no drought). + combined['drought_index'] = combined['drought_index'].fillna(0.0) + + # Reservoir: 0 storage / 0 release reads as "tiny / empty reservoir", + # which IS misleading -- the right semantic for an unregulated gauge + # is "no upstream reservoir at all". A reservoir_observed indicator + # (1 = mapped + retrieval succeeded for this row, 0 = unmapped / no + # data) tells the model when the storage / release columns are real. + # The numeric columns then default to 0 only as a placeholder; the + # indicator carries the truth. + res_series = combined['reservoir_storage'] + combined['reservoir_observed'] = res_series.notna().astype('int64') + combined['reservoir_storage'] = res_series.ffill().bfill().fillna(0.0) + rel_series = combined['reservoir_release'] + combined['reservoir_release'] = rel_series.ffill().bfill().fillna(0.0) logger.info( "Site %s: %d/%d daily rows usable after gap handling", site_id, len(combined), before, @@ -119,14 +215,25 @@ def merge_dataframes(noaa_data, flow_data, site_id, start_date, end_date, return pd.DataFrame() +def _disabled(env_var): + """Truthy-toggle env-var check, used for the per-source ablation levers.""" + return os.getenv(env_var, '').strip() in ('1', 'true', 'True') + + def fetch_and_process_data(prefix, site_id, start_date, end_date, flow_data): """ - Resolve a site's coordinates and fetch its NOAA temperature + HUC SWE - series, and its HUC8 basin id (used as a Phase 3 basin embedding key). - - Returns (noaa_data, swe_data, huc8). Either dataframe may be empty (SWE - degrades gracefully; missing NOAA causes the caller to skip the site). - huc8 may be None when the lookup fails. + Resolve a site's coordinates and fetch all its auxiliary timeseries: + NOAA temperature, NRCS SWE, SMAP soil moisture, USDM drought, USBR + reservoir storage + release, and the enclosing HUC8 basin id. + + Returns a dict with keys: + noaa_data, swe_data, sm_data, drought_data, reservoir_data, huc8 + Any data frame may be empty (auxiliary features degrade gracefully; only + missing NOAA causes the caller to skip the site). huc8 may be None. + + Per-source env vars short-circuit the corresponding fetch to empty for + ablation runs: + OPENFLOW_DISABLE_SMAP, OPENFLOW_DISABLE_DROUGHT, OPENFLOW_DISABLE_RESERVOIR """ if prefix == "USGS": coords_dict = get_usgs_coordinates(site_id) @@ -137,7 +244,7 @@ def fetch_and_process_data(prefix, site_id, start_date, end_date, flow_data): if not coords_dict: logger.error(f"Could not resolve coordinates for {prefix}:{site_id}. Skipping...") - return None, None, None + return None latitude = float(coords_dict['latitude']) longitude = float(coords_dict['longitude']) @@ -149,7 +256,7 @@ def fetch_and_process_data(prefix, site_id, start_date, end_date, flow_data): if noaa_data is None or noaa_data.empty: logger.warning(f"No NOAA data available for site ID {site_id}. Skipping...") - return None, None, None + return None noaa_data = noaa_data.copy() noaa_data['USGS_site_ID'] = site_id @@ -165,17 +272,60 @@ def fetch_and_process_data(prefix, site_id, start_date, end_date, flow_data): if not huc8: logger.warning("No HUC8 resolved for %s -- basin embedding will fall back", site_id) - # SWE is a history-window feature; degrade gracefully if the AWDB fetch - # fails -- the row stays, SWE defaults to 0 in merge_dataframes. + # SWE: graceful failure -> empty. try: swe_data = get_swe.get_swe(latitude, longitude, start_date, end_date) except Exception as e: logger.warning("SWE fetch failed for %s: %s", site_id, e) swe_data = pd.DataFrame(columns=['Date', 'SWE']) - # Gap handling and numeric coercion happen in merge_dataframes (per station, - # on the regular daily index) -- not here, and never via a pooled mean. - return noaa_data, swe_data, huc8 + # SMAP soil moisture: lazy import (heavy deps), env-var lever, graceful. + sm_data = pd.DataFrame(columns=['Date', 'soil_moisture']) + if _disabled('OPENFLOW_DISABLE_SMAP'): + logger.info("OPENFLOW_DISABLE_SMAP set -- skipping SMAP for %s", site_id) + else: + try: + from data import nasa_moisture + sm_data = nasa_moisture.main(latitude, longitude, start_date, end_date) + except ImportError as e: + logger.warning("Soil-moisture deps unavailable (%s); skipping SMAP for %s", + e, site_id) + except Exception as e: + logger.warning("SMAP fetch failed for %s: %s", site_id, e) + + # USDM drought: light deps, env-var lever, graceful. + drought_data = pd.DataFrame(columns=['Date', 'drought_index']) + if _disabled('OPENFLOW_DISABLE_DROUGHT'): + logger.info("OPENFLOW_DISABLE_DROUGHT set -- skipping USDM for %s", site_id) + else: + try: + from data import get_drought + drought_data = get_drought.get_drought(latitude, longitude, start_date, end_date) + except Exception as e: + logger.warning("USDM drought fetch failed for %s: %s", site_id, e) + + # USBR reservoir: needs a station->reservoir mapping; empty when no entry. + reservoir_data = pd.DataFrame( + columns=['Date', 'reservoir_storage', 'reservoir_release']) + if _disabled('OPENFLOW_DISABLE_RESERVOIR'): + logger.info("OPENFLOW_DISABLE_RESERVOIR set -- skipping RISE for %s", + site_id) + else: + try: + from data import get_reservoir + reservoir_data = get_reservoir.get_reservoir( + f"{prefix}:{site_id}", start_date, end_date) + except Exception as e: + logger.warning("USBR reservoir fetch failed for %s: %s", site_id, e) + + return { + 'noaa_data': noaa_data, + 'swe_data': swe_data, + 'sm_data': sm_data, + 'drought_data': drought_data, + 'reservoir_data': reservoir_data, + 'huc8': huc8, + } def get_site_ids(filename=None): @@ -231,15 +381,22 @@ def main(training_num_years=7): logger.warning(f"Unrecognized prefix for site ID {site_id}. Skipping...") continue - noaa_dataframe, swe_dataframe, huc8 = fetch_and_process_data( + fetched = fetch_and_process_data( prefix, id, start_date, end_date, flow_dataframe) - if noaa_dataframe is None or noaa_dataframe.empty or flow_dataframe.empty: + if (fetched is None + or fetched['noaa_data'] is None + or fetched['noaa_data'].empty + or flow_dataframe.empty): logger.warning(f"No usable data for site ID {site_id}. Skipping...") continue merged = merge_dataframes( - noaa_dataframe, flow_dataframe, site_id, start_date, end_date, - swe_data=swe_dataframe, huc8=huc8) + fetched['noaa_data'], flow_dataframe, site_id, start_date, end_date, + swe_data=fetched['swe_data'], + sm_data=fetched['sm_data'], + drought_data=fetched['drought_data'], + reservoir_data=fetched['reservoir_data'], + huc8=fetched['huc8']) if merged.empty: logger.warning(f"No usable merged data for site ID {site_id}. Skipping...") continue diff --git a/openFlowML/data/appeears.py b/openFlowML/data/appeears.py deleted file mode 100644 index 6db296c..0000000 --- a/openFlowML/data/appeears.py +++ /dev/null @@ -1,415 +0,0 @@ -import shutil -import logging -from data.utils.get_poly import simplify_polygon, validate_polygon, get_huc_polygon -import os -import datetime -import rasterio -from rasterio.plot import show -import numpy as np -from rasterio.mask import mask -from shapely.geometry import box -from data.utils.get_poly import check_polygon_intersection, get_huc_polygon, validate_polygon, simplify_polygon -from data.utils.data_utils import appeears_login, appeears_logout, load_vars, get_earthdata_auth, get_smap_data_bounds -import time -import tempfile -import argparse -import requests -import json -from mpl_toolkits.axes_grid1 import make_axes_locatable -from shapely import Polygon -import importlib.util -import matplotlib.pyplot as plt -from matplotlib.patches import Polygon as mplPolygon -matplotlib_spec = importlib.util.find_spec("matplotlib") -matplotlib_available = matplotlib_spec is not None - -# Configure logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') - -load_vars() - -def check_appeears_product(product_id, layer_name): - """ - Check if the specified product and layer are available in AppEEARS. - """ - # First, get the list of all products - products_url = "https://appeears.earthdatacloud.nasa.gov/api/product" - - try: - response = requests.get(products_url, headers=None, timeout=30) - response.raise_for_status() - - # Print raw response for debugging - print("Raw API Response:") - print(response.text[:1000]) # Print first 1000 characters - - products = response.json() - - print(f"Type of products: {type(products)}") - - if isinstance(products, list): - print("Products is a list. Printing first item:") - print(products[0] if products else "Empty list") - elif isinstance(products, dict): - print("Products is a dictionary. Printing keys:") - print(products.keys()) - else: - print(f"Unexpected type for products: {type(products)}") - - print("Available Soil Moisture Products:") - if isinstance(products, list): - for product in products: - if isinstance(product, dict) and 'ProductAndVersion' in product: - # Check if the product is related to soil moisture - if 'soil moisture' in product.get('Description', '').lower(): - print(f"- {product['ProductAndVersion']}: {product.get('Description', 'N/A')}") - print(f" Available: {product.get('Available', 'N/A')}") - print(f" Temporal Extent: {product.get('TemporalExtentStart', 'N/A')} to {product.get('TemporalExtentEnd', 'N/A')}") - print(f" Resolution: {product.get('Resolution', 'N/A')}") - print(f" Source: {product.get('Source', 'N/A')}") - print(" ---") - - except requests.Timeout: - logging.error("Request timed out while checking product availability") - return False - except requests.RequestException as e: - logging.error(f"Error checking product availability: {e}") - return False - except Exception as e: - logging.error(f"Unexpected error: {e}") - return False - - return True # Return True for now, adjust based on actual product availability check - -def get_product_layers(token, product_id): - """ - Get available layers for a specific product from AppEEARS API. - """ - url = f"https://appeears.earthdatacloud.nasa.gov/api/product/{product_id}" - headers = {"Authorization": f"Bearer {token}"} - - try: - response = requests.get(url, headers=headers) - response.raise_for_status() - product_info = response.json() - - logging.info(f"Available layers for {product_id}:") - for layer_name, layer_info in product_info.items(): - logging.info(f"- {layer_name}: {layer_info.get('Description', 'No description available')}") - - return product_info - except requests.RequestException as e: - logging.error(f"Error getting product layers: {e}") - return {} - -def submit_appears_task(token, polygon, start_date, end_date): - """ - Submit a task to AppEEARS API for SMAP data retrieval. - """ - product_id = "SPL3SMP_E.006" - layers = get_product_layers(token, product_id) - - if not layers: - logging.error(f"No layers found for product {product_id}") - return None - - # Choose the first available layer related to soil moisture - soil_moisture_layer = next((layer_name for layer_name, layer_info in layers.items() - if 'soil_moisture' in layer_name.lower() or 'soil_moisture' in layer_info.get('Description', '').lower()), - None) - - if not soil_moisture_layer: - logging.error(f"No soil moisture layer found for product {product_id}") - return None - - logging.info(f"Selected layer: {soil_moisture_layer}") - - url = "https://appeears.earthdatacloud.nasa.gov/api/task" - headers = { - "Authorization": f"Bearer {token}", - "Content-Type": "application/json" - } - - task_payload = { - "task_type": "area", - "task_name": "SMAP_Soil_Moisture_Extraction", - "params": { - "dates": [ - { - "startDate": start_date.strftime("%m-%d-%Y"), - "endDate": end_date.strftime("%m-%d-%Y") - } - ], - "layers": [ - { - "product": product_id, - "layer": soil_moisture_layer - } - ], - "output": { - "format": { - "type": "geotiff" - }, - "projection": "geographic" - }, - "geo": { - "type": "FeatureCollection", - "features": [ - { - "type": "Feature", - "properties": {}, - "geometry": { - "type": "Polygon", - "coordinates": [polygon] - } - } - ] - } - } - } - - try: - logging.info(f"Submitting task to URL: {url}") - logging.info(f"Headers: {headers}") - logging.info(f"Payload: {json.dumps(task_payload, indent=2)}") - - response = requests.post(url, headers=headers, data=json.dumps(task_payload), timeout=30) - - logging.info(f"Response status code: {response.status_code}") - logging.info(f"Response headers: {response.headers}") - logging.info(f"Response content: {response.text}") - - if response.status_code == 202: - return response.json()["task_id"] - else: - logging.error(f"Failed to submit task. Status code: {response.status_code}") - logging.error(f"Response: {response.text}") - return None - except requests.Timeout: - logging.error("Request timed out while submitting task") - return None - except requests.RequestException as e: - logging.error(f"Error submitting task: {e}") - return None - -def check_task_status(token, task_id): - """ - Check the status of an AppEEARS task. - """ - url = f"https://appeears.earthdatacloud.nasa.gov/api/task/{task_id}" - headers = {"Authorization": f"Bearer {token}"} - - response = requests.get(url, headers=headers) - - if response.status_code == 200: - return response.json()["status"] - else: - logging.error(f"Failed to check task status. Status code: {response.status_code}") - return None - -def download_task_results(token, task_id, output_dir): - """ - Download the results of a completed AppEEARS task. - """ - url = f"https://appeears.earthdatacloud.nasa.gov/api/bundle/{task_id}" - headers = {"Authorization": f"Bearer {token}"} - - response = requests.get(url, headers=headers) - - if response.status_code == 200: - bundle_info = response.json() - for file_info in bundle_info["files"]: - file_id = file_info["file_id"] - file_name = file_info["file_name"] - download_url = f"https://appeears.earthdatacloud.nasa.gov/api/bundle/{task_id}/{file_id}" - - file_response = requests.get(download_url, headers=headers, allow_redirects=True) - if file_response.status_code == 200: - with open(f"{output_dir}/{file_name}", "wb") as f: - f.write(file_response.content) - logging.info(f"Downloaded: {file_name}") - else: - logging.error(f"Failed to download {file_name}. Status code: {file_response.status_code}") - else: - logging.error(f"Failed to get bundle info. Status code: {response.status_code}") - -def extract_soil_moisture_from_geotiff(geotiff_path, polygon): - """ - Extract soil moisture data for the given polygon from the GeoTIFF file using rasterio. - """ - try: - with rasterio.open(geotiff_path) as src: - # Create a GeoJSON-like geometry object - geom = {"type": "Polygon", "coordinates": [polygon]} - - # Mask the raster with the polygon - out_image, out_transform = mask(src, [geom], crop=True) - - # Get the data from the masked raster - data = out_image[0] # Assuming it's a single-band raster - - # Calculate average soil moisture - valid_data = data[data != src.nodata] - if len(valid_data) > 0: - average_moisture = np.mean(valid_data) - return average_moisture - else: - logging.warning("No valid data found within the polygon") - return None - - except Exception as e: - logging.error(f"Error extracting soil moisture data: {e}") - logging.error("Traceback: ", exc_info=True) - return None - -def visualize_smap_data(geotiff_path, polygon, average_moisture): - """ - Create a detailed visualization of SMAP data with the polygon overlay. - """ - try: - with rasterio.open(geotiff_path) as src: - fig, ax = plt.subplots(figsize=(12, 8)) - - # Plot the SMAP data - show(src, ax=ax, cmap='YlGnBu') - - # Plot the polygon - poly = mplPolygon(polygon, facecolor='none', edgecolor='red', linewidth=2) - ax.add_patch(poly) - - # Add colorbar - divider = make_axes_locatable(ax) - cax = divider.append_axes("right", size="5%", pad=0.05) - plt.colorbar(ax.images[0], cax=cax, label='Soil Moisture') - - # Set the extent to focus on the polygon - min_lon, min_lat, max_lon, max_lat = Polygon(polygon).bounds - ax.set_xlim(min_lon - 0.5, max_lon + 0.5) - ax.set_ylim(min_lat - 0.5, max_lat + 0.5) - - # Add title and labels - plt.title(f'SMAP Soil Moisture Data with HUC8 Polygon\nAverage Soil Moisture: {average_moisture:.4f}') - plt.xlabel('Longitude') - plt.ylabel('Latitude') - - # Add grid - ax.grid(True, linestyle='--', alpha=0.5) - - # Show the plot - plt.tight_layout() - plt.show() - - except Exception as e: - logging.error(f"Error visualizing SMAP data: {e}") - logging.error("Traceback: ", exc_info=True) - -# Add a function to verify the token: -def verify_token(token): - url = "https://appeears.earthdatacloud.nasa.gov/api/user" - headers = {"Authorization": f"Bearer {token}"} - try: - response = requests.get(url, headers=headers) - if response.status_code == 200: - logging.info("Token verified successfully") - return True - else: - logging.error(f"Failed to verify token. Status code: {response.status_code}") - logging.error(f"Response: {response.text}") - return False - except requests.RequestException as e: - logging.error(f"Error verifying token: {e}") - return False - -def main(start_date, end_date, lat, lon, visual): - # Login to AppEEARS - token = appeears_login() - if not token: - logging.error("Failed to login to AppEEARS. Please check your credentials.") - return - - # Check if the product and layer are available - product_id = "SPL3SMP_E.003" - layer_name = "soil_moisture" - if not check_appeears_product(product_id, layer_name): - logging.error("Required product or layer is not available. Exiting.") - appeears_logout() - return - - # Get the HUC8 polygon - huc8_polygon = get_huc_polygon(lat, lon, 8) - if not huc8_polygon: - logging.error("Failed to retrieve HUC8 polygon") - appeears_logout() - return - - # Simplify & validate the polygon - simplified_polygon = simplify_polygon(huc8_polygon) - simplified_polygon = validate_polygon(simplified_polygon) - logging.debug(f"Validated polygon coordinates: {simplified_polygon}") - - # Submit AppEEARS task - task_id = submit_appears_task(token, simplified_polygon, start_date, end_date) - if not task_id: - logging.error("Failed to submit AppEEARS task") - appeears_logout() - return - - # Wait for task to complete - max_retries = 30 # 30 minutes maximum wait time - retries = 0 - while retries < max_retries: - status = check_task_status(token, task_id) - if status == "done": - logging.info("Task completed successfully") - break - elif status == "error": - logging.error("Task failed") - appeears_logout() - return - elif status is None: - logging.error("Failed to check task status") - appeears_logout() - return - time.sleep(60) # Wait for 60 seconds before checking again - retries += 1 - else: - logging.error("Task did not complete within the maximum wait time") - appeears_logout() - return - - # Download results - output_dir = tempfile.mkdtemp() - download_task_results(token, task_id, output_dir) - - # Process and visualize results - geotiff_file = next((f for f in os.listdir(output_dir) if f.endswith('.tif')), None) - if geotiff_file: - geotiff_path = os.path.join(output_dir, geotiff_file) - average_soil_moisture = extract_soil_moisture_from_geotiff(geotiff_path, simplified_polygon) - - if average_soil_moisture is not None: - logging.info(f"Average soil moisture: {average_soil_moisture:.4f}") - - if visual: - visualize_smap_data(geotiff_path, simplified_polygon, average_soil_moisture) - else: - logging.error("Failed to calculate average soil moisture") - else: - logging.error("No GeoTIFF file found in the downloaded results") - - # Clean up - shutil.rmtree(output_dir) - - # Logout from AppEEARS - appeears_logout() - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Calculate average soil moisture for a HUC8 polygon from SMAP L3 data.') - parser.add_argument('--start-date', type=lambda d: datetime.datetime.strptime(d, '%Y-%m-%d').date(), required=True, help='Start Date in YYYY-MM-DD format') - parser.add_argument('--end-date', type=lambda d: datetime.datetime.strptime(d, '%Y-%m-%d').date(), required=True, help='End Date in YYYY-MM-DD format') - parser.add_argument('--lat', type=float, required=True, help='Latitude of the point within the desired HUC8 polygon') - parser.add_argument('--lon', type=float, required=True, help='Longitude of the point within the desired HUC8 polygon') - parser.add_argument('--visual', action='store_true', help='Enable matplotlib visualization') - args = parser.parse_args() - - main(args.start_date, args.end_date, args.lat, args.lon, args.visual) \ No newline at end of file diff --git a/openFlowML/data/get_cbrfc.py b/openFlowML/data/get_cbrfc.py new file mode 100644 index 0000000..4c7c80d --- /dev/null +++ b/openFlowML/data/get_cbrfc.py @@ -0,0 +1,216 @@ +""" +NOAA Colorado Basin River Forecast Center (CBRFC) streamflow forecast baseline. + +A second comparison baseline alongside persistence. The model needs to beat +CBRFC's official short-range forecast to claim it's adding value over what an +operational forecaster has access to. + +Public API: + fetch(site_id, anchor_date, horizon_days) -> DataFrame[Date, cbrfc_flow] + Forecast issued on or before `anchor_date`, predictions for the next + `horizon_days`. Empty DataFrame when no usable forecast exists. + + baseline_predictions(test_samples) -> Optional[np.ndarray] + Stack predictions for the entire test set into a (N, horizon, 2) + tensor (same shape as persistence_baseline output in train.py). Returns + None when CBRFC coverage is sparse enough that the comparison would + be meaningless. + +IMPLEMENTATION STATUS: + The CBRFC publishes operational deterministic forecasts via the + Advanced Hydrologic Prediction Service (AHPS) at water.weather.gov, and + Ensemble Streamflow Prediction (ESP) products through their own portal at + cbrfc.noaa.gov. Both have *current-day* access; the **historical archive** + needed for backtesting against our test-set anchor dates is the gap: + + - AHPS does not expose historical issuance via its REST API; the + archived forecasts live in tarballs at + https://water.weather.gov/ahps/download.php + - CBRFC's ESP archive is accessible per-basin via their THREDDS server + but requires a per-issuance lookup that is meaningfully more involved + than this stub. + + Filling either path in (the obvious follow-up commit) lets the baseline + actually evaluate against the test set. Until then, fetch() returns empty + for any anchor_date != today, and baseline_predictions() returns None so + train.py skips the comparison cleanly. +""" + +import logging +from datetime import date, datetime, timedelta +from typing import List, Optional + +import numpy as np +import pandas as pd + +from data.utils import data_utils + +logger = logging.getLogger(__name__) +if not logging.getLogger().hasHandlers(): + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + +# AHPS public site forecast page (one page per gauge by NWS LID). +AHPS_FORECAST_URL = "https://water.weather.gov/ahps2/hydrograph_to_xml.php" +# Default 14-day horizon matches windowing.DECODER_DAYS. +DEFAULT_HORIZON_DAYS = 14 + + +def _empty() -> pd.DataFrame: + return pd.DataFrame(columns=['Date', 'cbrfc_flow']) + + +def _to_date(d) -> date: + if isinstance(d, datetime): + return d.date() + if isinstance(d, date): + return d + return datetime.strptime(str(d)[:10], '%Y-%m-%d').date() + + +def _ahps_lid_for_site(site_id: str) -> Optional[str]: + """ + Map a USGS / DWR site id to an NWS / AHPS LID (5-letter location id). + The mapping isn't algorithmic; it's a curated lookup. Returns None if no + LID is known for the site, in which case fetch() returns empty. + + Populate this when wiring CBRFC for a specific gauge: + {'USGS:09163500': 'CRSC2', ...} + """ + return _AHPS_LID_TABLE.get(site_id) + + +_AHPS_LID_TABLE: dict = { + # site_id -> NWS LID. Empty by default; fill in per gauge as needed. +} + + +def fetch_current(site_id: str, horizon_days: int = DEFAULT_HORIZON_DAYS) -> pd.DataFrame: + """ + Pull the most recent AHPS forecast for `site_id`. Returns DataFrame + [Date, cbrfc_flow] in cfs, one row per forecast day, or an empty frame + when the site has no LID mapping or AHPS returned no usable forecast. + + The AHPS public forecast page only exposes the current forecast issuance, + so this is the "what does CBRFC think tomorrow's flow is right now" + helper. Historical issuances need the archive (see module docstring). + """ + lid = _ahps_lid_for_site(site_id) + if not lid: + logger.info("No AHPS LID mapped for %s; CBRFC fetch skipped", site_id) + return _empty() + params = {'gage': lid, 'output': 'xml'} + response = data_utils.request_with_retry(AHPS_FORECAST_URL, params=params) + if response is None: + return _empty() + rows = _parse_ahps_forecast_xml(response.text, horizon_days) + if not rows: + return _empty() + return pd.DataFrame(rows, columns=['Date', 'cbrfc_flow']) + + +def _parse_ahps_forecast_xml(text: str, horizon_days: int) -> List[tuple]: + """ + Extract daily forecast (date, cfs) rows from the AHPS hydrograph XML. + + AHPS XML wraps a `` block of `` elements, each with a + `` ISO timestamp and a `` numeric value (typically flow + in kcfs or stage in ft -- the gauge metadata tells you which). Lifts the + primary value, collapses sub-daily issuances to daily mean, caps at + horizon_days days from the issuance date. + """ + try: + from xml.etree import ElementTree as ET + except ImportError: + return [] + try: + root = ET.fromstring(text) + except ET.ParseError: + return [] + rows: list = [] + forecast = root.find('forecast') + if forecast is None: + return rows + for datum in forecast.findall('datum'): + valid = datum.findtext('valid') + primary = datum.findtext('primary') + if not valid or primary is None: + continue + try: + d = datetime.fromisoformat(valid.replace('Z', '+00:00')).date() + v = float(primary) + except (ValueError, TypeError): + continue + rows.append((d.strftime('%Y-%m-%d'), v)) + if not rows: + return rows + # Collapse multiple values per day to daily mean. + df = pd.DataFrame(rows, columns=['Date', 'cbrfc_flow']) + daily = df.groupby('Date', as_index=False)['cbrfc_flow'].mean() + daily = daily.sort_values('Date').head(horizon_days) + return list(daily.itertuples(index=False, name=None)) + + +def fetch(site_id: str, anchor_date, + horizon_days: int = DEFAULT_HORIZON_DAYS) -> pd.DataFrame: + """ + CBRFC forecast for `site_id` issued on or before `anchor_date`, for the + `horizon_days` days following the anchor. + + For anchor_date == today, this is equivalent to fetch_current. For any + historical anchor_date, the AHPS API does not expose the issuance; this + returns empty until the historical archive integration lands (see module + docstring). + """ + anchor = _to_date(anchor_date) + if anchor >= date.today(): + return fetch_current(site_id, horizon_days=horizon_days) + logger.debug("CBRFC historical forecast for %s @ %s requires archive integration", + site_id, anchor) + return _empty() + + +def baseline_predictions(test_samples) -> Optional[np.ndarray]: + """ + Per-sample CBRFC forecast aligned with `windowing.WindowedSample` test + items. Returns an (N, horizon, target_features) array on the SCALED flow + target space (so it lines up with model_pred), or None when fewer than + one usable forecast was found and the comparison would be vacuous. + + This is the integration point train.py calls; when filled in, it produces + a third row in the per-horizon MAE table alongside model and persistence. + The shape mirrors `_persistence_pred_for_samples` in train.py. + """ + if not test_samples: + return None + # Until the historical CBRFC archive is wired in there's nothing to + # backtest against; signal "no comparison" cleanly to the caller. + found = 0 + horizon = test_samples[0].target_Y.shape[0] + rows = [] + for sample in test_samples: + # Forecast was issued at the end of the encoder window -- that's + # `anchor_date + encoder_days` for a sample whose anchor_date is the + # first encoder day. WindowedSample doesn't store the forecast-issue + # date explicitly; reconstruct as anchor + (encoder_days - 1). + forecast_issue = (sample.anchor_date + + pd.Timedelta(days=sample.encoder_X.shape[0] - 1)) + forecast = fetch(sample.site_id, forecast_issue, horizon_days=horizon) + if forecast.empty: + rows.append(np.full((horizon, sample.target_Y.shape[1]), np.nan, + dtype='float32')) + continue + found += 1 + # The fetched cbrfc_flow is on the raw cfs scale; train.py knows how + # to invert the scalers if a non-scaled prediction set is provided. + # Broadcast the single CBRFC value across both Min Flow + Max Flow + # target columns until the upstream product distinguishes them. + values = forecast['cbrfc_flow'].astype('float32').to_numpy() + if len(values) < horizon: + padded = np.full(horizon, np.nan, dtype='float32') + padded[:len(values)] = values + values = padded + rows.append(np.tile(values[:horizon, None], (1, sample.target_Y.shape[1]))) + + if found == 0: + return None + return np.stack(rows).astype('float32') diff --git a/openFlowML/data/get_drought.py b/openFlowML/data/get_drought.py new file mode 100644 index 0000000..d0925ba --- /dev/null +++ b/openFlowML/data/get_drought.py @@ -0,0 +1,180 @@ +""" +US Drought Monitor (USDM) drought intensity timeseries by HUC8. + +Wraps the public USDM data service at usdmdataservices.unl.edu. USDM publishes +weekly snapshots of percent area in each drought category (None / D0 / D1 / +D2 / D3 / D4); we collapse those to a single intensity index per week and +forward-fill to daily. + +Public API: + main(lat, lon, start_date, end_date) -> DataFrame[Date, drought_index] + +drought_index = sum_c (D{c}_percent * weight_c) where weight is 1..5 for +D0..D4. Range 0 (no drought anywhere in HUC) to 500 (entire HUC in D4 +exceptional drought). 0 is the legitimate default for missing rows. +""" + +import argparse +import logging +from datetime import date, datetime +from typing import Optional + +import pandas as pd + +from data.utils import data_utils +from data import get_swe + +logger = logging.getLogger(__name__) +if not logging.getLogger().hasHandlers(): + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + +USDM_URL = "https://usdmdataservices.unl.edu/api/HUCStatistics/GetWeeklyHUCStatistics" + +# Ordinal weights for the intensity index. None has weight 0 (it cancels). +_CATEGORY_WEIGHTS = {'D0': 1, 'D1': 2, 'D2': 3, 'D3': 4, 'D4': 5} + + +def _empty() -> pd.DataFrame: + return pd.DataFrame(columns=['Date', 'drought_index']) + + +def _to_date(d) -> date: + if isinstance(d, datetime): + return d.date() + if isinstance(d, date): + return d + return datetime.strptime(str(d)[:10], '%Y-%m-%d').date() + + +def _format_mdy(d) -> str: + """USDM API expects M/D/YYYY (no zero-padding).""" + d = _to_date(d) + return f"{d.month}/{d.day}/{d.year}" + + +def get_drought_weekly(huc_id, start_date, end_date, huc_level=8): + """ + Fetch raw weekly USDM percent-area-by-category records for a HUC. + + Returns a list of dicts (one per ISO week) or [] on any failure. + """ + if not huc_id: + return [] + params = { + 'aoi': str(huc_id), + 'hucLevel': str(huc_level), + 'startdate': _format_mdy(start_date), + 'enddate': _format_mdy(end_date), + 'statisticsType': '2', # percent area by drought category + } + headers = {'Accept': 'application/json'} + response = data_utils.request_with_retry(USDM_URL, params=params, headers=headers) + if response is None: + return [] + try: + payload = response.json() + except ValueError: + return [] + if not isinstance(payload, list): + return [] + return payload + + +def _record_to_index(record) -> float: + """ + Collapse a USDM weekly record (percent area per category) into a single + weighted-intensity index. Missing category fields treated as 0. + """ + total = 0.0 + for cat, weight in _CATEGORY_WEIGHTS.items(): + try: + total += float(record.get(cat, 0) or 0) * weight + except (TypeError, ValueError): + continue + return total + + +def _parse_record_date(record) -> Optional[date]: + """ + USDM payloads expose the snapshot date under a few possible keys; tolerate + all of them. + """ + for key in ('MapDate', 'ValidStart', 'validStart', 'mapDate'): + raw = record.get(key) + if not raw: + continue + raw = str(raw)[:10] + for fmt in ('%Y-%m-%d', '%Y%m%d'): + try: + return datetime.strptime(raw, fmt).date() + except ValueError: + continue + return None + + +def get_drought(lat, lon, start_date, end_date, huc_level=8) -> pd.DataFrame: + """ + End-to-end: lat/lon -> HUC -> USDM weekly intensity -> daily DataFrame. + + Weekly records are forward-filled across the daily index (USDM is a weekly + snapshot; values stay constant for the week). Empty DataFrame on any + failure (HUC lookup, API, parsing) so the spine treats it as "no data". + """ + try: + huc_id = get_swe.get_huc_id(lat, lon, level=huc_level) + except Exception as e: + logger.warning("HUC%d lookup failed for (%s, %s): %s", huc_level, lat, lon, e) + return _empty() + if not huc_id: + logger.warning("Could not resolve HUC%d for (%s, %s)", huc_level, lat, lon) + return _empty() + + records = get_drought_weekly(huc_id, start_date, end_date, huc_level=huc_level) + if not records: + logger.warning("No USDM records for HUC %s in [%s, %s]", + huc_id, start_date, end_date) + return _empty() + + rows = [] + for record in records: + d = _parse_record_date(record) + if d is None: + continue + rows.append((d.strftime('%Y-%m-%d'), _record_to_index(record))) + + if not rows: + return _empty() + + weekly = (pd.DataFrame(rows, columns=['Date', 'drought_index']) + .drop_duplicates(subset=['Date']) + .sort_values('Date')) + weekly['Date'] = pd.to_datetime(weekly['Date']) + weekly = weekly.set_index('Date') + + # Forward-fill weekly snapshots over the daily index. + daily_index = pd.date_range( + pd.Timestamp(_to_date(start_date)), + pd.Timestamp(_to_date(end_date)), + freq='D', name='Date', + ) + daily = weekly.reindex(daily_index, method='ffill') + daily = daily.reset_index() + daily['Date'] = daily['Date'].dt.strftime('%Y-%m-%d') + return daily[['Date', 'drought_index']] + + +def main(lat, lon, start_date, end_date, huc_level=8) -> pd.DataFrame: + df = get_drought(lat, lon, start_date, end_date, huc_level=huc_level) + data_utils.preview_data(df) + return df + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Fetch USDM drought intensity timeseries by HUC.') + parser.add_argument('--lat', type=float, required=True) + parser.add_argument('--lon', type=float, required=True) + parser.add_argument('--start-date', type=str, required=True, help='YYYY-MM-DD') + parser.add_argument('--end-date', type=str, required=True, help='YYYY-MM-DD') + parser.add_argument('--huc-level', type=int, default=8, choices=[2, 4, 6, 8, 10, 12]) + args = parser.parse_args() + main(args.lat, args.lon, args.start_date, args.end_date, huc_level=args.huc_level) diff --git a/openFlowML/data/get_reservoir.py b/openFlowML/data/get_reservoir.py new file mode 100644 index 0000000..7ec5357 --- /dev/null +++ b/openFlowML/data/get_reservoir.py @@ -0,0 +1,222 @@ +""" +USBR RISE reservoir storage + release timeseries by site. + +For regulated rivers, downstream flow is heavily driven by reservoir +operations (releases) and storage state (how full the basin is). This module +maps a site_id to one or more upstream reservoirs via .github/reservoir_mapping.txt +and pulls daily storage (acre-feet) + release (cfs) from the public USBR RISE +catalog API. + +Public API: + main(site_id, start_date, end_date) -> DataFrame[Date, reservoir_storage, + reservoir_release] + +Stations with no mapping entry (e.g. unregulated headwater gauges) get an +empty DataFrame back; combine_data treats that as "0 / not observed", which +is the correct semantic. +""" + +import argparse +import logging +import os +from datetime import date, datetime +from typing import List, Optional, Tuple + +import pandas as pd + +from data.utils import data_utils + +logger = logging.getLogger(__name__) +if not logging.getLogger().hasHandlers(): + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + +RISE_RESULT_URL = "https://data.usbr.gov/rise/api/result" + +# Reservoir mapping config lives alongside site_ids.txt in .github/. +_DEFAULT_MAPPING = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), + '.github', 'reservoir_mapping.txt', +) + + +def _empty() -> pd.DataFrame: + return pd.DataFrame(columns=['Date', 'reservoir_storage', 'reservoir_release']) + + +def _to_date(d) -> date: + if isinstance(d, datetime): + return d.date() + if isinstance(d, date): + return d + return datetime.strptime(str(d)[:10], '%Y-%m-%d').date() + + +def load_mapping(path: Optional[str] = None) -> dict: + """ + Parse the reservoir-mapping config into + {site_id: [(reservoir_name, storage_itemId, release_itemId), ...]}. + + Stations without entries are simply absent from the returned dict, which + the caller treats as "no reservoir for this station". + """ + if path is None: + path = _DEFAULT_MAPPING + mapping: dict = {} + if not os.path.exists(path): + logger.info("No reservoir mapping file at %s -- all stations unregulated", path) + return mapping + try: + with open(path, 'r') as f: + for raw in f: + line = raw.strip() + if not line or line.startswith('#'): + continue + parts = [p.strip() for p in line.split(',')] + if len(parts) < 3: + logger.warning("Skipping malformed reservoir mapping line: %s", raw.rstrip()) + continue + site_id = parts[0] + reservoir = parts[1] + storage = parts[2] or None + release = parts[3] if len(parts) >= 4 and parts[3] else None + mapping.setdefault(site_id, []).append((reservoir, storage, release)) + except OSError as e: + logger.warning("Could not read reservoir mapping %s: %s", path, e) + return mapping + + +def fetch_rise_series(item_id, start_date, end_date) -> pd.DataFrame: + """ + Pull a single RISE catalog-item timeseries as a [Date, value] frame. + + RISE's `/result` endpoint paginates; the public API caps page size, so we + walk forward until the next-page URL is exhausted. Empty DataFrame on any + failure -- the spine treats it as missing. + """ + if not item_id: + return pd.DataFrame(columns=['Date', 'value']) + + start = _to_date(start_date) + end = _to_date(end_date) + params = { + 'itemId': str(item_id), + 'after': start.strftime('%Y-%m-%dT00:00:00Z'), + 'before': end.strftime('%Y-%m-%dT23:59:59Z'), + 'order': 'ASC', + 'itemsPerPage': '10000', + } + headers = {'Accept': 'application/vnd.api+json'} + + rows: List[Tuple[str, float]] = [] + url = RISE_RESULT_URL + next_params = params + safety = 200 # hard cap on pagination loops to avoid runaway requests + while url and safety > 0: + safety -= 1 + response = data_utils.request_with_retry(url, params=next_params, headers=headers) + if response is None: + break + try: + payload = response.json() + except ValueError: + break + for item in (payload.get('data') or []): + attrs = item.get('attributes') or {} + ts = attrs.get('dateTime') or attrs.get('resultDateTime') + val = attrs.get('result') + if ts is None or val is None: + continue + try: + rows.append((str(ts)[:10], float(val))) + except (TypeError, ValueError): + continue + # Next-page link: JSON:API uses links.next as an absolute URL with + # query string baked in; once we follow it we drop params. + links = payload.get('links') or {} + nxt = links.get('next') + if not nxt or nxt == url: + break + url = nxt + next_params = None + + if not rows: + return pd.DataFrame(columns=['Date', 'value']) + df = pd.DataFrame(rows, columns=['Date', 'value']) + # Collapse any duplicate dates within the series (hourly -> daily). + df = df.groupby('Date', as_index=False)['value'].mean() + return df.sort_values('Date').reset_index(drop=True) + + +def get_reservoir(site_id, start_date, end_date, + mapping_path: Optional[str] = None) -> pd.DataFrame: + """ + Resolve the reservoir(s) mapped to site_id, fetch storage + release from + RISE, and return a daily DataFrame[Date, reservoir_storage, reservoir_release] + over the requested window. + + If multiple reservoirs map to one site, storage and release are summed + (total water held back upstream, total outflow). Returns an empty frame + when there is no mapping entry, mirroring the SWE / SMAP graceful pattern. + """ + mapping = load_mapping(mapping_path) + entries = mapping.get(site_id) + if not entries: + return _empty() + + storage_frames = [] + release_frames = [] + for reservoir_name, storage_id, release_id in entries: + if storage_id: + ts = fetch_rise_series(storage_id, start_date, end_date) + if not ts.empty: + ts = ts.rename(columns={'value': 'reservoir_storage'}) + storage_frames.append(ts) + if release_id: + tr = fetch_rise_series(release_id, start_date, end_date) + if not tr.empty: + tr = tr.rename(columns={'value': 'reservoir_release'}) + release_frames.append(tr) + logger.info("Site %s reservoir %s: storage=%s release=%s", + site_id, reservoir_name, + 'yes' if storage_id else '-', + 'yes' if release_id else '-') + + def _sum_by_date(frames, col): + if not frames: + return pd.DataFrame(columns=['Date', col]) + combined = pd.concat(frames, ignore_index=True) + return combined.groupby('Date', as_index=False)[col].sum() + + storage_total = _sum_by_date(storage_frames, 'reservoir_storage') + release_total = _sum_by_date(release_frames, 'reservoir_release') + + if storage_total.empty and release_total.empty: + return _empty() + + merged = (storage_total + .merge(release_total, on='Date', how='outer') + .sort_values('Date') + .reset_index(drop=True)) + if 'reservoir_storage' not in merged.columns: + merged['reservoir_storage'] = float('nan') + if 'reservoir_release' not in merged.columns: + merged['reservoir_release'] = float('nan') + return merged[['Date', 'reservoir_storage', 'reservoir_release']] + + +def main(site_id, start_date, end_date, + mapping_path: Optional[str] = None) -> pd.DataFrame: + df = get_reservoir(site_id, start_date, end_date, mapping_path=mapping_path) + data_utils.preview_data(df) + return df + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Fetch USBR RISE reservoir storage + release for a site.') + parser.add_argument('--site-id', type=str, required=True, help='e.g. USGS:09163500') + parser.add_argument('--start-date', type=str, required=True, help='YYYY-MM-DD') + parser.add_argument('--end-date', type=str, required=True, help='YYYY-MM-DD') + parser.add_argument('--mapping', type=str, default=None, + help='Override path to reservoir_mapping.txt') + args = parser.parse_args() + main(args.site_id, args.start_date, args.end_date, mapping_path=args.mapping) diff --git a/openFlowML/data/get_s2f.py b/openFlowML/data/get_s2f.py new file mode 100644 index 0000000..9b7921e --- /dev/null +++ b/openFlowML/data/get_s2f.py @@ -0,0 +1,97 @@ +""" +USBR Snow-to-Flow (S2F) seasonal forecast baseline. + +Reclamation publishes monthly seasonal water-supply forecasts (typically +April-July or April-September runoff *volume*) for major reservoirs in the +Upper Colorado and other basins. These are the operational statistical +baselines water managers actually use. + +Public API: + fetch(site_id, anchor_date) -> DataFrame[Date, s2f_volume_kaf] + baseline_predictions(test_samples) -> Optional[np.ndarray] + +IMPLEMENTATION STATUS + TIMESCALE CAVEAT: + S2F predicts a *seasonal volume* (e.g. April-July total runoff into + Reservoir X, in kAF). Our model predicts *daily* flow over a 14-day + horizon. Comparing them apples-to-apples is awkward: you'd have to either + + (a) aggregate our 14-day prediction into a fraction-of-seasonal-volume + contribution and compare to S2F's fractional progress estimate, or + (b) disaggregate S2F's seasonal volume into a daily climatological + shape and compare daily. + + Approach (b) is what this stub assumes when filled in -- given a seasonal + forecast V_kAF and a per-site climatology of the daily distribution + inside the forecast season, predict daily_flow[d] = V_kAF * daily_share[d]. + That requires a per-site daily-share lookup that doesn't exist in the + repo yet; building it is part of the follow-up. + + The data fetch itself: USBR publishes S2F products at + https://www.usbr.gov/uc/water/crsp/wsf/ (Upper Colorado WSF) + https://www.usbr.gov/pn/hydromet/forecast.html (Pacific Northwest) + as month-by-month CSV / PDFs that need scraping. There's no clean REST + archive. Until that's wired, fetch() returns empty and + baseline_predictions() returns None so train.py skips this baseline. +""" + +import logging +from datetime import date, datetime +from typing import Optional + +import numpy as np +import pandas as pd + +logger = logging.getLogger(__name__) +if not logging.getLogger().hasHandlers(): + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + + +def _empty() -> pd.DataFrame: + return pd.DataFrame(columns=['Date', 's2f_volume_kaf']) + + +def _to_date(d) -> date: + if isinstance(d, datetime): + return d.date() + if isinstance(d, date): + return d + return datetime.strptime(str(d)[:10], '%Y-%m-%d').date() + + +def fetch(site_id: str, anchor_date) -> pd.DataFrame: + """ + S2F seasonal-volume forecast for `site_id`, as of the month containing + `anchor_date`. Returns a single-row DataFrame [Date, s2f_volume_kaf] + when found, or empty otherwise. + + Currently stubbed: USBR S2F doesn't expose a clean REST archive, so this + always returns empty until the per-basin scraper is built. The function + signature is the integration point so combine_data / baselines can adopt + the real fetch without changing callers. + """ + logger.debug("S2F fetch stubbed for %s @ %s -- archive integration pending", + site_id, _to_date(anchor_date)) + return _empty() + + +def baseline_predictions(test_samples) -> Optional[np.ndarray]: + """ + Per-sample S2F-derived daily prediction aligned with WindowedSample test + items. Returns (N, horizon, target_features) on the raw flow scale when + feasible, or None when the S2F archive isn't wired in (current state). + + See the module docstring for the daily-disaggregation approach this + expects when fully implemented; until then it's a no-op returning None + so train.py reports "S2F: not available" and moves on. + """ + if not test_samples: + return None + found = 0 + for sample in test_samples: + if not fetch(sample.site_id, sample.anchor_date).empty: + found += 1 + if found == 0: + return None + # The disaggregation step (seasonal kAF -> daily cfs via per-site + # climatological share) belongs here once fetch() returns real data. + return None diff --git a/openFlowML/data/nasa_moisture.py b/openFlowML/data/nasa_moisture.py index 84eda14..c1137ae 100644 --- a/openFlowML/data/nasa_moisture.py +++ b/openFlowML/data/nasa_moisture.py @@ -1,330 +1,365 @@ -import datetime -import numpy as np -import h5py -import shutil -from scipy.spatial import cKDTree -from shapely.ops import transform -from shapely.geometry import Polygon -import logging +""" +SMAP L3 enhanced soil moisture (SPL3SMP_E v006) timeseries by HUC8. + +Given a station's lat/lon and a date window, this module looks up the enclosing +HUC8 polygon, searches NASA Earthdata for every SMAP granule whose footprint +intersects that polygon over the window, downloads them (with a granule cache +shared across calls so neighboring stations don't redownload the same global +daily file), filters each granule to recommended-quality pixels inside the +polygon's bbox, and returns the daily polygon mean. + +Public API: + main(lat, lon, start_date, end_date) -> DataFrame[Date, soil_moisture] + +A failure at any step (HUC8 lookup, Earthdata auth, search, download, +extraction) degrades to an empty DataFrame -- combine_data treats missing +soil moisture as "no SMAP today" (forward-filled, then site-median fallback, +plus an sm_observed indicator) rather than dropping the row. + +Performance: SPL3SMP_E granules are global daily files (~30-100 MB each), so +two stations in different HUC8s on the same day pull the same granule. The +module-level _GRANULE_PATH_CACHE deduplicates within a single process; set +OPENFLOW_SMAP_CACHE_DIR to a persistent path to keep granules across runs. +""" + import argparse -from earthaccess import * -from data.utils.get_poly import check_polygon_intersection, get_huc_polygon, validate_polygon, simplify_polygon -from data.utils.data_utils import load_vars, get_earthdata_auth, get_smap_data_bounds +import logging import os import tempfile -# Conditionally import matplotlib -import importlib.util -from shapely.ops import transform -import earthaccess -import matplotlib.pyplot as plt -from matplotlib.patches import Polygon as mplPolygon -matplotlib_spec = importlib.util.find_spec("matplotlib") -matplotlib_available = matplotlib_spec is not None +from datetime import date, datetime +from typing import Optional + +import pandas as pd + +logger = logging.getLogger(__name__) +if not logging.getLogger().hasHandlers(): + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + +# NSIDC SMAP L3 enhanced (9 km) passive radiometer soil moisture. +SMAP_SHORT_NAME = "SPL3SMP_E" +SMAP_VERSION = "006" +# HDF5 fill value for missing pixels in the SMAP product. +FILL_VALUE = -9999.0 +# Valid SMAP retrieval range for volumetric soil moisture (m^3/m^3). +VALID_MIN = 0.0 +VALID_MAX = 1.0 +# retrieval_qual_flag bit 0: 0 = retrieval is recommended quality, 1 = not. +# Conservative filter: drop any pixel where bit 0 is set. +QUAL_RECOMMENDED_BIT = 0 + +# Process-level cache so the per-station SMAP loop doesn't re-download the +# same global daily granule for every basin. Keyed by granule filename. +_GRANULE_PATH_CACHE: dict = {} + -# Configure logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +def _empty() -> pd.DataFrame: + return pd.DataFrame(columns=['Date', 'soil_moisture']) -load_vars() -def search_and_download_smap_data(start_date, end_date, auth, simplified_polygon): +def _to_date(d) -> date: + if isinstance(d, datetime): + return d.date() + if isinstance(d, date): + return d + return datetime.strptime(str(d)[:10], "%Y-%m-%d").date() + + +def _login_earthdata(): """ - Search for SMAP L3 data between the given dates that intersect with the given polygon, - and download the smallest intersecting granule. + Authenticate with NASA Earthdata via earthaccess. Returns the Auth on + success, None on any failure -- the caller short-circuits to an empty + series so a missing credential doesn't bring down combine_data. """ try: - # Ensure we're authenticated - if not auth.authenticated: - logging.info("Not logged in, attempting to log in...") - if not auth.login(strategy="environment"): - raise RuntimeError("Failed to authenticate with NASA Earthdata Login") - else: - logging.info("Already authenticated, proceeding with search and download") - - # Search for SPL3SMP_E collection - collection_query = earthaccess.DataCollections().short_name("SPL3SMP_E").version("006") - collections = collection_query.get() - - if not collections: - logging.error("SMAP L3 SM_P_E collection not found") - return None - - collection = collections[0] - concept_id = collection.concept_id() - logging.info(f"Found SMAP_L3_SM_P_E collection with concept_id: {concept_id}") - - - # Calculate bounding box from simplified_polygon - lons, lats = zip(*simplified_polygon) - min_lon, max_lon = min(lons), max(lons) - min_lat, max_lat = min(lats), max(lats) - - # Now search for granules using DataGranules - granule_query = (earthaccess.DataGranules() - .concept_id(concept_id) - .temporal(start_date, end_date) - .bounding_box(min_lon, min_lat, max_lon, max_lat)) - - granule_hits = granule_query.hits() - logging.info(f"Number of granules found: {granule_hits}") - - if granule_hits == 0: - logging.warning(f"No SMAP data found from {start_date} to {end_date}") - return None - - granules = granule_query.get_all() - - if not granules: - logging.warning(f"No granules retrieved despite positive hit count") - return None - - # Log the number of granules found - logging.info(f"Retrieved {len(granules)} granules") - - # Find the smallest granule - smallest_granule = min(granules, key=lambda g: g.size()) - logging.info(f"Smallest intersecting granule size: {smallest_granule.size()} MB") - - # Create a temporary directory that won't be automatically deleted - temp_dir = tempfile.mkdtemp() - logging.info(f"Created temporary directory: {temp_dir}") - - # Download only the smallest granule - try: - downloaded_files = earthaccess.download(smallest_granule, local_path=temp_dir) - except Exception as e: - logging.error(f"Error during download: {str(e)}") - if temp_dir: - shutil.rmtree(temp_dir) - return None, None - - if downloaded_files: - downloaded_file = downloaded_files[0] - logging.info(f"Successfully downloaded: {downloaded_file}") - - # Verify that the file exists - if os.path.exists(downloaded_file): - logging.info(f"File exists at {downloaded_file}") - file_size = os.path.getsize(downloaded_file) - logging.info(f"File size: {file_size} bytes") - else: - logging.error(f"File does not exist at {downloaded_file}") - if temp_dir: - shutil.rmtree(temp_dir) - return None, None - - return downloaded_file, temp_dir - else: - logging.error("Failed to download SMAP data") - if temp_dir: - shutil.rmtree(temp_dir) - return None, None - + import earthaccess + except ImportError as e: + logger.warning("earthaccess not installed: %s", e) + return None + try: + auth = earthaccess.login(strategy="environment") + if auth is not None and getattr(auth, 'authenticated', False): + return auth + logger.warning("Earthdata authentication did not succeed") + return None except Exception as e: - logging.error(f"Error searching or downloading SMAP data: {e}") - logging.error("Traceback: ", exc_info=True) - if temp_dir: - shutil.rmtree(temp_dir) - return None, None + logger.warning("Earthdata authentication error: %s", e) + return None + -def extract_soil_moisture(hdf_file, polygon, max_distance=0.1): +def _get_huc8_polygon(lat: float, lon: float): + """ + HUC8 polygon (list of (lon, lat) tuples) enclosing the point, simplified. + Returns None when the WBD lookup fails or the simplification dependencies + aren't installed. + """ try: - with h5py.File(hdf_file, 'r') as file: - for dataset_name in file: - if 'Soil_Moisture_Retrieval_Data' in dataset_name: - soil_moisture = file[f'{dataset_name}/soil_moisture'][:] - lat = file[f'{dataset_name}/latitude'][:] - lon = file[f'{dataset_name}/longitude'][:] - logging.info(f"Found soil moisture data in {dataset_name}") - break - else: - logging.error("Could not find soil moisture data in the file") - return None - - polygon_obj = Polygon(polygon) - minx, miny, maxx, maxy = polygon_obj.bounds - - # Start with the polygon bounds and gradually expand - for distance in np.linspace(0, max_distance, 5): - expanded_bounds = (minx-distance, miny-distance, maxx+distance, maxy+distance) - mask = (lon >= expanded_bounds[0]) & (lon <= expanded_bounds[2]) & \ - (lat >= expanded_bounds[1]) & (lat <= expanded_bounds[3]) - - lons = lon[mask] - lats = lat[mask] - soil_moisture_masked = soil_moisture[mask] - - valid = (soil_moisture_masked != -9999.0) & (lons != -9999.0) & (lats != -9999.0) - lons = lons[valid] - lats = lats[valid] - soil_moisture_valid = soil_moisture_masked[valid] - - logging.info(f"Found {len(lons)} valid points within expanded bounds (distance: {distance})") - - if len(lons) > 0: - tree = cKDTree(np.column_stack((lons, lats))) - expanded_polygon = polygon_obj.buffer(distance) - mask_polygon = tree.query_ball_point(expanded_polygon.exterior.coords, r=0.01) - mask_polygon = np.unique(np.concatenate(mask_polygon)) - - soil_moisture_in_polygon = soil_moisture_valid[mask_polygon] - - if len(soil_moisture_in_polygon) > 0: - average_moisture = np.mean(soil_moisture_in_polygon) - logging.info(f"Found {len(soil_moisture_in_polygon)} points inside or near the polygon") - logging.info(f"Average soil moisture: {average_moisture:.4f}") - return average_moisture, expanded_polygon - - logging.warning("No valid soil moisture data found within or near the polygon") - return None, None + from data.utils.get_poly import get_huc_polygon, simplify_polygon + except ImportError as e: + logger.warning("Polygon utilities unavailable: %s", e) + return None + result = get_huc_polygon(lat, lon, huc_level=8) + if not result: + return None + polygon, _huc_id, _attributes = result + if not polygon: + return None + return simplify_polygon(polygon) + +def _search_granules(polygon, start_date, end_date): + """SMAP granules intersecting the polygon's bbox over the date window.""" + try: + import earthaccess + except ImportError as e: + logger.warning("earthaccess not installed: %s", e) + return [] + lons, lats = zip(*polygon) + bbox = (min(lons), min(lats), max(lons), max(lats)) + try: + granules = earthaccess.search_data( + short_name=SMAP_SHORT_NAME, + version=SMAP_VERSION, + temporal=(_to_date(start_date), _to_date(end_date)), + bounding_box=bbox, + ) except Exception as e: - logging.error(f"Error extracting soil moisture data: {e}") - logging.error("Traceback: ", exc_info=True) - return None, None + logger.warning("SMAP granule search failed: %s", e) + return [] + return list(granules) if granules else [] -# Add this function to check data availability in the polygon area -def check_data_availability(hdf_file, polygon): + +def _granule_date(granule) -> Optional[date]: + """ + Observation date for a granule. SMAP L3 filenames embed YYYYMMDD + (e.g. SMAP_L3_SM_P_E_20240115_R18290_001.h5), so we lift it from there + and fall back to the granule's temporal metadata if the filename pattern + doesn't match. + """ try: - with h5py.File(hdf_file, 'r') as file: - for time_of_day in ['AM', 'PM']: - try: - soil_moisture = file[f'Soil_Moisture_Retrieval_Data_{time_of_day}/soil_moisture'][:] - lat = file[f'Soil_Moisture_Retrieval_Data_{time_of_day}/latitude'][:] - lon = file[f'Soil_Moisture_Retrieval_Data_{time_of_day}/longitude'][:] - break - except KeyError: - continue - else: - logging.error("Could not find soil moisture data") - return - - # Get polygon bounds - poly = Polygon(polygon) - min_lon, min_lat, max_lon, max_lat = poly.bounds - - # Create masks for the area of interest and its surroundings - area_mask = (lat >= min_lat) & (lat <= max_lat) & (lon >= min_lon) & (lon <= max_lon) - surrounding_mask = (lat >= min_lat-1) & (lat <= max_lat+1) & (lon >= min_lon-1) & (lon <= max_lon+1) - - # Check for valid data - valid_data_mask = (soil_moisture != -9999.0) - - # Calculate statistics - total_points = np.sum(area_mask) - valid_points = np.sum(valid_data_mask & area_mask) - surrounding_valid_points = np.sum(valid_data_mask & surrounding_mask) - - logging.info(f"Total points in area of interest: {total_points}") - logging.info(f"Valid data points in area of interest: {valid_points}") - logging.info(f"Percentage of valid data in area: {valid_points/total_points*100:.2f}%") - logging.info(f"Valid data points in surrounding area: {surrounding_valid_points}") - - if valid_points == 0: - logging.warning("No valid data points found within the polygon.") - if surrounding_valid_points > 0: - logging.info("However, valid data points found in the surrounding area.") - else: - logging.warning("No valid data points found in the surrounding area either.") + links = granule.data_links() if hasattr(granule, 'data_links') else [] + for link in links or []: + name = link.rsplit('/', 1)[-1] + for token in name.split('_'): + if len(token) == 8 and token.isdigit(): + try: + return datetime.strptime(token, "%Y%m%d").date() + except ValueError: + continue + except Exception: + pass + try: + umm = granule.get("umm", {}) if hasattr(granule, 'get') else {} + beg = (umm.get("TemporalExtent", {}) + .get("RangeDateTime", {}) + .get("BeginningDateTime")) + if beg: + return datetime.strptime(beg[:10], "%Y-%m-%d").date() + except Exception: + pass + return None + + +def _extract_polygon_mean(hdf_path: str, polygon, + use_quality_flag: bool = True) -> Optional[float]: + """ + Average volumetric soil moisture across all valid SMAP pixels inside the + polygon's bounding box, across both AM and PM passes. Returns None when + no valid pixel intersects (cloudy / RFI day, frozen ground, or polygon + entirely off the EASE-grid for that pass). + + Pixels are filtered against the explicit fill value, the [0, 1] valid + range, and -- when the dataset is present -- the SMAP retrieval_qual_flag + bit 0 ("recommended quality"). Without the quality filter, winter Colorado + retrievals include frozen-ground pixels that look numerically plausible + but the SMAP team flags as not recommended. + + Bbox is sufficient at SMAP's 9 km grid spacing -- the HUC8 footprint is + rarely much larger than a handful of cells and the bbox vs true polygon + distinction is below the retrieval noise floor. + """ + try: + import h5py + import numpy as np + except ImportError as e: + logger.warning("h5py/numpy unavailable: %s", e) + return None + + lons, lats = zip(*polygon) + min_lon, max_lon = min(lons), max(lons) + min_lat, max_lat = min(lats), max(lats) + # SMAP L3 v006 splits AM and PM passes into separate top-level groups. + pass_groups = [ + 'Soil_Moisture_Retrieval_Data_AM', + 'Soil_Moisture_Retrieval_Data_PM', + 'Soil_Moisture_Retrieval_Data', + ] + + collected = [] + try: + with h5py.File(hdf_path, 'r') as f: + for grp_name in pass_groups: + if grp_name not in f: + continue + grp = f[grp_name] + if not all(k in grp for k in ('soil_moisture', 'latitude', 'longitude')): + continue + sm = grp['soil_moisture'][:] + lat = grp['latitude'][:] + lon = grp['longitude'][:] + in_bbox = ((lon >= min_lon) & (lon <= max_lon) & + (lat >= min_lat) & (lat <= max_lat)) + valid = ((sm != FILL_VALUE) & (sm >= VALID_MIN) & (sm <= VALID_MAX) & + (lon != FILL_VALUE) & (lat != FILL_VALUE)) + # Quality flag: drop pixels where the recommended-quality bit + # is set. Older granules occasionally omit the dataset; in that + # case we proceed without the filter rather than dropping all. + if use_quality_flag and 'retrieval_qual_flag' in grp: + qflag = grp['retrieval_qual_flag'][:].astype('int32') + recommended = ((qflag >> QUAL_RECOMMENDED_BIT) & 1) == 0 + valid = valid & recommended + mask = in_bbox & valid + if mask.any(): + collected.extend(sm[mask].astype(float).tolist()) except Exception as e: - logging.error(f"Error checking data availability: {e}") + logger.warning("Failed to read %s: %s", hdf_path, e) + return None + + if not collected: + return None + return float(sum(collected) / len(collected)) + + +def _granule_cache_key(granule) -> Optional[str]: + """ + Stable id for a granule, used as the cache key. Falls back through + `.data_links()` (the most reliable per-granule identifier) before giving + up; granules without a usable id are simply not cached. + """ + try: + links = granule.data_links() if hasattr(granule, 'data_links') else [] + for link in links or []: + name = link.rsplit('/', 1)[-1] + if name: + return name + except Exception: + pass + return None + + +def _get_cache_dir() -> str: + """ + Persistent SMAP granule cache directory. Defaults to a subdirectory of the + system tempdir; override with OPENFLOW_SMAP_CACHE_DIR to share the cache + across runs (e.g. a CI cache action). + """ + base = (os.environ.get('OPENFLOW_SMAP_CACHE_DIR') or + os.path.join(tempfile.gettempdir(), 'openflow_smap_cache')) + os.makedirs(base, exist_ok=True) + return base -def list_nsidc_collections(): + +def _download_granule(granule, cache_dir): + """ + Download `granule` into `cache_dir`, or return its cached path. Returns + None on failure (and does NOT delete the cached file -- the cache is + intentionally process-lifetime+). + """ + key = _granule_cache_key(granule) + if key: + cached = _GRANULE_PATH_CACHE.get(key) + if cached and os.path.exists(cached): + logger.debug("Granule cache hit: %s", key) + return cached + # Also check disk in case a previous process populated the dir. + on_disk = os.path.join(cache_dir, key) + if os.path.exists(on_disk): + _GRANULE_PATH_CACHE[key] = on_disk + return on_disk + try: + import earthaccess + except ImportError: + return None try: - nsidc_query = earthaccess.collection_query().daac("NSIDC") - collections = nsidc_query.get() - - logging.info(f"Found {len(collections)} collections from NSIDC-DAAC:") - for collection in collections: - logging.info(f"- Short Name: {collection['umm']['ShortName']}, Version: {collection['umm'].get('Version', 'N/A')}") - - return collections + files = earthaccess.download([granule], local_path=cache_dir) except Exception as e: - logging.error(f"Error listing NSIDC collections: {e}") + logger.warning("Granule download failed: %s", e) + return None + if not files: return None + path = files[0] + if key: + _GRANULE_PATH_CACHE[key] = path + return path -def visualize_smap_and_polygon(hdf_file, polygon): + +def main(lat: float, lon: float, start_date, end_date) -> pd.DataFrame: + """ + Fetch the SMAP L3 enhanced soil-moisture timeseries for the HUC8 enclosing + (lat, lon) over [start_date, end_date]. + + Returns a DataFrame with columns ['Date', 'soil_moisture'] where + soil_moisture is volumetric m^3/m^3 in [0, 1]. Multiple AM/PM passes on + the same day are averaged. Empty DataFrame on any failure -- the + combine_data spine treats missing soil moisture the same way it treats + missing SWE (defaults to 0, never drops the row). + """ try: - with h5py.File(hdf_file, 'r') as file: - for dataset_name in file: - if 'Soil_Moisture_Retrieval_Data' in dataset_name: - soil_moisture = file[f'{dataset_name}/soil_moisture'][:] - lat = file[f'{dataset_name}/latitude'][:] - lon = file[f'{dataset_name}/longitude'][:] - break - else: - logging.error("Could not find soil moisture data in the file") - return - - # Create a mask for valid data - valid_mask = (soil_moisture != -9999.0) & (lat != -9999.0) & (lon != -9999.0) - - fig, ax = plt.subplots(figsize=(12, 8)) - - # Plot SMAP data points - sc = ax.scatter(lon[valid_mask], lat[valid_mask], c=soil_moisture[valid_mask], - cmap='viridis', s=1, alpha=0.5) - plt.colorbar(sc, label='Soil Moisture') - - # Plot the polygon - poly = mplPolygon(polygon, facecolor='none', edgecolor='red', linewidth=2) - ax.add_patch(poly) - - # Set the extent to focus on the area around the polygon - poly_bounds = Polygon(polygon).bounds - ax.set_xlim(poly_bounds[0] - 1, poly_bounds[2] + 1) - ax.set_ylim(poly_bounds[1] - 1, poly_bounds[3] + 1) - - plt.title('SMAP Data and Polygon') - plt.xlabel('Longitude') - plt.ylabel('Latitude') - plt.show() - + polygon = _get_huc8_polygon(lat, lon) except Exception as e: - logging.error(f"Error visualizing SMAP data and polygon: {e}") - -def main(start_date, end_date, lat, lon, visual): - auth = get_earthdata_auth() - - huc8_polygon = get_huc_polygon(lat, lon, huc_level=8) - if not huc8_polygon: - logging.error("Failed to retrieve HUC8 polygon") - return - - simplified_polygon = simplify_polygon(huc8_polygon) - logging.info(f"Simplified polygon coordinates: {simplified_polygon}") - - downloaded_file, temp_dir = search_and_download_smap_data(start_date, end_date, auth, simplified_polygon) - - if downloaded_file and temp_dir: - try: - # Visualize SMAP data and polygon - visualize_smap_and_polygon(downloaded_file, simplified_polygon) - - # Extract soil moisture data - average_soil_moisture, used_polygon = extract_soil_moisture(downloaded_file, simplified_polygon) - - if average_soil_moisture is not None: - logging.info(f"Average soil moisture: {average_soil_moisture:.4f}") - if visual and matplotlib_available: - pass - #visualize_soil_moisture_simple(used_polygon, average_soil_moisture) - else: - logging.error("Failed to calculate soil moisture.") - - finally: - # Clean up: remove the temporary directory - shutil.rmtree(temp_dir) - logging.info(f"Removed temporary directory: {temp_dir}") - else: - logging.error("Failed to find or download SMAP data") + logger.warning("HUC8 polygon lookup failed: %s", e) + return _empty() + if not polygon: + logger.warning("No HUC8 polygon found for (%s, %s)", lat, lon) + return _empty() + + if _login_earthdata() is None: + return _empty() + + granules = _search_granules(polygon, start_date, end_date) + if not granules: + logger.warning("No SMAP granules found in [%s, %s] for the polygon", + start_date, end_date) + return _empty() + logger.info("Found %d SMAP granules", len(granules)) + + cache_dir = _get_cache_dir() + rows = [] + for granule in granules: + obs_date = _granule_date(granule) + if obs_date is None: + continue + path = _download_granule(granule, cache_dir) + if not path: + continue + # Cached files are intentionally NOT removed -- the next station's + # call to main() in this process will hit the cache for the same + # global daily granule. + value = _extract_polygon_mean(path, polygon) + if value is None: + continue + rows.append((obs_date.strftime('%Y-%m-%d'), value)) + + if not rows: + return _empty() + + df = pd.DataFrame(rows, columns=['Date', 'soil_moisture']) + df['soil_moisture'] = pd.to_numeric(df['soil_moisture'], errors='coerce') + # AM + PM passes on the same date collapse to a single daily mean. + daily = df.groupby('Date', as_index=False)['soil_moisture'].mean() + return daily.sort_values('Date').reset_index(drop=True) + + if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Calculate average soil moisture for a HUC8 polygon from SMAP L3 data.') - parser.add_argument('--start-date', type=lambda d: datetime.datetime.strptime(d, '%Y-%m-%d').date(), required=True, help='Start Date in YYYY-MM-DD format') - parser.add_argument('--end-date', type=lambda d: datetime.datetime.strptime(d, '%Y-%m-%d').date(), required=True, help='End Date in YYYY-MM-DD format') - parser.add_argument('--lat', type=float, required=True, help='Latitude of the point within the desired HUC8 polygon') - parser.add_argument('--lon', type=float, required=True, help='Longitude of the point within the desired HUC8 polygon') - parser.add_argument('--visual', action='store_true', help='Enable matplotlib visualization') + parser = argparse.ArgumentParser( + description='Fetch SMAP L3 enhanced soil moisture timeseries by HUC8.') + parser.add_argument('--lat', type=float, required=True) + parser.add_argument('--lon', type=float, required=True) + parser.add_argument('--start-date', type=str, required=True, help='YYYY-MM-DD') + parser.add_argument('--end-date', type=str, required=True, help='YYYY-MM-DD') args = parser.parse_args() - - main(args.start_date, args.end_date, args.lat, args.lon, args.visual) \ No newline at end of file + df = main(args.lat, args.lon, args.start_date, args.end_date) + if df.empty: + print("No data") + else: + print(df.to_string(index=False)) diff --git a/openFlowML/data/soilmoisture.py b/openFlowML/data/soilmoisture.py deleted file mode 100644 index 84ccf8b..0000000 --- a/openFlowML/data/soilmoisture.py +++ /dev/null @@ -1,26 +0,0 @@ -import requests -import json - -def fetch_all_table_names(): - base_url = "https://SDMDataAccess.sc.egov.usda.gov" - headers = {"Content-Type": "application/json"} - - query_url = f"{base_url}/Tabular/post.rest" - # Query to fetch all table names from the information schema - sql_query = "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES" - query_data = { - "SERVICE": "query", - "REQUEST": "query", - "QUERY": sql_query, - "FORMAT": "JSON" - } - response = requests.post(query_url, headers=headers, data=json.dumps(query_data)) - - if response.status_code == 200: - return response.json() - else: - return f"Failed to fetch table names: {response.text}" - -# Execute the function to get all table names -result = fetch_all_table_names() -print(result) diff --git a/openFlowML/data/soilmoisture2.py b/openFlowML/data/soilmoisture2.py deleted file mode 100644 index 0dbd0f7..0000000 --- a/openFlowML/data/soilmoisture2.py +++ /dev/null @@ -1,98 +0,0 @@ -import asyncio -import xarray as xr -import pandas as pd -from shapely.geometry import Polygon -from pyproj import CRS -import fsspec -import dask -from dask.distributed import Client -import argparse -import logging -from data.utils.get_poly import get_huc8_polygon, simplify_polygon - -# Configure logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') - -async def fetch_soil_moisture_data(polygon: Polygon, date_range: pd.date_range) -> xr.Dataset: - """ - Asynchronously fetch CPC soil moisture data for a given polygon and date range. - - Args: - polygon (Polygon): Shapely Polygon object defining the area of interest - date_range (pd.date_range): Date range for data retrieval - - Returns: - xr.Dataset: Dataset containing soil moisture data for the specified region and time - """ - # CPC soil moisture data is available via Google Cloud Storage - gcs_url = "gs://noaa-cpc-pds/soil-moisture/" - - async with fsspec.open_files(gcs_url + "*.nc", mode='rb') as files: - datasets = [xr.open_dataset(file, engine='h5netcdf', chunks={'time': 'auto'}) - for file in files if file.start.date() in date_range] - - combined_data = xr.concat(datasets, dim='time') - - # Ensure CRS matches the polygon CRS (assuming WGS84) - data_crs = CRS.from_epsg(4326) - combined_data = combined_data.rio.write_crs(data_crs) - - # Clip data to polygon - clipped_data = combined_data.rio.clip([polygon], all_touched=True) - - return clipped_data - -async def process_soil_moisture_data(data: xr.Dataset) -> pd.DataFrame: - """ - Process the retrieved soil moisture data. - - Args: - data (xr.Dataset): Dataset containing soil moisture data - - Returns: - pd.DataFrame: Processed soil moisture data - """ - # Example processing - adjust as needed - mean_moisture = data['soilw'].mean(dim=['lat', 'lon']) - return mean_moisture.to_dataframe() - -async def main(lat: float, lon: float, start_date: str, end_date: str): - # Set up dask client for parallel processing - client = Client(n_workers=4, threads_per_worker=2, memory_limit='4GB') - - # Get HUC8 polygon - huc8_polygon = get_huc8_polygon(lat, lon) - if not huc8_polygon: - logging.error("No HUC8 polygon found") - await client.close() - return - - # Simplify the polygon - simplified_polygon = simplify_polygon(huc8_polygon) - logging.info(f"Simplified polygon: {simplified_polygon}") - - # Conve - # rt to Shapely Polygon - polygon = Polygon(simplified_polygon) - - date_range = pd.date_range(start=start_date, end=end_date, freq='D') - - data = await fetch_soil_moisture_data(polygon, date_range) - - # Process data using dask for parallelization - processed_data = await dask.compute(process_soil_moisture_data(data))[0] - - print(processed_data) - - # Clean up dask client - await client.close() - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Fetch soil moisture data for HUC8 polygon based on lat/lon.') - parser.add_argument('--lat', type=float, required=True, help='Latitude') - parser.add_argument('--lon', type=float, required=True, help='Longitude') - parser.add_argument('--start_date', type=str, required=True, help='Start date (YYYY-MM-DD)') - parser.add_argument('--end_date', type=str, required=True, help='End date (YYYY-MM-DD)') - args = parser.parse_args() - - asyncio.run(main(args.lat, args.lon, args.start_date, args.end_date)) \ No newline at end of file diff --git a/openFlowML/data/utils/data_utils.py b/openFlowML/data/utils/data_utils.py index 0924662..fc2aee3 100644 --- a/openFlowML/data/utils/data_utils.py +++ b/openFlowML/data/utils/data_utils.py @@ -2,18 +2,14 @@ import os import time import requests -from datetime import datetime, timedelta, timezone +from datetime import timedelta -# NOTE: h5py, python-dotenv and earthaccess are heavy, optional dependencies -# used only by the soil-moisture data modules. They are imported lazily inside -# the functions that need them so that lightweight consumers (e.g. combine_data -# importing preview_data) work with only the core requirements installed. +# NOTE: python-dotenv is a heavy optional dependency used only by load_vars +# (credential loading for the soil-moisture path). It is imported lazily so +# lightweight consumers like combine_data work with only the core requirements. -# Global variables to store token and expiration -_token = None -_expiration = None - # Additional function to display the beginning and ending of the dataframe +# Additional function to display the beginning and ending of the dataframe def preview_data(df, num_rows=4): logging.info("First few rows:") logging.info(df.head(num_rows)) @@ -59,117 +55,6 @@ def date_chunks(start_date, end_date, max_days=366): yield current, chunk_end current = chunk_end + timedelta(days=1) -def get_earthdata_auth(): - """ - Create and return an authenticated earthaccess Auth instance. - """ - from earthaccess import Auth - auth = Auth() - - username = os.getenv("EARTHDATA_USERNAME") - password = os.getenv("EARTHDATA_PASSWORD") - - logging.info(f"Attempting to authenticate with username: {username}") - - if username and password: - if auth.login(strategy="environment"): - logging.info("Successfully authenticated using environment variables") - return auth - else: - logging.warning("Authentication failed using environment variables") - else: - logging.warning("Environment variables not set or empty") - - raise RuntimeError("Failed to authenticate with NASA Earthdata Login") - - -def appeears_login(): - """ - Log in to AppEEARS and obtain a token. - """ - global _token, _expiration - url = "https://appeears.earthdatacloud.nasa.gov/api/login" - username = os.getenv("EARTHDATA_USERNAME") - password = os.getenv("EARTHDATA_PASSWORD") - - if not username or not password: - raise ValueError("EARTHDATA_USERNAME and EARTHDATA_PASSWORD environment variables must be set.") - - try: - response = requests.post(url, auth=(username, password)) - response.raise_for_status() - data = response.json() - _token = data['token'] - # Parse the expiration date string manually - expiration_str = data['expiration'] - # Remove the 'Z' at the end and split the string - date_part, time_part = expiration_str[:-1].split('T') - year, month, day = map(int, date_part.split('-')) - hour, minute, second = map(int, time_part.split(':')) - _expiration = datetime(year, month, day, hour, minute, second, tzinfo=timezone.utc) - return _token - except requests.RequestException as e: - print(f"Login failed: {e}") - return None - -def appeears_logout(): - """ - Log out from AppEEARS and invalidate the current token. - """ - global _token, _expiration - if not _token: - print("No active session to log out from.") - return True - - url = "https://appeears.earthdatacloud.nasa.gov/api/logout" - headers = {'Authorization': f'Bearer {_token}'} - try: - response = requests.post(url, headers=headers) - if response.status_code == 204: - _token = None - _expiration = None - return True - else: - print(f"Logout failed: {response.text}") - return False - except requests.RequestException as e: - print(f"Logout failed: {e}") - return False - -def get_smap_data_bounds(hdf_file): - """ - Get the actual bounding box of the SMAP data from the HDF file, excluding fill values. - """ - import h5py - try: - with h5py.File(hdf_file, 'r') as file: - for time_of_day in ['AM', 'PM']: - try: - lat_dataset = file[f'Soil_Moisture_Retrieval_Data_{time_of_day}/latitude'] - lon_dataset = file[f'Soil_Moisture_Retrieval_Data_{time_of_day}/longitude'] - - # Filter out fill values (assuming -9999.0 is the fill value) - valid_lat = lat_dataset[lat_dataset[:] != -9999.0] - valid_lon = lon_dataset[lon_dataset[:] != -9999.0] - - if len(valid_lat) > 0 and len(valid_lon) > 0: - min_lat = valid_lat.min() - max_lat = valid_lat.max() - min_lon = valid_lon.min() - max_lon = valid_lon.max() - - logging.info(f"Actual SMAP data bounds: Lon ({min_lon}, {max_lon}), Lat ({min_lat}, {max_lat})") - return (min_lon, min_lat, max_lon, max_lat) - else: - logging.warning(f"No valid data found for {time_of_day}") - except KeyError: - continue - logging.error("Could not find valid latitude or longitude data in the file") - return None - except Exception as e: - logging.error(f"Error getting SMAP data bounds: {e}") - return None - def load_vars(): """ @@ -180,14 +65,18 @@ def load_vars(): kill the interpreter just because credentials are absent (e.g. in CI or when running the test suite). """ - from dotenv import load_dotenv + try: + from dotenv import load_dotenv + except ImportError: + logging.warning("python-dotenv not installed; skipping creds.env load") + load_dotenv = None script_dir = os.path.dirname(os.path.abspath(__file__)) cred_env_path = os.path.join(script_dir, 'creds.env') - if os.path.exists(cred_env_path): + if load_dotenv is not None and os.path.exists(cred_env_path): load_dotenv(cred_env_path) logging.info(f"Loaded environment variables from {cred_env_path}") - else: + elif load_dotenv is not None: logging.warning(f"creds.env file not found at {cred_env_path}") required = ["EARTHDATA_USERNAME", "EARTHDATA_PASSWORD", diff --git a/openFlowML/normalize_data.py b/openFlowML/normalize_data.py index a8fdc92..c5c463e 100644 --- a/openFlowML/normalize_data.py +++ b/openFlowML/normalize_data.py @@ -27,9 +27,17 @@ # Required: rows missing any of these are dropped (no pooled-mean fill). CORE_REQUIRED = ['TMIN', 'TMAX', 'Min Flow', 'Max Flow'] -# Optional columns that get scaled when present. SWE is slow-varying; if it's -# missing for a row, default it to 0 rather than dropping the row. -OPTIONAL_NUMERIC = ['SWE'] +# Optional continuous columns that get scaled when present. SWE, soil_moisture, +# drought_index, and the two reservoir columns are all slow-varying auxiliaries +# that combine_data imputes when missing rather than dropping the row. +OPTIONAL_NUMERIC = ['SWE', 'soil_moisture', 'drought_index', + 'reservoir_storage', 'reservoir_release'] +# Binary indicator columns: included as model inputs but NOT z-scored. Scaling +# a 0/1 indicator destroys the semantics (the model needs to see 0 vs 1 as +# distinct cases, not as samples from a centered normal). +# sm_observed -- 1 = real / short-gap-interpolated SMAP retrieval +# reservoir_observed -- 1 = station has a USBR mapping AND fetch succeeded +INDICATOR_COLUMNS = ['sm_observed', 'reservoir_observed'] NUMERIC_COLUMNS = CORE_REQUIRED + OPTIONAL_NUMERIC # Streamflow is log-normal -- log1p before z-scoring is standard hydrology # practice. Temperature/SWE stay linear. @@ -140,16 +148,25 @@ def normalize_data(data, artifacts_dir=None): if missing_core: raise ValueError(f"Missing required columns in the data: {missing_core}") - # Coerce every numeric column we know about, including SWE if present. + # Coerce every numeric column we know about, including SWE if present, + # plus the binary indicator columns (still numeric, just not scaled). present_numeric = [c for c in NUMERIC_COLUMNS if c in data.columns] - for column in present_numeric: + present_indicator = [c for c in INDICATOR_COLUMNS if c in data.columns] + for column in present_numeric + present_indicator: data[column] = pd.to_numeric(data[column], errors='coerce') - # SWE is slowly varying and often legitimately zero -- any remaining - # missing value here defaults to 0 ("no snow data") instead of forcing - # the row out, so a station with no nearby SNOTEL still contributes. - if 'SWE' in data.columns: - data['SWE'] = data['SWE'].fillna(0.0) + # All optional numerics are slow-varying auxiliaries; combine_data + # imputes them with smarter per-source logic. The safety net here + # defaults any remaining missing value to 0 instead of dropping the + # row, so a station with no nearby SNOTEL / no SMAP retrieval / + # outside USDM coverage / no upstream reservoir still contributes. + for column in ('SWE', 'soil_moisture', 'drought_index', + 'reservoir_storage', 'reservoir_release'): + if column in data.columns: + data[column] = data[column].fillna(0.0) + # Indicators default to 0 ("not observed") when missing. + for column in present_indicator: + data[column] = data[column].fillna(0).astype('int64') # combine_data owns per-station gap handling for flow + temperature; # anything still missing in the core columns here is dropped rather diff --git a/openFlowML/train.py b/openFlowML/train.py index bbd6993..69a265f 100644 --- a/openFlowML/train.py +++ b/openFlowML/train.py @@ -81,7 +81,10 @@ def main(): # 2. Sanity-check spine outputs the rest of training relies on. required = {'station_idx', 'basin_idx', 'site_id', 'Date', - 'Min Flow', 'Max Flow', 'TMIN', 'TMAX', 'SWE', + 'Min Flow', 'Max Flow', 'TMIN', 'TMAX', + 'SWE', 'soil_moisture', 'sm_observed', + 'drought_index', + 'reservoir_storage', 'reservoir_release', 'reservoir_observed', 'doy_sin', 'doy_cos'} missing = required - set(data.columns) if missing: @@ -143,17 +146,46 @@ def main(): verbose=2, ) - # 7. Evaluate against the persistence baseline on the held-out test set. + # 7. Evaluate against persistence + any external operational baselines + # on the held-out test set. External baselines (CBRFC, S2F) return + # None when their archive integration isn't wired in yet, in which + # case they're silently skipped. if test_inputs is not None and len(splits.test) > 0: model_pred = net.predict(test_inputs, verbose=0) model_mae_per_h = _summarize_per_horizon(test_targets, model_pred) persistence_pred = _persistence_pred_for_samples(splits.test) persistence_mae_per_h = _summarize_per_horizon(test_targets, persistence_pred) + + external_mae_per_h = {} + try: + from data import get_cbrfc + cbrfc_pred = get_cbrfc.baseline_predictions(splits.test) + if cbrfc_pred is not None: + external_mae_per_h['cbrfc'] = _summarize_per_horizon( + test_targets, cbrfc_pred) + except Exception as e: + logger.warning("CBRFC baseline skipped: %s", e) + try: + from data import get_s2f + s2f_pred = get_s2f.baseline_predictions(splits.test) + if s2f_pred is not None: + external_mae_per_h['s2f'] = _summarize_per_horizon( + test_targets, s2f_pred) + except Exception as e: + logger.warning("S2F baseline skipped: %s", e) + logger.info("Test MAE (scaled space) per horizon day:") - for k, (m, p) in enumerate(zip(model_mae_per_h, persistence_mae_per_h), start=1): + for k in range(len(model_mae_per_h)): + m = model_mae_per_h[k] + p = persistence_mae_per_h[k] verdict = "BEATS" if m < p else "LOSES_TO" - logger.info(" day %2d: model=%.4f persistence=%.4f (%s baseline)", - k, m, p, verdict) + extras = "".join( + f" {name}={vals[k]:.4f}" + for name, vals in external_mae_per_h.items()) + logger.info(" day %2d: model=%.4f persistence=%.4f (%s persistence)%s", + k + 1, m, p, verdict, extras) + for name in external_mae_per_h: + logger.info("External baseline available: %s", name) # 8. Persist the model + training config alongside the scaler/index JSON # that combine_data already wrote. These five artifacts together are diff --git a/openFlowML/windowing.py b/openFlowML/windowing.py index 1e5399c..57c6a9e 100644 --- a/openFlowML/windowing.py +++ b/openFlowML/windowing.py @@ -24,12 +24,23 @@ logger = logging.getLogger(__name__) -# Encoder window: everything we know up to the prediction time. Flow is here -# (these are observations), and SWE -- current snowpack is a strong predictor -# of snowmelt-fed runoff in Colorado. -ENCODER_FEATURES = ['Min Flow', 'Max Flow', 'TMIN', 'TMAX', 'SWE', 'doy_sin', 'doy_cos'] +# Encoder window: everything we know up to the prediction time. +# flow (observations); SWE -- snowpack drives snowmelt-fed runoff in CO; +# soil_moisture (SMAP L3 enhanced) -- antecedent wetness gates infiltration +# vs runoff; sm_observed -- 1 = real / short-gap-interpolated SMAP, 0 = +# imputed via combine_data's median fallback; drought_index (USDM) -- HUC8 +# drought intensity 0..500; reservoir_storage / reservoir_release (USBR +# RISE) -- upstream regulation state for regulated rivers, defaults to 0 +# for unmapped (unregulated) stations; reservoir_observed -- 1 when the +# station actually has reservoir data, 0 when it's the unregulated default. +ENCODER_FEATURES = ['Min Flow', 'Max Flow', 'TMIN', 'TMAX', + 'SWE', 'soil_moisture', 'sm_observed', + 'drought_index', + 'reservoir_storage', 'reservoir_release', 'reservoir_observed', + 'doy_sin', 'doy_cos'] # Decoder window: ONLY features available at forecast time. No flow (that's -# what we're predicting), no SWE (no skillful 14-day SWE forecast exists). +# what we're predicting), no SWE / no soil_moisture / no drought / no +# reservoir state (none of these have a skillful 14-day forecast available). DECODER_FEATURES = ['TMIN', 'TMAX', 'doy_sin', 'doy_cos'] # Target: log-z-scored flow during the decoder days. Both columns are already # log1p-transformed and z-scored by normalize_data, so MSE/Huber here behaves. diff --git a/tests/test_combine_data.py b/tests/test_combine_data.py index 131f570..45bcc78 100644 --- a/tests/test_combine_data.py +++ b/tests/test_combine_data.py @@ -111,6 +111,152 @@ def test_merge_defaults_swe_to_zero_when_not_provided(): assert (merged['SWE'] == 0.0).all() +def test_merge_attaches_interpolated_soil_moisture_when_provided(): + noaa, flow = _make_frames() + dates = pd.date_range('2022-01-01', '2022-01-20', freq='D') + # Sparse SMAP retrievals -- only every 3rd day, like a real revisit cadence. + sm = pd.DataFrame({'Date': dates[::3], + 'soil_moisture': [0.15, 0.20, 0.25, 0.30, 0.25, 0.20, 0.15]}) + merged = combine_data.merge_dataframes( + noaa, flow, 'USGS:TEST', + datetime(2022, 1, 1), datetime(2022, 1, 20), + sm_data=sm) + assert 'soil_moisture' in merged.columns + assert not merged['soil_moisture'].isnull().any() + # Every value must lie within the observed range -- ffill/bfill carries + # the edge values to the trailing days, not 0. + sm_min, sm_max = 0.15, 0.30 + assert merged['soil_moisture'].max() <= sm_max + assert merged['soil_moisture'].min() >= sm_min + + +def test_merge_soil_moisture_indicator_flags_real_observations(): + noaa, flow = _make_frames() + dates = pd.date_range('2022-01-01', '2022-01-20', freq='D') + # SMAP retrievals on a sparse subset of days -- the rest get interpolated + # or ffilled / bfilled. + sm = pd.DataFrame({'Date': dates[::3], + 'soil_moisture': [0.15, 0.20, 0.25, 0.30, 0.25, 0.20, 0.15]}) + merged = combine_data.merge_dataframes( + noaa, flow, 'USGS:TEST', + datetime(2022, 1, 1), datetime(2022, 1, 20), + sm_data=sm) + assert 'sm_observed' in merged.columns + # The indicator is 1 on real-or-interpolated rows (everything from the + # first observation to the last observation, since MAX_SM_GAP_DAYS=30 + # exceeds the 3-day spacing) and 0 on rows imputed via ffill/bfill. + rows = merged.set_index(pd.to_datetime(merged['Date'])) + # Jan 19 is the last sparse observation; Jan 20 is ffill / bfill territory. + assert rows.loc['2022-01-19', 'sm_observed'] == 1 + assert rows.loc['2022-01-20', 'sm_observed'] == 0 + + +def test_merge_soil_moisture_falls_back_to_site_median_not_zero(): + # If a station has SOME observations but nothing extending to a portion of + # the window, the missing rows should imputed via ffill/bfill (which here + # leaves no gap), then the median if a gap remained -- never 0. + noaa, flow = _make_frames() + dates = pd.date_range('2022-01-01', '2022-01-20', freq='D') + # Only two observations, well within the window. Long stretches before and + # after these get ffilled / bfilled to those known values, NOT to 0. + sm = pd.DataFrame({'Date': [dates[5], dates[10]], + 'soil_moisture': [0.30, 0.40]}) + merged = combine_data.merge_dataframes( + noaa, flow, 'USGS:TEST', + datetime(2022, 1, 1), datetime(2022, 1, 20), + sm_data=sm) + # Leading days bfill to 0.30; trailing days ffill from 0.40. Neither is 0. + assert (merged['soil_moisture'] > 0).all() + # sm_observed: the two real observations + the interior rows between them + # (interpolated within MAX_SM_GAP_DAYS) are flagged as observed (1). Days + # before the first or after the last observation are imputed via ffill/ + # bfill and flagged not-observed (0). + rows = merged.set_index(pd.to_datetime(merged['Date'])) + assert rows.loc['2022-01-01', 'sm_observed'] == 0 # before first obs + assert rows.loc['2022-01-06', 'sm_observed'] == 1 # first obs + assert rows.loc['2022-01-08', 'sm_observed'] == 1 # interior interpolated + assert rows.loc['2022-01-11', 'sm_observed'] == 1 # second obs + assert rows.loc['2022-01-20', 'sm_observed'] == 0 # after last obs + + +def test_merge_defaults_soil_moisture_to_zero_only_when_no_observations(): + noaa, flow = _make_frames() + merged = combine_data.merge_dataframes( + noaa, flow, 'USGS:TEST', datetime(2022, 1, 1), datetime(2022, 1, 20)) + assert 'soil_moisture' in merged.columns + # With no SMAP data at all, fall back to 0 (last-resort default). + assert (merged['soil_moisture'] == 0.0).all() + # And the indicator correctly says "none of these are real observations". + assert (merged['sm_observed'] == 0).all() + + +def test_merge_attaches_drought_index_with_ffill(): + noaa, flow = _make_frames() + dates = pd.date_range('2022-01-01', '2022-01-20', freq='D') + # USDM weekly snapshots on Jan 4 and Jan 11. + drought = pd.DataFrame({ + 'Date': [dates[3], dates[10]], + 'drought_index': [100.0, 250.0], + }) + merged = combine_data.merge_dataframes( + noaa, flow, 'USGS:TEST', + datetime(2022, 1, 1), datetime(2022, 1, 20), + drought_data=drought) + assert 'drought_index' in merged.columns + rows = merged.set_index(pd.to_datetime(merged['Date'])) + # Pre-first-snapshot days have no value to ffill from -> 0 (the default). + assert rows.loc['2022-01-01', 'drought_index'] == 0.0 + # The snapshot day takes the first value. + assert rows.loc['2022-01-04', 'drought_index'] == 100.0 + # ffilled through the week. + assert rows.loc['2022-01-10', 'drought_index'] == 100.0 + # Next snapshot kicks in. + assert rows.loc['2022-01-11', 'drought_index'] == 250.0 + assert rows.loc['2022-01-20', 'drought_index'] == 250.0 + + +def test_merge_defaults_drought_index_to_zero_when_not_provided(): + noaa, flow = _make_frames() + merged = combine_data.merge_dataframes( + noaa, flow, 'USGS:TEST', datetime(2022, 1, 1), datetime(2022, 1, 20)) + assert 'drought_index' in merged.columns + assert (merged['drought_index'] == 0.0).all() + + +def test_merge_attaches_reservoir_storage_and_release(): + noaa, flow = _make_frames() + dates = pd.date_range('2022-01-01', '2022-01-20', freq='D') + reservoir = pd.DataFrame({ + 'Date': dates[::5], # every 5 days + 'reservoir_storage': [1000.0, 1100.0, 1050.0, 1000.0], + 'reservoir_release': [50.0, 60.0, 55.0, 50.0], + }) + merged = combine_data.merge_dataframes( + noaa, flow, 'USGS:TEST', + datetime(2022, 1, 1), datetime(2022, 1, 20), + reservoir_data=reservoir) + assert 'reservoir_storage' in merged.columns + assert 'reservoir_release' in merged.columns + assert 'reservoir_observed' in merged.columns + # Observed where the reservoir series provided a value (post-interpolation). + assert merged['reservoir_observed'].sum() > 0 + # Storage stays in the observed range after ffill / bfill. + s_min, s_max = 1000.0, 1100.0 + assert merged['reservoir_storage'].max() <= s_max + assert merged['reservoir_storage'].min() >= s_min + + +def test_merge_defaults_reservoir_to_zero_when_unregulated(): + # Unregulated station: no reservoir_data passed -> storage / release default + # to 0 and reservoir_observed stays 0. + noaa, flow = _make_frames() + merged = combine_data.merge_dataframes( + noaa, flow, 'USGS:TEST', datetime(2022, 1, 1), datetime(2022, 1, 20)) + assert (merged['reservoir_storage'] == 0.0).all() + assert (merged['reservoir_release'] == 0.0).all() + assert (merged['reservoir_observed'] == 0).all() + + def test_merge_records_huc8_when_provided(): noaa, flow = _make_frames() merged = combine_data.merge_dataframes( diff --git a/tests/test_drought.py b/tests/test_drought.py new file mode 100644 index 0000000..b755c2d --- /dev/null +++ b/tests/test_drought.py @@ -0,0 +1,114 @@ +""" +Tests for get_drought.main / get_drought helpers (USDM intensity-by-HUC). + +USDM is a clean REST endpoint; we mock with requests_mock and don't hit the +network. The HUC lookup is patched out per test so we can isolate the USDM +parsing + forward-fill logic from the WBD ArcGIS dependency. +""" + +from datetime import datetime + +import pandas as pd +import pytest +import requests_mock as _rm + +from data import get_drought + + +@pytest.fixture +def mock_api(): + with _rm.Mocker() as m: + yield m + + +def test_intensity_weights_collapse_categories_correctly(): + # 100% D2 == weight 3 == index 300; D0..D4 weights are ordinal (1..5). + assert get_drought._record_to_index( + {'D0': 0, 'D1': 0, 'D2': 100, 'D3': 0, 'D4': 0}) == 300.0 + # 50% D1 + 50% D3 == 0.5*2 + 0.5*4 = 3 -- but expressed as percent it's + # 50 + 50 -> 50*2 + 50*4 = 300. + assert get_drought._record_to_index( + {'D0': 0, 'D1': 50, 'D2': 0, 'D3': 50, 'D4': 0}) == 300.0 + # No drought anywhere -> 0. + assert get_drought._record_to_index( + {'D0': 0, 'D1': 0, 'D2': 0, 'D3': 0, 'D4': 0}) == 0.0 + # Maximum: 100% D4 -> 500. + assert get_drought._record_to_index( + {'D0': 0, 'D1': 0, 'D2': 0, 'D3': 0, 'D4': 100}) == 500.0 + + +def test_intensity_treats_missing_categories_as_zero(): + # USDM occasionally omits zero-category fields; missing == 0, not error. + assert get_drought._record_to_index({'D4': 100}) == 500.0 + assert get_drought._record_to_index({}) == 0.0 + + +def test_format_mdy_strips_zero_padding(): + # USDM API requires M/D/YYYY (NOT MM/DD/YYYY). + assert get_drought._format_mdy('2024-01-05') == '1/5/2024' + assert get_drought._format_mdy(datetime(2024, 12, 31)) == '12/31/2024' + + +def test_get_drought_weekly_parses_records(mock_api): + mock_api.get( + get_drought.USDM_URL, + json=[ + {'MapDate': '2024-01-02', 'D0': 100, 'D1': 0, 'D2': 0, 'D3': 0, 'D4': 0}, + {'MapDate': '2024-01-09', 'D0': 0, 'D1': 50, 'D2': 50, 'D3': 0, 'D4': 0}, + ], + ) + recs = get_drought.get_drought_weekly('14010001', '2024-01-01', '2024-01-15') + assert len(recs) == 2 + + +def test_get_drought_returns_empty_when_huc_lookup_fails(monkeypatch): + monkeypatch.setattr(get_drought.get_swe, 'get_huc_id', lambda *a, **kw: None) + df = get_drought.get_drought(40.0, -106.0, '2024-01-01', '2024-01-15') + assert df.empty + assert list(df.columns) == ['Date', 'drought_index'] + + +def test_get_drought_forward_fills_weekly_to_daily(mock_api, monkeypatch): + monkeypatch.setattr(get_drought.get_swe, 'get_huc_id', lambda *a, **kw: '14010001') + mock_api.get( + get_drought.USDM_URL, + json=[ + {'MapDate': '2024-01-02', 'D0': 100, 'D1': 0, 'D2': 0, 'D3': 0, 'D4': 0}, + {'MapDate': '2024-01-09', 'D0': 0, 'D1': 50, 'D2': 50, 'D3': 0, 'D4': 0}, + ], + ) + df = get_drought.get_drought(40.0, -106.0, '2024-01-01', '2024-01-15') + # Daily reindex over Jan 1..15 = 15 rows. + assert len(df) == 15 + rows = df.set_index('Date') + # Days before the first weekly snapshot stay NaN (we used ffill only -- + # the spine fills those with 0). + assert pd.isna(rows.loc['2024-01-01', 'drought_index']) + # Days from Jan 2 through Jan 8 inherit the first snapshot's intensity + # (100% D0 -> weight 1 -> 100). + assert rows.loc['2024-01-02', 'drought_index'] == 100.0 + assert rows.loc['2024-01-08', 'drought_index'] == 100.0 + # Days from Jan 9 onward inherit the second snapshot (50*2 + 50*3 = 250). + assert rows.loc['2024-01-09', 'drought_index'] == 250.0 + assert rows.loc['2024-01-15', 'drought_index'] == 250.0 + + +def test_get_drought_returns_empty_when_api_returns_nothing(mock_api, monkeypatch): + monkeypatch.setattr(get_drought.get_swe, 'get_huc_id', lambda *a, **kw: '14010001') + mock_api.get(get_drought.USDM_URL, json=[]) + df = get_drought.get_drought(40.0, -106.0, '2024-01-01', '2024-01-15') + assert df.empty + + +def test_get_drought_tolerates_alternate_date_keys(mock_api, monkeypatch): + monkeypatch.setattr(get_drought.get_swe, 'get_huc_id', lambda *a, **kw: '14010001') + mock_api.get( + get_drought.USDM_URL, + json=[ + # USDM has been observed to return ValidStart instead of MapDate. + {'ValidStart': '20240102', 'D0': 100, 'D1': 0, 'D2': 0, 'D3': 0, 'D4': 0}, + ], + ) + df = get_drought.get_drought(40.0, -106.0, '2024-01-01', '2024-01-15') + assert len(df) == 15 + assert df.set_index('Date').loc['2024-01-02', 'drought_index'] == 100.0 diff --git a/tests/test_external_baselines.py b/tests/test_external_baselines.py new file mode 100644 index 0000000..82dd7b2 --- /dev/null +++ b/tests/test_external_baselines.py @@ -0,0 +1,64 @@ +""" +Smoke tests for the CBRFC + S2F baseline modules. + +Both modules expose a `baseline_predictions(test_samples)` hook that train.py +calls during evaluation. Until the historical archive integrations are wired +in (see module docstrings), both must return None cleanly -- not crash, not +log noise, not pollute the persistence comparison. These tests pin that +contract so the train.py integration stays safe. +""" + +import numpy as np +import pandas as pd +import pytest + +from data import get_cbrfc, get_s2f +from windowing import WindowedSample + + +def _sample(site_id='USGS:09163500'): + """A single WindowedSample-shaped object sufficient for baseline_predictions.""" + return WindowedSample( + encoder_X=np.zeros((60, 9), dtype='float32'), + decoder_X=np.zeros((14, 4), dtype='float32'), + target_Y=np.zeros((14, 2), dtype='float32'), + persistence_anchor=np.zeros(2, dtype='float32'), + station_idx=1, + basin_idx=1, + site_id=site_id, + anchor_date=pd.Timestamp('2024-01-01'), + ) + + +def test_cbrfc_baseline_returns_none_when_no_lid_mapped(): + # The default _AHPS_LID_TABLE is empty -> every site has no LID -> every + # fetch returns empty -> the predictions stack is None (skip cleanly). + pred = get_cbrfc.baseline_predictions([_sample()]) + assert pred is None + + +def test_cbrfc_baseline_returns_none_for_empty_test_set(): + assert get_cbrfc.baseline_predictions([]) is None + + +def test_cbrfc_fetch_current_skips_when_no_lid(): + df = get_cbrfc.fetch_current('USGS:UNKNOWN') + assert df.empty + assert list(df.columns) == ['Date', 'cbrfc_flow'] + + +def test_cbrfc_fetch_historical_returns_empty_until_archive_wired(): + # Anchor in the past -> archive lookup is stubbed -> empty. + df = get_cbrfc.fetch('USGS:09163500', '2024-01-01') + assert df.empty + + +def test_s2f_baseline_returns_none_until_archive_wired(): + assert get_s2f.baseline_predictions([_sample()]) is None + assert get_s2f.baseline_predictions([]) is None + + +def test_s2f_fetch_returns_empty_until_archive_wired(): + df = get_s2f.fetch('USGS:09163500', '2024-01-01') + assert df.empty + assert list(df.columns) == ['Date', 's2f_volume_kaf'] diff --git a/tests/test_nasa_moisture.py b/tests/test_nasa_moisture.py new file mode 100644 index 0000000..322bf30 --- /dev/null +++ b/tests/test_nasa_moisture.py @@ -0,0 +1,351 @@ +""" +Tests for the consolidated SMAP soil-moisture fetcher (nasa_moisture.main). + +We mock at the earthaccess + polygon-lookup boundaries so the test runs offline +and deterministically; the on-disk extraction is exercised by writing a tiny +real HDF5 file and reading it back through _extract_polygon_mean. +""" + +import os +from datetime import date, datetime + +import numpy as np +import pandas as pd +import pytest + +# h5py is part of requirements-data.txt; skip if not installed. +h5py = pytest.importorskip('h5py') + +from data import nasa_moisture + + +# Polygon big enough to contain the synthetic grids used below. +_POLYGON = [(-106.5, 39.5), (-104.5, 39.5), (-104.5, 41.5), (-106.5, 41.5)] + + +def _write_smap_h5(path, sm_am, sm_pm=None, qflag_am=None, qflag_pm=None): + """Write a minimal SMAP-shaped HDF5 file at `path`.""" + lat = np.array([ + [40.0, 40.0, 40.0], + [40.5, 40.5, 40.5], + [41.0, 41.0, 41.0], + ]) + lon = np.array([ + [-106.0, -105.5, -105.0], + [-106.0, -105.5, -105.0], + [-106.0, -105.5, -105.0], + ]) + with h5py.File(path, 'w') as f: + am = f.create_group('Soil_Moisture_Retrieval_Data_AM') + am.create_dataset('soil_moisture', data=sm_am) + am.create_dataset('latitude', data=lat) + am.create_dataset('longitude', data=lon) + if qflag_am is not None: + am.create_dataset('retrieval_qual_flag', data=qflag_am) + if sm_pm is not None: + pm = f.create_group('Soil_Moisture_Retrieval_Data_PM') + pm.create_dataset('soil_moisture', data=sm_pm) + pm.create_dataset('latitude', data=lat) + pm.create_dataset('longitude', data=lon) + if qflag_pm is not None: + pm.create_dataset('retrieval_qual_flag', data=qflag_pm) + + +class _FakeGranule: + """Stand-in for an earthaccess granule with the bits nasa_moisture uses.""" + + def __init__(self, filename): + self._filename = filename + + def data_links(self): + return [f"https://example.test/{self._filename}"] + + +def test_extract_polygon_mean_averages_valid_pixels_only(tmp_path): + path = str(tmp_path / 'sm.h5') + sm = np.array([ + [-9999.0, 0.25, 0.30], + [0.20, 0.35, 0.40], + [-9999.0, 0.45, 0.50], + ]) + _write_smap_h5(path, sm) + value = nasa_moisture._extract_polygon_mean(path, _POLYGON) + expected = (0.25 + 0.30 + 0.20 + 0.35 + 0.40 + 0.45 + 0.50) / 7 + assert value == pytest.approx(expected) + + +def test_extract_polygon_mean_returns_none_when_all_fill(tmp_path): + path = str(tmp_path / 'sm.h5') + sm = np.full((3, 3), -9999.0) + _write_smap_h5(path, sm) + assert nasa_moisture._extract_polygon_mean(path, _POLYGON) is None + + +def test_extract_polygon_mean_combines_am_and_pm(tmp_path): + path = str(tmp_path / 'sm.h5') + sm_am = np.full((3, 3), 0.20) + sm_pm = np.full((3, 3), 0.40) + _write_smap_h5(path, sm_am, sm_pm=sm_pm) + # 9 AM pixels at 0.2 + 9 PM pixels at 0.4 -> grand mean of 0.30. + assert nasa_moisture._extract_polygon_mean(path, _POLYGON) == pytest.approx(0.30) + + +def test_extract_polygon_mean_clips_to_valid_range(tmp_path): + path = str(tmp_path / 'sm.h5') + sm = np.array([ + [0.30, 0.40, 5.0], # 5.0 > VALID_MAX -- rejected + [0.20, -0.10, 0.50], # -0.10 < VALID_MIN -- rejected + [0.10, 0.20, 0.30], + ]) + _write_smap_h5(path, sm) + value = nasa_moisture._extract_polygon_mean(path, _POLYGON) + kept = [0.30, 0.40, 0.20, 0.50, 0.10, 0.20, 0.30] + assert value == pytest.approx(sum(kept) / len(kept)) + + +def test_granule_date_parses_from_filename(): + g = _FakeGranule('SMAP_L3_SM_P_E_20240115_R18290_001.h5') + assert nasa_moisture._granule_date(g) == date(2024, 1, 15) + + +def test_granule_date_returns_none_when_no_pattern(): + g = _FakeGranule('arbitrary_filename.h5') + assert nasa_moisture._granule_date(g) is None + + +def test_main_returns_empty_when_polygon_lookup_fails(monkeypatch): + monkeypatch.setattr(nasa_moisture, '_get_huc8_polygon', lambda lat, lon: None) + out = nasa_moisture.main(40.0, -105.0, '2024-01-01', '2024-01-03') + assert out.empty + assert list(out.columns) == ['Date', 'soil_moisture'] + + +def test_main_returns_empty_when_auth_fails(monkeypatch): + monkeypatch.setattr(nasa_moisture, '_get_huc8_polygon', lambda lat, lon: _POLYGON) + monkeypatch.setattr(nasa_moisture, '_login_earthdata', lambda: None) + out = nasa_moisture.main(40.0, -105.0, '2024-01-01', '2024-01-03') + assert out.empty + + +def test_main_returns_empty_when_no_granules(monkeypatch): + monkeypatch.setattr(nasa_moisture, '_get_huc8_polygon', lambda lat, lon: _POLYGON) + monkeypatch.setattr(nasa_moisture, '_login_earthdata', lambda: object()) + monkeypatch.setattr(nasa_moisture, '_search_granules', + lambda poly, s, e: []) + out = nasa_moisture.main(40.0, -105.0, '2024-01-01', '2024-01-03') + assert out.empty + + +def test_main_end_to_end_with_mocked_search_and_download(monkeypatch, tmp_path): + # Two distinct dates, each with a known mean. Build the granule fixtures. + g1_path = str(tmp_path / 'SMAP_L3_SM_P_E_20240115_R18290_001.h5') + g2_path = str(tmp_path / 'SMAP_L3_SM_P_E_20240116_R18290_001.h5') + _write_smap_h5(g1_path, np.full((3, 3), 0.30)) + _write_smap_h5(g2_path, np.full((3, 3), 0.50)) + + g1 = _FakeGranule('SMAP_L3_SM_P_E_20240115_R18290_001.h5') + g2 = _FakeGranule('SMAP_L3_SM_P_E_20240116_R18290_001.h5') + + download_map = {id(g1): g1_path, id(g2): g2_path} + + monkeypatch.setattr(nasa_moisture, '_get_huc8_polygon', lambda lat, lon: _POLYGON) + monkeypatch.setattr(nasa_moisture, '_login_earthdata', lambda: object()) + monkeypatch.setattr(nasa_moisture, '_search_granules', + lambda poly, s, e: [g1, g2]) + + def fake_download(granule, tmpdir): + # Copy fixture into the temp dir so the real cleanup path runs cleanly. + import shutil + src = download_map[id(granule)] + dst = f"{tmpdir}/{src.rsplit('/', 1)[-1]}" + shutil.copy(src, dst) + return dst + + monkeypatch.setattr(nasa_moisture, '_download_granule', fake_download) + + out = nasa_moisture.main(40.0, -105.0, '2024-01-15', '2024-01-16') + assert list(out.columns) == ['Date', 'soil_moisture'] + assert len(out) == 2 + assert out.iloc[0]['Date'] == '2024-01-15' + assert out.iloc[0]['soil_moisture'] == pytest.approx(0.30) + assert out.iloc[1]['Date'] == '2024-01-16' + assert out.iloc[1]['soil_moisture'] == pytest.approx(0.50) + + +def test_main_collapses_multiple_granules_on_same_date(monkeypatch, tmp_path): + # Two granules tagged with the same date -- the daily series collapses + # them via mean (single row, averaged value). + a_path = str(tmp_path / 'SMAP_L3_SM_P_E_20240120_pass_a_001.h5') + b_path = str(tmp_path / 'SMAP_L3_SM_P_E_20240120_pass_b_001.h5') + _write_smap_h5(a_path, np.full((3, 3), 0.20)) + _write_smap_h5(b_path, np.full((3, 3), 0.40)) + + g_a = _FakeGranule(a_path.rsplit('/', 1)[-1]) + g_b = _FakeGranule(b_path.rsplit('/', 1)[-1]) + download_map = {id(g_a): a_path, id(g_b): b_path} + + monkeypatch.setattr(nasa_moisture, '_get_huc8_polygon', lambda lat, lon: _POLYGON) + monkeypatch.setattr(nasa_moisture, '_login_earthdata', lambda: object()) + monkeypatch.setattr(nasa_moisture, '_search_granules', + lambda poly, s, e: [g_a, g_b]) + + def fake_download(granule, tmpdir): + import shutil + src = download_map[id(granule)] + dst = f"{tmpdir}/{src.rsplit('/', 1)[-1]}" + shutil.copy(src, dst) + return dst + + monkeypatch.setattr(nasa_moisture, '_download_granule', fake_download) + + out = nasa_moisture.main(40.0, -105.0, '2024-01-20', '2024-01-20') + assert len(out) == 1 + assert out.iloc[0]['Date'] == '2024-01-20' + assert out.iloc[0]['soil_moisture'] == pytest.approx(0.30) + + +def test_main_skips_granules_that_fail_download(monkeypatch, tmp_path): + good_path = str(tmp_path / 'SMAP_L3_SM_P_E_20240117_R001_001.h5') + _write_smap_h5(good_path, np.full((3, 3), 0.25)) + + g_bad = _FakeGranule('SMAP_L3_SM_P_E_20240115_R000_001.h5') + g_good = _FakeGranule('SMAP_L3_SM_P_E_20240117_R001_001.h5') + + monkeypatch.setattr(nasa_moisture, '_get_huc8_polygon', lambda lat, lon: _POLYGON) + monkeypatch.setattr(nasa_moisture, '_login_earthdata', lambda: object()) + monkeypatch.setattr(nasa_moisture, '_search_granules', + lambda poly, s, e: [g_bad, g_good]) + + def fake_download(granule, tmpdir): + if granule is g_bad: + return None + import shutil + dst = f"{tmpdir}/{good_path.rsplit('/', 1)[-1]}" + shutil.copy(good_path, dst) + return dst + + monkeypatch.setattr(nasa_moisture, '_download_granule', fake_download) + + out = nasa_moisture.main(40.0, -105.0, '2024-01-15', '2024-01-17') + assert len(out) == 1 + assert out.iloc[0]['Date'] == '2024-01-17' + + +def test_to_date_accepts_string_datetime_and_date(): + assert nasa_moisture._to_date('2024-01-15') == date(2024, 1, 15) + assert nasa_moisture._to_date(datetime(2024, 1, 15, 12, 0)) == date(2024, 1, 15) + assert nasa_moisture._to_date(date(2024, 1, 15)) == date(2024, 1, 15) + + +def test_extract_polygon_mean_drops_not_recommended_quality_pixels(tmp_path): + path = str(tmp_path / 'sm.h5') + # All pixels have legitimate-looking soil moisture; the quality flag + # rejects half of them (bit 0 = 1 means "not recommended"). The polygon + # mean must reflect only the recommended pixels. + sm = np.full((3, 3), 0.40) + qflag = np.array([ + [1, 1, 1], # all not-recommended + [0, 0, 0], # all recommended + [1, 1, 1], # all not-recommended + ], dtype='int32') + _write_smap_h5(path, sm, qflag_am=qflag) + value = nasa_moisture._extract_polygon_mean(path, _POLYGON) + # Only the middle row's 3 pixels survive -- all happen to be 0.40. + assert value == pytest.approx(0.40) + + +def test_extract_polygon_mean_returns_none_when_quality_flag_rejects_all(tmp_path): + path = str(tmp_path / 'sm.h5') + sm = np.full((3, 3), 0.40) + qflag = np.ones((3, 3), dtype='int32') # every pixel "not recommended" + _write_smap_h5(path, sm, qflag_am=qflag) + assert nasa_moisture._extract_polygon_mean(path, _POLYGON) is None + + +def test_extract_polygon_mean_proceeds_when_quality_dataset_missing(tmp_path): + # Older granules sometimes omit retrieval_qual_flag; we must not refuse + # to extract in that case. + path = str(tmp_path / 'sm.h5') + sm = np.full((3, 3), 0.30) + _write_smap_h5(path, sm) # no qflag dataset + assert nasa_moisture._extract_polygon_mean(path, _POLYGON) == pytest.approx(0.30) + + +def test_extract_polygon_mean_ignores_quality_flag_when_disabled(tmp_path): + path = str(tmp_path / 'sm.h5') + sm = np.full((3, 3), 0.40) + qflag = np.ones((3, 3), dtype='int32') # all "not recommended" + _write_smap_h5(path, sm, qflag_am=qflag) + value = nasa_moisture._extract_polygon_mean( + path, _POLYGON, use_quality_flag=False) + # With the filter off, every in-bbox / in-range pixel counts. + assert value == pytest.approx(0.40) + + +def test_granule_cache_dedupes_downloads_within_a_process(monkeypatch, tmp_path): + # Two main() calls hitting the same granule must only invoke the actual + # download once -- the cache lets neighboring stations share global daily + # granules without re-fetching. + nasa_moisture._GRANULE_PATH_CACHE.clear() + cache_dir = str(tmp_path / 'cache') + monkeypatch.setenv('OPENFLOW_SMAP_CACHE_DIR', cache_dir) + + granule_path = str(tmp_path / 'SMAP_L3_SM_P_E_20240115_R001_001.h5') + _write_smap_h5(granule_path, np.full((3, 3), 0.25)) + g = _FakeGranule('SMAP_L3_SM_P_E_20240115_R001_001.h5') + + download_calls = {'n': 0} + + class _FakeEA: + @staticmethod + def download(granules, local_path): + download_calls['n'] += 1 + import shutil + dst = os.path.join(local_path, 'SMAP_L3_SM_P_E_20240115_R001_001.h5') + shutil.copy(granule_path, dst) + return [dst] + + import sys + monkeypatch.setitem(sys.modules, 'earthaccess', _FakeEA) + monkeypatch.setattr(nasa_moisture, '_get_huc8_polygon', lambda lat, lon: _POLYGON) + monkeypatch.setattr(nasa_moisture, '_login_earthdata', lambda: object()) + monkeypatch.setattr(nasa_moisture, '_search_granules', + lambda poly, s, e: [g]) + + out1 = nasa_moisture.main(40.0, -105.0, '2024-01-15', '2024-01-15') + out2 = nasa_moisture.main(40.1, -105.1, '2024-01-15', '2024-01-15') + + assert not out1.empty and not out2.empty + # The actual download was invoked exactly once across both main() calls. + assert download_calls['n'] == 1 + + +def test_granule_cache_picks_up_existing_disk_files(monkeypatch, tmp_path): + # If a previous process populated the cache dir, a fresh process should + # still hit the cache without re-downloading. + nasa_moisture._GRANULE_PATH_CACHE.clear() + cache_dir = str(tmp_path / 'cache') + os.makedirs(cache_dir, exist_ok=True) + monkeypatch.setenv('OPENFLOW_SMAP_CACHE_DIR', cache_dir) + + # Pre-populate the cache dir with the granule file (simulating a previous + # run that already downloaded it). + cached_path = os.path.join(cache_dir, 'SMAP_L3_SM_P_E_20240115_R001_001.h5') + _write_smap_h5(cached_path, np.full((3, 3), 0.35)) + g = _FakeGranule('SMAP_L3_SM_P_E_20240115_R001_001.h5') + + class _FailingEA: + @staticmethod + def download(granules, local_path): + raise AssertionError("download must not be called when cached on disk") + + import sys + monkeypatch.setitem(sys.modules, 'earthaccess', _FailingEA) + monkeypatch.setattr(nasa_moisture, '_get_huc8_polygon', lambda lat, lon: _POLYGON) + monkeypatch.setattr(nasa_moisture, '_login_earthdata', lambda: object()) + monkeypatch.setattr(nasa_moisture, '_search_granules', + lambda poly, s, e: [g]) + + out = nasa_moisture.main(40.0, -105.0, '2024-01-15', '2024-01-15') + assert not out.empty + assert out.iloc[0]['soil_moisture'] == pytest.approx(0.35) diff --git a/tests/test_normalize_data.py b/tests/test_normalize_data.py index d291f99..5756cd1 100644 --- a/tests/test_normalize_data.py +++ b/tests/test_normalize_data.py @@ -112,6 +112,81 @@ def test_normalize_fills_missing_swe_with_zero_not_dropping_rows(): assert out['SWE'].isnull().sum() == 0 +def test_normalize_scales_soil_moisture_when_present(tmp_path): + import json + df = _make_combined() + df['soil_moisture'] = np.linspace(0.05, 0.45, 20) + out = normalize_data.normalize_data(df, artifacts_dir=str(tmp_path)) + assert out is not None + scalers = json.loads((tmp_path / 'scalers.json').read_text()) + assert 'soil_moisture' in scalers + # Identity transform (no log1p) for soil moisture; z-scored on the raw scale. + assert scalers['soil_moisture']['transform'] == 'identity' + assert abs(out['soil_moisture'].mean()) < 1e-6 + assert abs(out['soil_moisture'].std(ddof=0) - 1.0) < 1e-6 + + +def test_normalize_fills_missing_soil_moisture_with_zero(): + df = _make_combined() + df['soil_moisture'] = [np.nan] * len(df) + out = normalize_data.normalize_data(df) + assert out is not None + # Missing soil moisture should NOT drop rows. + assert len(out) == len(df) + assert out['soil_moisture'].isnull().sum() == 0 + + +def test_normalize_keeps_sm_observed_as_binary_indicator(tmp_path): + import json + df = _make_combined() + df['soil_moisture'] = np.linspace(0.05, 0.45, 20) + df['sm_observed'] = [1, 0, 1, 0, 1, 0, 1, 0, 1, 0] * 2 + out = normalize_data.normalize_data(df, artifacts_dir=str(tmp_path)) + assert out is not None + # Indicator MUST NOT be z-scored -- the model needs to see 0 vs 1. + assert set(out['sm_observed'].unique()) <= {0, 1} + scalers = json.loads((tmp_path / 'scalers.json').read_text()) + assert 'sm_observed' not in scalers + + +def test_normalize_scales_drought_and_reservoir_when_present(tmp_path): + import json + df = _make_combined() + df['drought_index'] = np.linspace(0.0, 400.0, 20) + df['reservoir_storage'] = np.linspace(800_000.0, 1_200_000.0, 20) + df['reservoir_release'] = np.linspace(50.0, 250.0, 20) + out = normalize_data.normalize_data(df, artifacts_dir=str(tmp_path)) + assert out is not None + scalers = json.loads((tmp_path / 'scalers.json').read_text()) + for col in ('drought_index', 'reservoir_storage', 'reservoir_release'): + assert col in scalers + # All three are continuous; z-scored, no log transform. + assert scalers[col]['transform'] == 'identity' + assert abs(out[col].mean()) < 1e-6 + assert abs(out[col].std(ddof=0) - 1.0) < 1e-6 + + +def test_normalize_keeps_reservoir_observed_as_binary(tmp_path): + import json + df = _make_combined() + df['reservoir_observed'] = [1, 0] * 10 + out = normalize_data.normalize_data(df, artifacts_dir=str(tmp_path)) + assert out is not None + assert set(out['reservoir_observed'].unique()) <= {0, 1} + scalers = json.loads((tmp_path / 'scalers.json').read_text()) + assert 'reservoir_observed' not in scalers + + +def test_normalize_defaults_sm_observed_to_zero_when_missing(): + df = _make_combined() + df['sm_observed'] = [1, np.nan, 1, np.nan] * 5 + out = normalize_data.normalize_data(df) + assert out is not None + # Missing rows default to 0 ("not observed"). + assert out['sm_observed'].isnull().sum() == 0 + assert set(out['sm_observed'].unique()) <= {0, 1} + + def test_flow_columns_are_log_transformed_before_scaling(tmp_path): import json df = _make_combined() diff --git a/tests/test_reservoir.py b/tests/test_reservoir.py new file mode 100644 index 0000000..127e5c6 --- /dev/null +++ b/tests/test_reservoir.py @@ -0,0 +1,133 @@ +""" +Tests for get_reservoir (USBR RISE storage + release). + +The mapping file and the RISE result endpoint are both mocked. We don't hit +data.usbr.gov in the test run; the test asserts the mapping is parsed correctly +and that the JSON:API response shape is normalised into a [Date, value] frame. +""" + +import pandas as pd +import pytest +import requests_mock as _rm + +from data import get_reservoir + + +@pytest.fixture +def mock_api(): + with _rm.Mocker() as m: + yield m + + +def test_load_mapping_skips_comments_and_blanks(tmp_path): + path = tmp_path / 'reservoirs.txt' + path.write_text( + "# header comment\n" + "\n" + "USGS:09163500, Lake Powell, 6126, 6127\n" + "# inline comment line\n" + "USGS:09114500, Blue Mesa, 6135,\n" # no release id + " USGS:09070500 , Dillon , 6200 , 6201 \n" # extra whitespace + ) + mapping = get_reservoir.load_mapping(str(path)) + assert set(mapping) == {'USGS:09163500', 'USGS:09114500', 'USGS:09070500'} + assert mapping['USGS:09163500'] == [('Lake Powell', '6126', '6127')] + # The release id is empty -> None for that slot. + assert mapping['USGS:09114500'] == [('Blue Mesa', '6135', None)] + # Whitespace is stripped. + assert mapping['USGS:09070500'] == [('Dillon', '6200', '6201')] + + +def test_load_mapping_handles_missing_file(tmp_path): + missing = tmp_path / 'does_not_exist.txt' + assert get_reservoir.load_mapping(str(missing)) == {} + + +def test_load_mapping_skips_malformed_lines(tmp_path): + path = tmp_path / 'reservoirs.txt' + path.write_text( + "USGS:09163500, Lake Powell, 6126, 6127\n" + "this line has too few fields\n" + "USGS:09114500, Blue Mesa, 6135\n" + ) + mapping = get_reservoir.load_mapping(str(path)) + assert set(mapping) == {'USGS:09163500', 'USGS:09114500'} + + +def test_fetch_rise_series_parses_jsonapi_response(mock_api): + mock_api.get( + get_reservoir.RISE_RESULT_URL, + json={ + 'data': [ + {'attributes': {'dateTime': '2024-01-01T00:00:00Z', 'result': 100.0}}, + {'attributes': {'dateTime': '2024-01-02T00:00:00Z', 'result': 200.0}}, + {'attributes': {'dateTime': '2024-01-03T00:00:00Z', 'result': 300.0}}, + ], + 'links': {}, + }, + ) + df = get_reservoir.fetch_rise_series('6126', '2024-01-01', '2024-01-03') + assert list(df['Date']) == ['2024-01-01', '2024-01-02', '2024-01-03'] + assert list(df['value']) == [100.0, 200.0, 300.0] + + +def test_fetch_rise_series_collapses_sub_daily_to_daily_mean(mock_api): + mock_api.get( + get_reservoir.RISE_RESULT_URL, + json={ + 'data': [ + {'attributes': {'dateTime': '2024-01-01T00:00:00Z', 'result': 100.0}}, + {'attributes': {'dateTime': '2024-01-01T12:00:00Z', 'result': 200.0}}, + ], + 'links': {}, + }, + ) + df = get_reservoir.fetch_rise_series('6126', '2024-01-01', '2024-01-01') + assert len(df) == 1 + assert df.iloc[0]['value'] == 150.0 + + +def test_fetch_rise_series_returns_empty_on_empty_id(mock_api): + # No request must be issued when item_id is falsy. + df = get_reservoir.fetch_rise_series(None, '2024-01-01', '2024-01-03') + assert df.empty + assert mock_api.call_count == 0 + + +def test_get_reservoir_returns_empty_for_unmapped_station(tmp_path): + mapping_path = tmp_path / 'reservoirs.txt' + mapping_path.write_text("USGS:09163500, Lake Powell, 6126, 6127\n") + df = get_reservoir.get_reservoir( + 'USGS:UNREGULATED', '2024-01-01', '2024-01-03', mapping_path=str(mapping_path)) + assert df.empty + assert list(df.columns) == ['Date', 'reservoir_storage', 'reservoir_release'] + + +def test_get_reservoir_sums_multiple_reservoirs(mock_api, tmp_path): + mapping_path = tmp_path / 'reservoirs.txt' + mapping_path.write_text( + "USGS:09163500, ResA, 100, 200\n" + "USGS:09163500, ResB, 101, 201\n" + ) + + # Different responses by query string -- match on itemId. + def _by_item(request, context): + item_id = request.qs.get('itemid', [''])[0] + value_map = { + '100': 1000.0, # ResA storage + '101': 500.0, # ResB storage + '200': 50.0, # ResA release + '201': 25.0, # ResB release + } + v = value_map[item_id] + return { + 'data': [{'attributes': {'dateTime': '2024-01-01T00:00:00Z', 'result': v}}], + 'links': {}, + } + + mock_api.get(get_reservoir.RISE_RESULT_URL, json=_by_item) + df = get_reservoir.get_reservoir( + 'USGS:09163500', '2024-01-01', '2024-01-01', mapping_path=str(mapping_path)) + assert len(df) == 1 + assert df.iloc[0]['reservoir_storage'] == 1500.0 # 1000 + 500 + assert df.iloc[0]['reservoir_release'] == 75.0 # 50 + 25 diff --git a/tests/test_windowing.py b/tests/test_windowing.py index a4fad59..a306377 100644 --- a/tests/test_windowing.py +++ b/tests/test_windowing.py @@ -20,6 +20,12 @@ def _make_station_frame(site_id, station_idx, basin_idx, n_days=120, start='2022 'TMIN': rng.standard_normal(n_days), 'TMAX': rng.standard_normal(n_days), 'SWE': rng.standard_normal(n_days), + 'soil_moisture': rng.standard_normal(n_days), + 'sm_observed': rng.integers(0, 2, n_days), + 'drought_index': rng.standard_normal(n_days), + 'reservoir_storage': rng.standard_normal(n_days), + 'reservoir_release': rng.standard_normal(n_days), + 'reservoir_observed': rng.integers(0, 2, n_days), 'doy_sin': np.sin(2 * np.pi * np.arange(n_days) / 365), 'doy_cos': np.cos(2 * np.pi * np.arange(n_days) / 365), }) @@ -39,6 +45,21 @@ def test_decoder_features_carry_no_flow_information(): assert 'Max Flow' not in windowing.DECODER_FEATURES # SWE is also a current-conditions feature -- no skillful 14-day forecast. assert 'SWE' not in windowing.DECODER_FEATURES + # Same for soil moisture: SMAP is encoder-only, never in the decoder window. + assert 'soil_moisture' not in windowing.DECODER_FEATURES + assert 'soil_moisture' in windowing.ENCODER_FEATURES + # The sm_observed indicator is also encoder-only. + assert 'sm_observed' in windowing.ENCODER_FEATURES + assert 'sm_observed' not in windowing.DECODER_FEATURES + # USDM drought + USBR reservoir are encoder-only too -- none has a + # skillful 14-day forecast available. + assert 'drought_index' in windowing.ENCODER_FEATURES + assert 'drought_index' not in windowing.DECODER_FEATURES + assert 'reservoir_storage' in windowing.ENCODER_FEATURES + assert 'reservoir_release' in windowing.ENCODER_FEATURES + assert 'reservoir_observed' in windowing.ENCODER_FEATURES + for c in ('reservoir_storage', 'reservoir_release', 'reservoir_observed'): + assert c not in windowing.DECODER_FEATURES def test_build_windows_shapes_are_correct(): @@ -145,3 +166,9 @@ def test_build_windows_rejects_dataframe_missing_required_columns(): df = _make_station_frame('USGS:A', 1, 1).drop(columns=['SWE']) with pytest.raises(ValueError): windowing.build_windows(df) + + +def test_build_windows_rejects_dataframe_missing_soil_moisture(): + df = _make_station_frame('USGS:A', 1, 1).drop(columns=['soil_moisture']) + with pytest.raises(ValueError): + windowing.build_windows(df)