diff --git a/src/fev/metrics.py b/src/fev/metrics.py index 965c316..0091e73 100644 --- a/src/fev/metrics.py +++ b/src/fev/metrics.py @@ -1,10 +1,7 @@ from typing import Any, Callable, Type -import datasets import numpy as np -from fev.constants import PREDICTIONS - MetricConfig = str | dict[str, Any] @@ -19,25 +16,43 @@ def name(self) -> str: return self.__class__.__name__ @staticmethod - def _safemean(arr: np.ndarray) -> float: - """Compute mean of an array, ignoring NaN, Inf, and -Inf values.""" - return float(np.mean(arr[np.isfinite(arr)])) - - @staticmethod - def _get_y_test(test_data: datasets.Dataset, target_column: str) -> np.ndarray: - """ "Return array of shape [len(test_data), horizon] with ground truth values in float64 dtype.""" - return np.array(test_data[target_column], dtype=np.float64) + def _safemean(arr: np.ndarray, axis=None) -> float | np.ndarray: + """Compute mean ignoring NaN, Inf, and -Inf values.""" + mask = ~np.isfinite(arr) + if mask.any(): + arr = np.where(mask, np.nan, arr) + return np.nanmean(arr, axis=axis) def compute( self, *, - test_data: datasets.Dataset, - predictions: datasets.Dataset, - past_data: datasets.Dataset, + y_true: np.ndarray, + y_pred: np.ndarray, + y_past: np.ndarray, + y_past_lengths: np.ndarray, + q_pred: np.ndarray, seasonality: int, quantile_levels: list[float], - target_column: str = "target", ) -> float: + """Compute the metric score. Computed per target dim, then averaged across dims. + + Parameters + ---------- + y_true : np.ndarray [N, H, D] + Ground truth. N=number of time series, H=forecast horizon, D=target dimensions. + y_pred : np.ndarray [N, H, D] + Point forecast predictions, same shape as y_true. + y_past : np.ndarray [total_T, D] + Concatenated historical observations for all items (ragged time axis). + y_past_lengths : np.ndarray [N] + Number of past observations per item. sum(y_past_lengths) == total_T. + q_pred : np.ndarray [N, H, D, Q] + Quantile predictions. Q=len(quantile_levels), or Q=0 if none requested. + seasonality : int + Seasonal period for scaled error metrics (MASE, RMSSE, SQL). + quantile_levels : list[float] + Quantile levels in (0, 1) corresponding to q_pred's last axis. + """ raise NotImplementedError @@ -65,16 +80,16 @@ class MAE(Metric): def compute( self, *, - test_data: datasets.Dataset, - predictions: datasets.Dataset, - past_data: datasets.Dataset, + y_true: np.ndarray, + y_pred: np.ndarray, + y_past: np.ndarray, + y_past_lengths: np.ndarray, + q_pred: np.ndarray, seasonality: int, quantile_levels: list[float], - target_column: str = "target", - ): - y_test = self._get_y_test(test_data, target_column=target_column) - y_pred = np.array(predictions[PREDICTIONS]) - return np.nanmean(np.abs(y_test - y_pred)) + ) -> float: + per_dim = np.nanmean(np.abs(y_true - y_pred), axis=(0, 1)) # [D] + return float(np.mean(per_dim)) class WAPE(Metric): @@ -86,17 +101,18 @@ def __init__(self, epsilon: float = 0.0) -> None: def compute( self, *, - test_data: datasets.Dataset, - predictions: datasets.Dataset, - past_data: datasets.Dataset, + y_true: np.ndarray, + y_pred: np.ndarray, + y_past: np.ndarray, + y_past_lengths: np.ndarray, + q_pred: np.ndarray, seasonality: int, quantile_levels: list[float], - target_column: str = "target", - ): - y_test = self._get_y_test(test_data, target_column=target_column) - y_pred = np.array(predictions[PREDICTIONS]) - - return np.nanmean(np.abs(y_test - y_pred)) / max(self.epsilon, np.nanmean(np.abs(y_test))) + ) -> float: + abs_err_per_dim = np.nanmean(np.abs(y_true - y_pred), axis=(0, 1)) # [D] + abs_true_per_dim = np.nanmean(np.abs(y_true), axis=(0, 1)) # [D] + per_dim = abs_err_per_dim / np.maximum(abs_true_per_dim, self.epsilon) + return float(np.mean(per_dim)) class MASE(Metric): @@ -113,21 +129,20 @@ def __init__(self, epsilon: float = 0.0) -> None: def compute( self, *, - test_data: datasets.Dataset, - predictions: datasets.Dataset, - past_data: datasets.Dataset, + y_true: np.ndarray, + y_pred: np.ndarray, + y_past: np.ndarray, + y_past_lengths: np.ndarray, + q_pred: np.ndarray, seasonality: int, quantile_levels: list[float], - target_column: str = "target", - ): - y_test = self._get_y_test(test_data, target_column=target_column) - y_pred = np.array(predictions[PREDICTIONS]) - + ) -> float: seasonal_error = _abs_seasonal_error_per_item( - past_data=past_data, seasonality=seasonality, target_column=target_column - ) + y_past=y_past, y_past_lengths=y_past_lengths, seasonality=seasonality + ) # [N, D] seasonal_error = np.clip(seasonal_error, self.epsilon, None) - return self._safemean(np.abs(y_test - y_pred) / seasonal_error[:, None]) + scaled = np.abs(y_true - y_pred) / seasonal_error[:, None, :] # [N, H, D] + return float(np.mean(self._safemean(scaled, axis=(0, 1)))) class RMSE(Metric): @@ -136,16 +151,16 @@ class RMSE(Metric): def compute( self, *, - test_data: datasets.Dataset, - predictions: datasets.Dataset, - past_data: datasets.Dataset, + y_true: np.ndarray, + y_pred: np.ndarray, + y_past: np.ndarray, + y_past_lengths: np.ndarray, + q_pred: np.ndarray, seasonality: int, quantile_levels: list[float], - target_column: str = "target", - ): - y_test = self._get_y_test(test_data, target_column=target_column) - y_pred = np.array(predictions[PREDICTIONS]) - return np.sqrt(np.nanmean((y_test - y_pred) ** 2)) + ) -> float: + per_dim = np.sqrt(np.nanmean((y_true - y_pred) ** 2, axis=(0, 1))) # [D] + return float(np.mean(per_dim)) class RMSSE(Metric): @@ -162,20 +177,20 @@ def __init__(self, epsilon: float = 0.0) -> None: def compute( self, *, - test_data: datasets.Dataset, - predictions: datasets.Dataset, - past_data: datasets.Dataset, + y_true: np.ndarray, + y_pred: np.ndarray, + y_past: np.ndarray, + y_past_lengths: np.ndarray, + q_pred: np.ndarray, seasonality: int, quantile_levels: list[float], - target_column: str = "target", - ): - y_test = self._get_y_test(test_data, target_column=target_column) - y_pred = np.array(predictions[PREDICTIONS]) + ) -> float: seasonal_error = _squared_seasonal_error_per_item( - past_data, seasonality=seasonality, target_column=target_column - ) + y_past=y_past, y_past_lengths=y_past_lengths, seasonality=seasonality + ) # [N, D] seasonal_error = np.clip(seasonal_error, self.epsilon, None) - return np.sqrt(self._safemean((y_test - y_pred) ** 2 / seasonal_error[:, None])) + scaled = (y_true - y_pred) ** 2 / seasonal_error[:, None, :] # [N, H, D] + return float(np.mean(np.sqrt(self._safemean(scaled, axis=(0, 1))))) class MSE(Metric): @@ -184,16 +199,16 @@ class MSE(Metric): def compute( self, *, - test_data: datasets.Dataset, - predictions: datasets.Dataset, - past_data: datasets.Dataset, + y_true: np.ndarray, + y_pred: np.ndarray, + y_past: np.ndarray, + y_past_lengths: np.ndarray, + q_pred: np.ndarray, seasonality: int, quantile_levels: list[float], - target_column: str = "target", - ): - y_test = self._get_y_test(test_data, target_column=target_column) - y_pred = np.array(predictions[PREDICTIONS]) - return np.nanmean((y_test - y_pred) ** 2) + ) -> float: + per_dim = np.nanmean((y_true - y_pred) ** 2, axis=(0, 1)) # [D] + return float(np.mean(per_dim)) class RMSLE(Metric): @@ -202,16 +217,16 @@ class RMSLE(Metric): def compute( self, *, - test_data: datasets.Dataset, - predictions: datasets.Dataset, - past_data: datasets.Dataset, + y_true: np.ndarray, + y_pred: np.ndarray, + y_past: np.ndarray, + y_past_lengths: np.ndarray, + q_pred: np.ndarray, seasonality: int, quantile_levels: list[float], - target_column: str = "target", - ): - y_test = self._get_y_test(test_data, target_column=target_column) - y_pred = np.array(predictions[PREDICTIONS]) - return np.sqrt(np.nanmean((np.log1p(y_test) - np.log1p(y_pred)) ** 2)) + ) -> float: + per_dim = np.sqrt(np.nanmean((np.log1p(y_true) - np.log1p(y_pred)) ** 2, axis=(0, 1))) # [D] + return float(np.mean(per_dim)) class MAPE(Metric): @@ -220,17 +235,16 @@ class MAPE(Metric): def compute( self, *, - test_data: datasets.Dataset, - predictions: datasets.Dataset, - past_data: datasets.Dataset, + y_true: np.ndarray, + y_pred: np.ndarray, + y_past: np.ndarray, + y_past_lengths: np.ndarray, + q_pred: np.ndarray, seasonality: int, quantile_levels: list[float], - target_column: str = "target", - ): - y_test = self._get_y_test(test_data, target_column=target_column) - y_pred = np.array(predictions[PREDICTIONS]) - ratio = np.abs(y_test - y_pred) / np.abs(y_test) - return self._safemean(ratio) + ) -> float: + ratio = np.abs(y_true - y_pred) / np.abs(y_true) # [N, H, D] + return float(np.mean(self._safemean(ratio, axis=(0, 1)))) class SMAPE(Metric): @@ -239,16 +253,16 @@ class SMAPE(Metric): def compute( self, *, - test_data: datasets.Dataset, - predictions: datasets.Dataset, - past_data: datasets.Dataset, + y_true: np.ndarray, + y_pred: np.ndarray, + y_past: np.ndarray, + y_past_lengths: np.ndarray, + q_pred: np.ndarray, seasonality: int, quantile_levels: list[float], - target_column: str = "target", - ): - y_test = self._get_y_test(test_data, target_column=target_column) - y_pred = np.array(predictions[PREDICTIONS]) - return self._safemean(2 * np.abs(y_test - y_pred) / (np.abs(y_test) + np.abs(y_pred))) + ) -> float: + val = 2 * np.abs(y_true - y_pred) / (np.abs(y_true) + np.abs(y_pred)) # [N, H, D] + return float(np.mean(self._safemean(val, axis=(0, 1)))) class MQL(Metric): @@ -259,22 +273,19 @@ class MQL(Metric): def compute( self, *, - test_data: datasets.Dataset, - predictions: datasets.Dataset, - past_data: datasets.Dataset, + y_true: np.ndarray, + y_pred: np.ndarray, + y_past: np.ndarray, + y_past_lengths: np.ndarray, + q_pred: np.ndarray, seasonality: int, quantile_levels: list[float], - target_column: str = "target", - ): - if quantile_levels is None or len(quantile_levels) == 0: - raise ValueError(f"{self.__class__.__name__} cannot be computed if quantile_levels is None") - ql = _quantile_loss( - test_data=test_data, - predictions=predictions, - quantile_levels=quantile_levels, - target_column=target_column, - ) - return np.nanmean(ql) + ) -> float: + if len(quantile_levels) == 0: + raise ValueError(f"{self.__class__.__name__} cannot be computed without quantile_levels") + ql = _quantile_loss(y_true=y_true, q_pred=q_pred, quantile_levels=quantile_levels) # [N, H, D, Q] + per_dim = np.nanmean(ql, axis=(0, 1, 3)) # [D] + return float(np.mean(per_dim)) class SQL(Metric): @@ -293,25 +304,22 @@ def __init__(self, epsilon: float = 0.0) -> None: def compute( self, *, - test_data: datasets.Dataset, - predictions: datasets.Dataset, - past_data: datasets.Dataset, + y_true: np.ndarray, + y_pred: np.ndarray, + y_past: np.ndarray, + y_past_lengths: np.ndarray, + q_pred: np.ndarray, seasonality: int, quantile_levels: list[float], - target_column: str = "target", - ): - ql = _quantile_loss( - test_data=test_data, - predictions=predictions, - quantile_levels=quantile_levels, - target_column=target_column, - ) - ql_per_time_step = np.nanmean(ql, axis=2) # [num_items, horizon] + ) -> float: + ql = _quantile_loss(y_true=y_true, q_pred=q_pred, quantile_levels=quantile_levels) # [N, H, D, Q] + ql_avg_q = np.nanmean(ql, axis=3) # [N, H, D] seasonal_error = _abs_seasonal_error_per_item( - past_data=past_data, seasonality=seasonality, target_column=target_column - ) + y_past=y_past, y_past_lengths=y_past_lengths, seasonality=seasonality + ) # [N, D] seasonal_error = np.clip(seasonal_error, self.epsilon, None) - return self._safemean(ql_per_time_step / seasonal_error[:, None]) + scaled = ql_avg_q / seasonal_error[:, None, :] # [N, H, D] + return float(np.mean(self._safemean(scaled, axis=(0, 1)))) class WQL(Metric): @@ -325,94 +333,120 @@ def __init__(self, epsilon: float = 0.0) -> None: def compute( self, *, - test_data: datasets.Dataset, - predictions: datasets.Dataset, - past_data: datasets.Dataset, + y_true: np.ndarray, + y_pred: np.ndarray, + y_past: np.ndarray, + y_past_lengths: np.ndarray, + q_pred: np.ndarray, seasonality: int, quantile_levels: list[float], - target_column: str = "target", - ): - ql = _quantile_loss( - test_data=test_data, - predictions=predictions, - quantile_levels=quantile_levels, - target_column=target_column, - ) - return np.nanmean(ql) / max(self.epsilon, np.nanmean(np.abs(np.array(test_data[target_column])))) + ) -> float: + ql = _quantile_loss(y_true=y_true, q_pred=q_pred, quantile_levels=quantile_levels) # [N, H, D, Q] + ql_per_dim = np.nanmean(ql, axis=(0, 1, 3)) # [D] + abs_true_per_dim = np.nanmean(np.abs(y_true), axis=(0, 1)) # [D] + per_dim = ql_per_dim / np.maximum(abs_true_per_dim, self.epsilon) + return float(np.mean(per_dim)) def _quantile_loss( *, - test_data: datasets.Dataset, - predictions: datasets.Dataset, + y_true: np.ndarray, + q_pred: np.ndarray, quantile_levels: list[float], - target_column: str, -): - """Compute quantile loss for each observation""" - pred_per_quantile = [] - for q in quantile_levels: - pred_per_quantile.append(np.array(predictions[str(q)])) - q_pred = np.stack(pred_per_quantile, axis=-1) # [num_series, horizon, len(quantile_levels)] - y_test = Metric._get_y_test(test_data, target_column=target_column)[..., None] # [num_series, horizon, 1] - assert y_test.shape[:-1] == q_pred.shape[:-1] - return 2 * np.abs((y_test - q_pred) * ((y_test <= q_pred) - np.array(quantile_levels))) +) -> np.ndarray: + """Compute quantile loss. + + Returns + ------- + np.ndarray [N, H, D, Q] + """ + y_true_expanded = y_true[..., None] # [N, H, D, 1] + q_arr = np.array(quantile_levels) # [Q] + return 2 * np.abs((y_true_expanded - q_pred) * ((y_true_expanded <= q_pred) - q_arr)) def _seasonal_error_per_item( - arrays: list[np.ndarray], + *, + y_past: np.ndarray, + y_past_lengths: np.ndarray, seasonality: int, aggregate_fn: Callable, ) -> np.ndarray: - """Compute seasonal error for each time series using vectorized operations. - - Uses bincount with weights to efficiently compute per-series aggregations. + """Compute seasonal error for each (item, dim) pair. + + Parameters + ---------- + y_past : np.ndarray [total_T, D] + Concatenated past observations. + y_past_lengths : np.ndarray [N] + Number of observations per item. + seasonality : int + Seasonal period. + aggregate_fn : Callable + Applied element-wise to seasonal diffs (e.g. np.abs or np.square). + + Returns + ------- + np.ndarray [N, D] """ - num_series = len(arrays) + num_series = len(y_past_lengths) + num_dims = y_past.shape[1] + if num_series == 0: - return np.array([], dtype="float64") + return np.array([], dtype="float64").reshape(0, 0) - lengths = np.array([a.size for a in arrays], dtype=np.int64) - num_diffs_per_series = np.maximum(lengths - seasonality, 0) + num_diffs_per_series = np.maximum(y_past_lengths - seasonality, 0) if num_diffs_per_series.sum() == 0: - return np.full(num_series, np.nan, dtype="float64") + return np.full((num_series, num_dims), np.nan, dtype="float64") - flat = np.concatenate(arrays).astype("float64") - series_starts = np.concatenate([[0], np.cumsum(lengths[:-1])]) + # Fast path: all items have equal length — reshape + slice instead of fancy indexing + if np.all(y_past_lengths == y_past_lengths[0]): + T = int(y_past_lengths[0]) + y_reshaped = y_past.reshape(num_series, T, num_dims) + diffs = y_reshaped[:, seasonality:, :] - y_reshaped[:, :-seasonality, :] + return np.nanmean(aggregate_fn(diffs), axis=1) - # Build indices for all (t, t-seasonality) pairs across all series total_diffs = int(num_diffs_per_series.sum()) series_ids = np.repeat(np.arange(num_series, dtype=np.int64), num_diffs_per_series) diff_offsets = np.arange(total_diffs) - np.repeat( np.cumsum(num_diffs_per_series) - num_diffs_per_series, num_diffs_per_series ) - idx_current = series_starts[series_ids] + seasonality + diff_offsets + offsets = np.empty(num_series + 1, dtype=np.int64) + offsets[0] = 0 + np.cumsum(y_past_lengths, out=offsets[1:]) + idx_current = offsets[series_ids] + seasonality + diff_offsets idx_lagged = idx_current - seasonality - diffs = flat[idx_current] - flat[idx_lagged] - errors = aggregate_fn(diffs) + diffs = y_past[idx_current] - y_past[idx_lagged] # [total_diffs, D] + errors = aggregate_fn(diffs) # [total_diffs, D] - # Compute per-series nanmean via bincount - valid = ~np.isnan(errors) - sums = np.bincount(series_ids, weights=np.where(valid, errors, 0.0), minlength=num_series) - counts = np.bincount(series_ids, weights=valid.astype("float64"), minlength=num_series) + valid = ~np.isnan(errors) # [total_diffs, D] + result = np.full((num_series, num_dims), np.nan, dtype="float64") + for d in range(num_dims): + sums = np.bincount(series_ids, weights=np.where(valid[:, d], errors[:, d], 0.0), minlength=num_series) + counts = np.bincount(series_ids, weights=valid[:, d].astype("float64"), minlength=num_series) + mask = counts > 0 + result[mask, d] = sums[mask] / counts[mask] - result = np.full(num_series, np.nan, dtype="float64") - np.divide(sums, counts, out=result, where=counts > 0) return result -def _abs_seasonal_error_per_item(past_data: datasets.Dataset, seasonality: int, target_column: str) -> np.ndarray: - """Compute mean absolute seasonal error for each time series in past_data.""" - arrays = past_data.with_format("numpy")[target_column] - return _seasonal_error_per_item(arrays, seasonality, aggregate_fn=np.abs) +def _abs_seasonal_error_per_item(*, y_past: np.ndarray, y_past_lengths: np.ndarray, seasonality: int) -> np.ndarray: + """Compute mean absolute seasonal error. Returns [N, D].""" + return _seasonal_error_per_item( + y_past=y_past, y_past_lengths=y_past_lengths, seasonality=seasonality, aggregate_fn=np.abs + ) -def _squared_seasonal_error_per_item(past_data: datasets.Dataset, seasonality: int, target_column: str) -> np.ndarray: - """Compute mean squared seasonal error for each time series in past_data.""" - arrays = past_data.with_format("numpy")[target_column] - return _seasonal_error_per_item(arrays, seasonality, aggregate_fn=np.square) +def _squared_seasonal_error_per_item( + *, y_past: np.ndarray, y_past_lengths: np.ndarray, seasonality: int +) -> np.ndarray: + """Compute mean squared seasonal error. Returns [N, D].""" + return _seasonal_error_per_item( + y_past=y_past, y_past_lengths=y_past_lengths, seasonality=seasonality, aggregate_fn=np.square + ) AVAILABLE_METRICS: dict[str, Type[Metric]] = { diff --git a/src/fev/task.py b/src/fev/task.py index b1cd54c..116fc13 100644 --- a/src/fev/task.py +++ b/src/fev/task.py @@ -9,6 +9,7 @@ import datasets import numpy as np import pandas as pd +import pyarrow.compute as pc import pydantic from pydantic_core import ArgsKwargs @@ -136,8 +137,6 @@ def compute_metrics( This is a convenience method that exists for debugging and additional evaluation. """ past_data, _, test_data = self._get_past_future_test_data() - past_data.set_format("numpy") - test_data.set_format("numpy") for target_column, predictions_for_column in predictions.items(): if len(predictions_for_column) != len(test_data): @@ -146,23 +145,59 @@ def compute_metrics( f"match the length of test data ({len(test_data)})" ) + N = len(test_data) + D = len(self.target_columns) + H = self.horizon + Q = len(quantile_levels) + + # y_true [N, H, D] — via pyarrow for fast column access + test_table = test_data.data.table + y_true = np.stack( + [pc.list_flatten(test_table.column(col)).to_numpy(zero_copy_only=False) for col in self.target_columns], + axis=1, + dtype=np.float64, + ).reshape(N, H, D) + + # y_pred [N, H, D] + pred_arrs = [] + pred_tables = {} + for col in self.target_columns: + pred_tables[col] = predictions[col].data.table + pred_arrs.append(pc.list_flatten(pred_tables[col].column(PREDICTIONS)).to_numpy(zero_copy_only=False)) + y_pred = np.stack(pred_arrs, axis=1, dtype=np.float64).reshape(N, H, D) + + # q_pred [N, H, D, Q] + if Q > 0: + q_arrs = [] + for col in self.target_columns: + for q in quantile_levels: + q_arrs.append(pc.list_flatten(pred_tables[col].column(str(q))).to_numpy(zero_copy_only=False)) + q_pred = np.stack(q_arrs, axis=1, dtype=np.float64).reshape(N, H, D, Q) + else: + q_pred = np.empty((N, H, D, 0), dtype=np.float64) + + # y_past [total_T, D] + lengths [N] + past_table = past_data.data.table + y_past_flat = np.stack( + [pc.list_flatten(past_table.column(col)).to_numpy(zero_copy_only=False) for col in self.target_columns], + axis=1, + dtype=np.float64, + ) + y_past_lengths = pc.list_value_length(past_table.column(self.target_columns[0])).to_numpy() + test_scores: dict[str, float] = {} with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=RuntimeWarning) for metric in metrics: - scores = [] - for col in self.target_columns: - scores.append( - metric.compute( - test_data=test_data, - predictions=predictions[col], - past_data=past_data, - seasonality=seasonality, - quantile_levels=quantile_levels, - target_column=col, - ) - ) - test_scores[metric.name] = float(np.mean(scores)) + test_scores[metric.name] = metric.compute( + y_true=y_true, + y_pred=y_pred, + y_past=y_past_flat, + y_past_lengths=y_past_lengths, + q_pred=q_pred, + seasonality=seasonality, + quantile_levels=quantile_levels, + ) return test_scores @@ -746,21 +781,33 @@ def _to_dataset(preds: datasets.Dataset | list[dict]) -> datasets.Dataset: ) if missing_columns := set(self.target_columns) - set(predictions.keys()): raise ValueError(f"Missing predictions for columns {missing_columns} (got {sorted(predictions.keys())})") - predictions = predictions.cast(self.predictions_schema).with_format("numpy") - for target_column, predictions_for_column in predictions.items(): - self._assert_all_columns_finite(predictions_for_column) - return predictions - @staticmethod - def _assert_all_columns_finite(predictions: datasets.Dataset) -> None: - for col in predictions.column_names: - nan_row_idx, _ = np.where(~np.isfinite(np.array(predictions[col]))) - if len(nan_row_idx) > 0: + expected_columns = set(self.predictions_schema.keys()) + for target_col, pred_ds in predictions.items(): + table = pred_ds.data.table + if missing := expected_columns - set(table.column_names): + raise ValueError( + f"Predictions for '{target_col}' are missing columns {sorted(missing)}. " + f"Expected: {sorted(expected_columns)}" + ) + lengths = pc.list_value_length(table.column(PREDICTIONS)).to_numpy() + if not np.all(lengths == self.horizon): + bad_idx = int(np.argmax(lengths != self.horizon)) raise ValueError( - "Predictions contain NaN or Inf values. " - f"First invalid value encountered in column {col} for item {nan_row_idx[0]}:\n" - f"{predictions[int(nan_row_idx[0])]}" + f"Predictions for '{target_col}' have wrong length at item {bad_idx}: " + f"got {lengths[bad_idx]}, expected {self.horizon}" ) + for col in expected_columns: + flat = pc.list_flatten(table.column(col)) + if not pc.all(pc.is_finite(flat)).as_py(): + flat_np = flat.to_numpy(zero_copy_only=False) + bad_flat_idx = int(np.argmax(~np.isfinite(flat_np))) + bad_item = bad_flat_idx // self.horizon + raise ValueError( + f"Predictions contain NaN or Inf values. " + f"First invalid value in column '{col}' for target '{target_col}' at item {bad_item}." + ) + return predictions def evaluation_summary( self, diff --git a/test/test_metrics.py b/test/test_metrics.py index acaea6f..0343be4 100644 --- a/test/test_metrics.py +++ b/test/test_metrics.py @@ -75,6 +75,13 @@ def _reference_seasonal_error_per_item(arrays, seasonality, aggregate_fn): return np.array(result, dtype="float64") +def _arrays_to_flat(arrays): + """Helper to convert list of 1D arrays to flat [total_T, 1] + lengths [N].""" + lengths = np.array([len(a) for a in arrays], dtype=np.int64) + flat = np.concatenate(arrays).astype(np.float64).reshape(-1, 1) if arrays else np.empty((0, 1), dtype=np.float64) + return flat, lengths + + @pytest.mark.parametrize("aggregate_fn", [np.abs, np.square]) def test_seasonal_error_per_item(aggregate_fn): """Test vectorized impl against reference with mixed edge cases.""" @@ -88,7 +95,10 @@ def test_seasonal_error_per_item(aggregate_fn): ] seasonality = 2 - result = _seasonal_error_per_item(arrays, seasonality, aggregate_fn) + flat, lengths = _arrays_to_flat(arrays) + result = _seasonal_error_per_item( + y_past=flat, y_past_lengths=lengths, seasonality=seasonality, aggregate_fn=aggregate_fn + )[:, 0] expected = _reference_seasonal_error_per_item(arrays, seasonality, aggregate_fn) np.testing.assert_allclose(result, expected) @@ -96,6 +106,7 @@ def test_seasonal_error_per_item(aggregate_fn): def test_seasonal_error_per_item_empty(): """Test with empty input.""" - result = _seasonal_error_per_item([], 2, np.abs) - assert len(result) == 0 + flat, lengths = _arrays_to_flat([]) + result = _seasonal_error_per_item(y_past=flat, y_past_lengths=lengths, seasonality=2, aggregate_fn=np.abs) + assert result.size == 0 assert result.dtype == np.float64