From 413ec070258d1adfc609803b69c7d1234db932e6 Mon Sep 17 00:00:00 2001 From: Oleksandr Shchur Date: Tue, 12 May 2026 12:55:15 +0000 Subject: [PATCH 01/10] Add a skeleton for the foundation model API class --- src/autogluon/cloud/model/__init__.py | 3 + src/autogluon/cloud/model/foundation_model.py | 233 ++++++++++++++++++ src/autogluon/cloud/model/registry.py | 65 +++++ 3 files changed, 301 insertions(+) create mode 100644 src/autogluon/cloud/model/__init__.py create mode 100644 src/autogluon/cloud/model/foundation_model.py create mode 100644 src/autogluon/cloud/model/registry.py diff --git a/src/autogluon/cloud/model/__init__.py b/src/autogluon/cloud/model/__init__.py new file mode 100644 index 0000000..798d7a3 --- /dev/null +++ b/src/autogluon/cloud/model/__init__.py @@ -0,0 +1,3 @@ +from .foundation_model import FoundationModel, TabularFoundationModel, TimeSeriesFoundationModel + +__all__ = ["FoundationModel", "TabularFoundationModel", "TimeSeriesFoundationModel"] diff --git a/src/autogluon/cloud/model/foundation_model.py b/src/autogluon/cloud/model/foundation_model.py new file mode 100644 index 0000000..1792c8f --- /dev/null +++ b/src/autogluon/cloud/model/foundation_model.py @@ -0,0 +1,233 @@ +"""FoundationModel — deploy and predict with pretrained foundation models on AWS.""" + +from __future__ import annotations + +from abc import abstractmethod +from typing import Any, Dict, List, Literal, Optional, Union + +import pandas as pd + +from autogluon.cloud.endpoint.endpoint import Endpoint +from autogluon.cloud.job.remote_job import RemoteJob + +from .registry import get_model_config + + +class FoundationModel: + """ + Pretrained foundation model inference on AWS. + + Factory: FoundationModel("chronos-bolt-base", ...) returns the appropriate + task-specific subclass (TimeSeriesFoundationModel, TabularFoundationModel). + + Examples + -------- + >>> model = FoundationModel("chronos-bolt-base", role_arn="arn:...") + >>> endpoint = model.deploy() + >>> predictions = endpoint.predict(data) + >>> endpoint.delete_endpoint() + """ + + def __new__(cls, model_id: str, **kwargs) -> "FoundationModel": + if cls is not FoundationModel: + return super().__new__(cls) + config = get_model_config(model_id) + task = config["task"] + if task == "timeseries": + return super().__new__(TimeSeriesFoundationModel) + elif task == "tabular": + return super().__new__(TabularFoundationModel) + raise ValueError(f"Unsupported task: {task}") + + def __init__( + self, + model_id: str, + backend: Literal["sagemaker"] = "sagemaker", + role_arn: Optional[str] = None, + region: Optional[str] = None, + s3_output_path: Optional[str] = None, + model_config: Optional[Dict[str, Any]] = None, + ): + self.model_id = model_id + self.role_arn = role_arn + self.region = region + self.s3_output_path = s3_output_path + self._config = get_model_config(model_id) + # Merge user overrides on top of registry defaults + self.model_config = {**self._config.get("model_config", {}), **(model_config or {})} + # TODO: instantiate backend via BackendFactory + self._backend_type = backend + + def deploy( + self, + instance_type: Optional[str] = None, + mode: Literal["realtime", "serverless", "async"] = "realtime", + endpoint_name: Optional[str] = None, + model_artifact_path: Optional[str] = None, + model_config: Optional[Dict[str, Any]] = None, + wait: bool = True, + ) -> Endpoint: + """ + Deploy model to an endpoint. + + Parameters + ---------- + instance_type + Instance type for the endpoint. + If None, will use the default from the model registry. + mode + Endpoint type. + endpoint_name + Custom endpoint name. + If None, will auto-generate a unique name. + model_artifact_path + S3 path to pre-cached model weights (for VPC / fast cold start). + If None, weights are downloaded from HuggingFace on cold start. + model_config + Override default inference config (prediction_length, quantile_levels, etc.) + wait + Whether to block until the endpoint is ready. + + Returns + ------- + Endpoint + """ + raise NotImplementedError + + @abstractmethod + def predict(self, data: Union[str, pd.DataFrame], wait: bool = True, **kwargs) -> Union[pd.DataFrame, RemoteJob]: + """Subclasses override with task-specific signature.""" + ... + + def fit( + self, + train_data: Union[str, pd.DataFrame], + output_path: Optional[str] = None, + instance_type: Optional[str] = None, + wait: bool = True, + **kwargs, + ) -> "FoundationModel": + """ + Fine-tune the model. Returns a new FoundationModel pointing to the fine-tuned artifact. + + Parameters + ---------- + train_data + Training data as DataFrame or S3 path. + output_path + S3 path to store fine-tuned model. + If None, will auto-generate under s3_output_path. + instance_type + Instance type for the training job. + If None, will use the default from the model registry. + wait + If True, block until training completes. + + Returns + ------- + FoundationModel + New instance with model_config pointing to the fine-tuned artifact. + """ + raise NotImplementedError + + def cache_model_artifact(self, s3_path: str) -> str: + """ + Pre-cache model weights to S3 for VPC or production use. + + Launches a small job that downloads weights from HuggingFace + and writes them to S3, avoiding large local downloads. + + Parameters + ---------- + s3_path + S3 path where the model weights should be cached. + + Returns + ------- + str + S3 path to the cached artifact. + """ + raise NotImplementedError + + +class TimeSeriesFoundationModel(FoundationModel): + """Foundation model for time series forecasting (Chronos, etc.).""" + + def predict( + self, + data: Union[str, pd.DataFrame], + known_covariates: Optional[Union[str, pd.DataFrame]] = None, + prediction_length: Optional[int] = None, + quantile_levels: Optional[List[float]] = None, + output_path: Optional[str] = None, + instance_type: Optional[str] = None, + wait: bool = True, + ) -> Union[pd.DataFrame, RemoteJob]: + """ + Run batch prediction for time series. + + Parameters + ---------- + data + Historical time series in long format (DataFrame or S3 path). + known_covariates + Future values of known covariates (DataFrame or S3 path). + prediction_length + Number of time steps to forecast. + If None, will use the default from the model registry. + quantile_levels + Quantiles to predict. + If None, will use the default from the model registry. + output_path + S3 path to store predictions. + If None, will auto-generate under s3_output_path. + instance_type + Instance type for the prediction job. + If None, will use the default from the model registry. + wait + If True, block and return DataFrame. If False, return the job handle. + + Returns + ------- + Union[pd.DataFrame, RemoteJob] + """ + raise NotImplementedError + + +class TabularFoundationModel(FoundationModel): + """Foundation model for tabular prediction (Mitra, TabICL, etc.).""" + + def predict( + self, + train_data: Union[str, pd.DataFrame], + test_data: Union[str, pd.DataFrame], + label: str = "target", + output_path: Optional[str] = None, + instance_type: Optional[str] = None, + wait: bool = True, + ) -> Union[pd.DataFrame, RemoteJob]: + """ + Run batch prediction for tabular tasks. + + Parameters + ---------- + train_data + Labeled few-shot context for the foundation model. + test_data + Unlabeled data to predict on. + label + Target column name in train_data. + output_path + S3 path to store predictions. + If None, will auto-generate under s3_output_path. + instance_type + Instance type for the prediction job. + If None, will use the default from the model registry. + wait + If True, block and return DataFrame. If False, return the job handle. + + Returns + ------- + Union[pd.DataFrame, RemoteJob] + """ + raise NotImplementedError diff --git a/src/autogluon/cloud/model/registry.py b/src/autogluon/cloud/model/registry.py new file mode 100644 index 0000000..4b61d20 --- /dev/null +++ b/src/autogluon/cloud/model/registry.py @@ -0,0 +1,65 @@ +"""Foundation model registry. + +Maps model_id to AG-compatible configuration for deploy / predict. +""" + +from typing import Any, Dict, Literal, TypedDict + + +class FoundationModelConfig(TypedDict): + task: Literal["timeseries", "tabular"] + model_name: str # AG model class name (e.g. "Chronos", "Chronos2") + model_config: Dict[str, Any] # passed to the AG model (e.g. {"model_path": "..."}) + default_instance_type: str + default_inference_config: Dict[str, Any] # default prediction kwargs + + +FOUNDATION_MODEL_REGISTRY: dict[str, FoundationModelConfig] = { + "chronos-bolt-tiny": { + "task": "timeseries", + "model_name": "Chronos", + "model_config": {"model_path": "amazon/chronos-bolt-tiny"}, + "default_instance_type": "ml.g5.xlarge", + "default_inference_config": { + "prediction_length": 64, + "quantile_levels": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], + }, + }, + "chronos-bolt-small": { + "task": "timeseries", + "model_name": "Chronos", + "model_config": {"model_path": "amazon/chronos-bolt-small"}, + "default_instance_type": "ml.g5.xlarge", + "default_inference_config": { + "prediction_length": 64, + "quantile_levels": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], + }, + }, + "chronos-bolt-base": { + "task": "timeseries", + "model_name": "Chronos", + "model_config": {"model_path": "amazon/chronos-bolt-base"}, + "default_instance_type": "ml.g5.xlarge", + "default_inference_config": { + "prediction_length": 64, + "quantile_levels": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], + }, + }, + "chronos-2": { + "task": "timeseries", + "model_name": "Chronos2", + "model_config": {"model_path": "amazon/chronos-2"}, + "default_instance_type": "ml.g5.xlarge", + "default_inference_config": { + "prediction_length": 64, + "quantile_levels": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], + }, + }, +} + + +def get_model_config(model_id: str) -> FoundationModelConfig: + if model_id not in FOUNDATION_MODEL_REGISTRY: + available = list(FOUNDATION_MODEL_REGISTRY.keys()) + raise ValueError(f"Unknown model_id '{model_id}'. Available models: {available}") + return FOUNDATION_MODEL_REGISTRY[model_id] From 0e2d6138145cc5c07a2f17d8a1b97b3280ad3b3c Mon Sep 17 00:00:00 2001 From: Oleksandr Shchur Date: Tue, 12 May 2026 12:57:00 +0000 Subject: [PATCH 02/10] Fix dict union --- src/autogluon/cloud/model/foundation_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/autogluon/cloud/model/foundation_model.py b/src/autogluon/cloud/model/foundation_model.py index 1792c8f..9cda06e 100644 --- a/src/autogluon/cloud/model/foundation_model.py +++ b/src/autogluon/cloud/model/foundation_model.py @@ -54,7 +54,7 @@ def __init__( self.s3_output_path = s3_output_path self._config = get_model_config(model_id) # Merge user overrides on top of registry defaults - self.model_config = {**self._config.get("model_config", {}), **(model_config or {})} + self.model_config = self._config.get("model_config", {}) | (model_config or {}) # TODO: instantiate backend via BackendFactory self._backend_type = backend From fb92f32b0e70a7de2a6870cb1c50434fd3f8827b Mon Sep 17 00:00:00 2001 From: Oleksandr Shchur Date: Tue, 12 May 2026 13:18:03 +0000 Subject: [PATCH 03/10] Update API --- src/autogluon/cloud/model/foundation_model.py | 51 ++++++++++++++++++- src/autogluon/cloud/model/registry.py | 26 +++++++--- 2 files changed, 69 insertions(+), 8 deletions(-) diff --git a/src/autogluon/cloud/model/foundation_model.py b/src/autogluon/cloud/model/foundation_model.py index 9cda06e..d01b0d1 100644 --- a/src/autogluon/cloud/model/foundation_model.py +++ b/src/autogluon/cloud/model/foundation_model.py @@ -33,9 +33,9 @@ def __new__(cls, model_id: str, **kwargs) -> "FoundationModel": return super().__new__(cls) config = get_model_config(model_id) task = config["task"] - if task == "timeseries": + if task == "forecasting": return super().__new__(TimeSeriesFoundationModel) - elif task == "tabular": + elif task in ("classification", "regression"): return super().__new__(TabularFoundationModel) raise ValueError(f"Unsupported task: {task}") @@ -156,7 +156,11 @@ class TimeSeriesFoundationModel(FoundationModel): def predict( self, data: Union[str, pd.DataFrame], + target: str = "target", + id_column: str = "item_id", + timestamp_column: str = "timestamp", known_covariates: Optional[Union[str, pd.DataFrame]] = None, + static_features: Optional[Union[str, pd.DataFrame]] = None, prediction_length: Optional[int] = None, quantile_levels: Optional[List[float]] = None, output_path: Optional[str] = None, @@ -170,8 +174,16 @@ def predict( ---------- data Historical time series in long format (DataFrame or S3 path). + target + Name of the target column to forecast. + id_column + Name of the item ID column. + timestamp_column + Name of the timestamp column. known_covariates Future values of known covariates (DataFrame or S3 path). + static_features + Metadata attributes of individual items (DataFrame or S3 path). prediction_length Number of time steps to forecast. If None, will use the default from the model registry. @@ -231,3 +243,38 @@ def predict( Union[pd.DataFrame, RemoteJob] """ raise NotImplementedError + + def predict_proba( + self, + train_data: Union[str, pd.DataFrame], + test_data: Union[str, pd.DataFrame], + label: str = "target", + output_path: Optional[str] = None, + instance_type: Optional[str] = None, + wait: bool = True, + ) -> Union[pd.DataFrame, RemoteJob]: + """ + Run batch prediction returning class probabilities. + + Parameters + ---------- + train_data + Labeled few-shot context for the foundation model. + test_data + Unlabeled data to predict on. + label + Target column name in train_data. + output_path + S3 path to store predictions. + If None, will auto-generate under s3_output_path. + instance_type + Instance type for the prediction job. + If None, will use the default from the model registry. + wait + If True, block and return DataFrame. If False, return the job handle. + + Returns + ------- + Union[pd.DataFrame, RemoteJob] + """ + raise NotImplementedError diff --git a/src/autogluon/cloud/model/registry.py b/src/autogluon/cloud/model/registry.py index 4b61d20..88ce371 100644 --- a/src/autogluon/cloud/model/registry.py +++ b/src/autogluon/cloud/model/registry.py @@ -7,8 +7,8 @@ class FoundationModelConfig(TypedDict): - task: Literal["timeseries", "tabular"] - model_name: str # AG model class name (e.g. "Chronos", "Chronos2") + task: Literal["forecasting", "classification", "regression"] + model_name: str # AG model class name (e.g. "Chronos", "Chronos2", "Mitra") model_config: Dict[str, Any] # passed to the AG model (e.g. {"model_path": "..."}) default_instance_type: str default_inference_config: Dict[str, Any] # default prediction kwargs @@ -16,7 +16,7 @@ class FoundationModelConfig(TypedDict): FOUNDATION_MODEL_REGISTRY: dict[str, FoundationModelConfig] = { "chronos-bolt-tiny": { - "task": "timeseries", + "task": "forecasting", "model_name": "Chronos", "model_config": {"model_path": "amazon/chronos-bolt-tiny"}, "default_instance_type": "ml.g5.xlarge", @@ -26,7 +26,7 @@ class FoundationModelConfig(TypedDict): }, }, "chronos-bolt-small": { - "task": "timeseries", + "task": "forecasting", "model_name": "Chronos", "model_config": {"model_path": "amazon/chronos-bolt-small"}, "default_instance_type": "ml.g5.xlarge", @@ -36,7 +36,7 @@ class FoundationModelConfig(TypedDict): }, }, "chronos-bolt-base": { - "task": "timeseries", + "task": "forecasting", "model_name": "Chronos", "model_config": {"model_path": "amazon/chronos-bolt-base"}, "default_instance_type": "ml.g5.xlarge", @@ -46,7 +46,7 @@ class FoundationModelConfig(TypedDict): }, }, "chronos-2": { - "task": "timeseries", + "task": "forecasting", "model_name": "Chronos2", "model_config": {"model_path": "amazon/chronos-2"}, "default_instance_type": "ml.g5.xlarge", @@ -55,6 +55,20 @@ class FoundationModelConfig(TypedDict): "quantile_levels": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], }, }, + "mitra-classification": { + "task": "classification", + "model_name": "Mitra", + "model_config": {"model_path": "TODO"}, + "default_instance_type": "ml.m5.xlarge", + "default_inference_config": {}, + }, + "mitra-regression": { + "task": "regression", + "model_name": "Mitra", + "model_config": {"model_path": "TODO"}, + "default_instance_type": "ml.m5.xlarge", + "default_inference_config": {}, + }, } From 371ff0ddf801b4121401b986f206de5185997a31 Mon Sep 17 00:00:00 2001 From: Oleksandr Shchur Date: Tue, 12 May 2026 13:52:08 +0000 Subject: [PATCH 04/10] Update structure --- src/autogluon/cloud/model/foundation_model.py | 31 +++++++++++++------ src/autogluon/cloud/model/registry.py | 19 ------------ 2 files changed, 21 insertions(+), 29 deletions(-) diff --git a/src/autogluon/cloud/model/foundation_model.py b/src/autogluon/cloud/model/foundation_model.py index d01b0d1..10814cb 100644 --- a/src/autogluon/cloud/model/foundation_model.py +++ b/src/autogluon/cloud/model/foundation_model.py @@ -64,8 +64,8 @@ def deploy( mode: Literal["realtime", "serverless", "async"] = "realtime", endpoint_name: Optional[str] = None, model_artifact_path: Optional[str] = None, - model_config: Optional[Dict[str, Any]] = None, wait: bool = True, + **backend_kwargs, ) -> Endpoint: """ Deploy model to an endpoint. @@ -81,12 +81,13 @@ def deploy( Custom endpoint name. If None, will auto-generate a unique name. model_artifact_path - S3 path to pre-cached model weights (for VPC / fast cold start). - If None, weights are downloaded from HuggingFace on cold start. - model_config - Override default inference config (prediction_length, quantile_levels, etc.) + S3 path to pre-cached model weights (for VPC use). + If None, weights are downloaded from HuggingFace at startup. wait Whether to block until the endpoint is ready. + **backend_kwargs + Additional backend-specific arguments (e.g. framework_version, custom_image_uri, + volume_size). Returns ------- @@ -132,10 +133,10 @@ def fit( def cache_model_artifact(self, s3_path: str) -> str: """ - Pre-cache model weights to S3 for VPC or production use. + Pre-cache model weights to S3 (for VPC-deployed endpoints). Launches a small job that downloads weights from HuggingFace - and writes them to S3, avoiding large local downloads. + and uploads them to S3. Parameters ---------- @@ -161,11 +162,12 @@ def predict( timestamp_column: str = "timestamp", known_covariates: Optional[Union[str, pd.DataFrame]] = None, static_features: Optional[Union[str, pd.DataFrame]] = None, - prediction_length: Optional[int] = None, + prediction_length: int = 64, quantile_levels: Optional[List[float]] = None, output_path: Optional[str] = None, instance_type: Optional[str] = None, wait: bool = True, + **backend_kwargs, ) -> Union[pd.DataFrame, RemoteJob]: """ Run batch prediction for time series. @@ -186,10 +188,8 @@ def predict( Metadata attributes of individual items (DataFrame or S3 path). prediction_length Number of time steps to forecast. - If None, will use the default from the model registry. quantile_levels Quantiles to predict. - If None, will use the default from the model registry. output_path S3 path to store predictions. If None, will auto-generate under s3_output_path. @@ -198,6 +198,9 @@ def predict( If None, will use the default from the model registry. wait If True, block and return DataFrame. If False, return the job handle. + **backend_kwargs + Additional backend-specific arguments (e.g. job_name, custom_image_uri, + framework_version, volume_size). Returns ------- @@ -217,6 +220,7 @@ def predict( output_path: Optional[str] = None, instance_type: Optional[str] = None, wait: bool = True, + **backend_kwargs, ) -> Union[pd.DataFrame, RemoteJob]: """ Run batch prediction for tabular tasks. @@ -237,6 +241,9 @@ def predict( If None, will use the default from the model registry. wait If True, block and return DataFrame. If False, return the job handle. + **backend_kwargs + Additional backend-specific arguments (e.g. job_name, custom_image_uri, + framework_version, volume_size). Returns ------- @@ -252,6 +259,7 @@ def predict_proba( output_path: Optional[str] = None, instance_type: Optional[str] = None, wait: bool = True, + **backend_kwargs, ) -> Union[pd.DataFrame, RemoteJob]: """ Run batch prediction returning class probabilities. @@ -272,6 +280,9 @@ def predict_proba( If None, will use the default from the model registry. wait If True, block and return DataFrame. If False, return the job handle. + **backend_kwargs + Additional backend-specific arguments (e.g. job_name, custom_image_uri, + framework_version, volume_size). Returns ------- diff --git a/src/autogluon/cloud/model/registry.py b/src/autogluon/cloud/model/registry.py index 88ce371..10fcc8a 100644 --- a/src/autogluon/cloud/model/registry.py +++ b/src/autogluon/cloud/model/registry.py @@ -11,7 +11,6 @@ class FoundationModelConfig(TypedDict): model_name: str # AG model class name (e.g. "Chronos", "Chronos2", "Mitra") model_config: Dict[str, Any] # passed to the AG model (e.g. {"model_path": "..."}) default_instance_type: str - default_inference_config: Dict[str, Any] # default prediction kwargs FOUNDATION_MODEL_REGISTRY: dict[str, FoundationModelConfig] = { @@ -20,54 +19,36 @@ class FoundationModelConfig(TypedDict): "model_name": "Chronos", "model_config": {"model_path": "amazon/chronos-bolt-tiny"}, "default_instance_type": "ml.g5.xlarge", - "default_inference_config": { - "prediction_length": 64, - "quantile_levels": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], - }, }, "chronos-bolt-small": { "task": "forecasting", "model_name": "Chronos", "model_config": {"model_path": "amazon/chronos-bolt-small"}, "default_instance_type": "ml.g5.xlarge", - "default_inference_config": { - "prediction_length": 64, - "quantile_levels": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], - }, }, "chronos-bolt-base": { "task": "forecasting", "model_name": "Chronos", "model_config": {"model_path": "amazon/chronos-bolt-base"}, "default_instance_type": "ml.g5.xlarge", - "default_inference_config": { - "prediction_length": 64, - "quantile_levels": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], - }, }, "chronos-2": { "task": "forecasting", "model_name": "Chronos2", "model_config": {"model_path": "amazon/chronos-2"}, "default_instance_type": "ml.g5.xlarge", - "default_inference_config": { - "prediction_length": 64, - "quantile_levels": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], - }, }, "mitra-classification": { "task": "classification", "model_name": "Mitra", "model_config": {"model_path": "TODO"}, "default_instance_type": "ml.m5.xlarge", - "default_inference_config": {}, }, "mitra-regression": { "task": "regression", "model_name": "Mitra", "model_config": {"model_path": "TODO"}, "default_instance_type": "ml.m5.xlarge", - "default_inference_config": {}, }, } From 901cba3437e1485b424bb64ecf6e7a56ce1fe711 Mon Sep 17 00:00:00 2001 From: Oleksandr Shchur Date: Tue, 12 May 2026 13:53:36 +0000 Subject: [PATCH 05/10] Change default pred_len --- src/autogluon/cloud/model/foundation_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/autogluon/cloud/model/foundation_model.py b/src/autogluon/cloud/model/foundation_model.py index 10814cb..aead5ce 100644 --- a/src/autogluon/cloud/model/foundation_model.py +++ b/src/autogluon/cloud/model/foundation_model.py @@ -162,7 +162,7 @@ def predict( timestamp_column: str = "timestamp", known_covariates: Optional[Union[str, pd.DataFrame]] = None, static_features: Optional[Union[str, pd.DataFrame]] = None, - prediction_length: int = 64, + prediction_length: int = 1, quantile_levels: Optional[List[float]] = None, output_path: Optional[str] = None, instance_type: Optional[str] = None, From ad18b97109d801a202372ffcba3fa3550ad0bba9 Mon Sep 17 00:00:00 2001 From: Oleksandr Shchur Date: Wed, 13 May 2026 12:08:28 +0000 Subject: [PATCH 06/10] Split training and infernece hyperparameters --- src/autogluon/cloud/model/foundation_model.py | 39 ++++++++++++++----- src/autogluon/cloud/model/registry.py | 21 ++++++---- 2 files changed, 44 insertions(+), 16 deletions(-) diff --git a/src/autogluon/cloud/model/foundation_model.py b/src/autogluon/cloud/model/foundation_model.py index aead5ce..b3a8d35 100644 --- a/src/autogluon/cloud/model/foundation_model.py +++ b/src/autogluon/cloud/model/foundation_model.py @@ -22,7 +22,7 @@ class FoundationModel: Examples -------- - >>> model = FoundationModel("chronos-bolt-base", role_arn="arn:...") + >>> model = FoundationModel("chronos-bolt-base", role_arn="arn:...", hyperparameters={"model_path": "s3://cached/"}) >>> endpoint = model.deploy() >>> predictions = endpoint.predict(data) >>> endpoint.delete_endpoint() @@ -46,24 +46,29 @@ def __init__( role_arn: Optional[str] = None, region: Optional[str] = None, s3_output_path: Optional[str] = None, - model_config: Optional[Dict[str, Any]] = None, + hyperparameters: Optional[Dict[str, Any]] = None, ): self.model_id = model_id self.role_arn = role_arn self.region = region self.s3_output_path = s3_output_path self._config = get_model_config(model_id) - # Merge user overrides on top of registry defaults - self.model_config = self._config.get("model_config", {}) | (model_config or {}) + self._hyperparameter_overrides = hyperparameters or {} # TODO: instantiate backend via BackendFactory self._backend_type = backend + def _get_hyperparameters( + self, context: Literal["inference", "training"], overrides: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + config_key = "inference_hyperparameters" if context == "inference" else "training_hyperparameters" + return self._config.get(config_key, {}) | self._hyperparameter_overrides | (overrides or {}) + def deploy( self, instance_type: Optional[str] = None, mode: Literal["realtime", "serverless", "async"] = "realtime", endpoint_name: Optional[str] = None, - model_artifact_path: Optional[str] = None, + hyperparameters: Optional[Dict[str, Any]] = None, wait: bool = True, **backend_kwargs, ) -> Endpoint: @@ -80,9 +85,9 @@ def deploy( endpoint_name Custom endpoint name. If None, will auto-generate a unique name. - model_artifact_path - S3 path to pre-cached model weights (for VPC use). - If None, weights are downloaded from HuggingFace at startup. + hyperparameters + Model hyperparameters for inference. Overrides values passed to the constructor. + Available hyperparameters for each model are listed in the AutoGluon documentation. wait Whether to block until the endpoint is ready. **backend_kwargs @@ -105,6 +110,7 @@ def fit( train_data: Union[str, pd.DataFrame], output_path: Optional[str] = None, instance_type: Optional[str] = None, + hyperparameters: Optional[Dict[str, Any]] = None, wait: bool = True, **kwargs, ) -> "FoundationModel": @@ -121,13 +127,16 @@ def fit( instance_type Instance type for the training job. If None, will use the default from the model registry. + hyperparameters + Model hyperparameters for training. Overrides values passed to the constructor. + Available hyperparameters for each model are listed in the AutoGluon documentation. wait If True, block until training completes. Returns ------- FoundationModel - New instance with model_config pointing to the fine-tuned artifact. + New instance with hyperparameters pointing to the fine-tuned artifact. """ raise NotImplementedError @@ -164,6 +173,7 @@ def predict( static_features: Optional[Union[str, pd.DataFrame]] = None, prediction_length: int = 1, quantile_levels: Optional[List[float]] = None, + hyperparameters: Optional[Dict[str, Any]] = None, output_path: Optional[str] = None, instance_type: Optional[str] = None, wait: bool = True, @@ -190,6 +200,9 @@ def predict( Number of time steps to forecast. quantile_levels Quantiles to predict. + hyperparameters + Model hyperparameters for inference. Overrides values passed to the constructor. + Available hyperparameters for each model are listed in the AutoGluon documentation. output_path S3 path to store predictions. If None, will auto-generate under s3_output_path. @@ -217,6 +230,7 @@ def predict( train_data: Union[str, pd.DataFrame], test_data: Union[str, pd.DataFrame], label: str = "target", + hyperparameters: Optional[Dict[str, Any]] = None, output_path: Optional[str] = None, instance_type: Optional[str] = None, wait: bool = True, @@ -233,6 +247,9 @@ def predict( Unlabeled data to predict on. label Target column name in train_data. + hyperparameters + Model hyperparameters for inference. Overrides values passed to the constructor. + Available hyperparameters for each model are listed in the AutoGluon documentation. output_path S3 path to store predictions. If None, will auto-generate under s3_output_path. @@ -256,6 +273,7 @@ def predict_proba( train_data: Union[str, pd.DataFrame], test_data: Union[str, pd.DataFrame], label: str = "target", + hyperparameters: Optional[Dict[str, Any]] = None, output_path: Optional[str] = None, instance_type: Optional[str] = None, wait: bool = True, @@ -272,6 +290,9 @@ def predict_proba( Unlabeled data to predict on. label Target column name in train_data. + hyperparameters + Model hyperparameters for inference. Overrides values passed to the constructor. + Available hyperparameters for each model are listed in the AutoGluon documentation. output_path S3 path to store predictions. If None, will auto-generate under s3_output_path. diff --git a/src/autogluon/cloud/model/registry.py b/src/autogluon/cloud/model/registry.py index 10fcc8a..7502c8f 100644 --- a/src/autogluon/cloud/model/registry.py +++ b/src/autogluon/cloud/model/registry.py @@ -9,7 +9,8 @@ class FoundationModelConfig(TypedDict): task: Literal["forecasting", "classification", "regression"] model_name: str # AG model class name (e.g. "Chronos", "Chronos2", "Mitra") - model_config: Dict[str, Any] # passed to the AG model (e.g. {"model_path": "..."}) + inference_hyperparameters: Dict[str, Any] # defaults for deploy() and predict() + training_hyperparameters: Dict[str, Any] # defaults for fit() default_instance_type: str @@ -17,37 +18,43 @@ class FoundationModelConfig(TypedDict): "chronos-bolt-tiny": { "task": "forecasting", "model_name": "Chronos", - "model_config": {"model_path": "amazon/chronos-bolt-tiny"}, + "inference_hyperparameters": {"model_path": "amazon/chronos-bolt-tiny"}, + "training_hyperparameters": {"model_path": "amazon/chronos-bolt-tiny"}, "default_instance_type": "ml.g5.xlarge", }, "chronos-bolt-small": { "task": "forecasting", "model_name": "Chronos", - "model_config": {"model_path": "amazon/chronos-bolt-small"}, + "inference_hyperparameters": {"model_path": "amazon/chronos-bolt-small"}, + "training_hyperparameters": {"model_path": "amazon/chronos-bolt-small"}, "default_instance_type": "ml.g5.xlarge", }, "chronos-bolt-base": { "task": "forecasting", "model_name": "Chronos", - "model_config": {"model_path": "amazon/chronos-bolt-base"}, + "inference_hyperparameters": {"model_path": "amazon/chronos-bolt-base"}, + "training_hyperparameters": {"model_path": "amazon/chronos-bolt-base"}, "default_instance_type": "ml.g5.xlarge", }, "chronos-2": { "task": "forecasting", "model_name": "Chronos2", - "model_config": {"model_path": "amazon/chronos-2"}, + "inference_hyperparameters": {"model_path": "amazon/chronos-2"}, + "training_hyperparameters": {"model_path": "amazon/chronos-2", "fine_tune": True}, "default_instance_type": "ml.g5.xlarge", }, "mitra-classification": { "task": "classification", "model_name": "Mitra", - "model_config": {"model_path": "TODO"}, + "inference_hyperparameters": {"model_path": "TODO"}, + "training_hyperparameters": {"model_path": "TODO"}, "default_instance_type": "ml.m5.xlarge", }, "mitra-regression": { "task": "regression", "model_name": "Mitra", - "model_config": {"model_path": "TODO"}, + "inference_hyperparameters": {"model_path": "TODO"}, + "training_hyperparameters": {"model_path": "TODO"}, "default_instance_type": "ml.m5.xlarge", }, } From 9141907f828855e3981852056851595ad169cbaf Mon Sep 17 00:00:00 2001 From: Oleksandr Shchur Date: Wed, 13 May 2026 12:19:09 +0000 Subject: [PATCH 07/10] Remove deploy mode --- src/autogluon/cloud/model/foundation_model.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/autogluon/cloud/model/foundation_model.py b/src/autogluon/cloud/model/foundation_model.py index b3a8d35..4247126 100644 --- a/src/autogluon/cloud/model/foundation_model.py +++ b/src/autogluon/cloud/model/foundation_model.py @@ -66,7 +66,6 @@ def _get_hyperparameters( def deploy( self, instance_type: Optional[str] = None, - mode: Literal["realtime", "serverless", "async"] = "realtime", endpoint_name: Optional[str] = None, hyperparameters: Optional[Dict[str, Any]] = None, wait: bool = True, @@ -80,8 +79,6 @@ def deploy( instance_type Instance type for the endpoint. If None, will use the default from the model registry. - mode - Endpoint type. endpoint_name Custom endpoint name. If None, will auto-generate a unique name. @@ -91,8 +88,8 @@ def deploy( wait Whether to block until the endpoint is ready. **backend_kwargs - Additional backend-specific arguments (e.g. framework_version, custom_image_uri, - volume_size). + Backend-specific arguments. Use these to configure serverless, async, or + autoscaling (e.g. memory_size_in_mb, max_concurrency, initial_instance_count). Returns ------- From e4a725e94ce8fbf9447315efff19829ebefdcf13 Mon Sep 17 00:00:00 2001 From: Oleksandr Shchur Date: Wed, 13 May 2026 12:22:09 +0000 Subject: [PATCH 08/10] Update dummy configs --- src/autogluon/cloud/model/registry.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/autogluon/cloud/model/registry.py b/src/autogluon/cloud/model/registry.py index 7502c8f..128c73d 100644 --- a/src/autogluon/cloud/model/registry.py +++ b/src/autogluon/cloud/model/registry.py @@ -43,19 +43,20 @@ class FoundationModelConfig(TypedDict): "training_hyperparameters": {"model_path": "amazon/chronos-2", "fine_tune": True}, "default_instance_type": "ml.g5.xlarge", }, + # TODO: Replace dummy configs with real values "mitra-classification": { "task": "classification", "model_name": "Mitra", "inference_hyperparameters": {"model_path": "TODO"}, "training_hyperparameters": {"model_path": "TODO"}, - "default_instance_type": "ml.m5.xlarge", + "default_instance_type": "ml.g5.xlarge", }, "mitra-regression": { "task": "regression", "model_name": "Mitra", "inference_hyperparameters": {"model_path": "TODO"}, "training_hyperparameters": {"model_path": "TODO"}, - "default_instance_type": "ml.m5.xlarge", + "default_instance_type": "ml.g5.xlarge", }, } From 26e0741666e8877e842e448a84fa620add311614 Mon Sep 17 00:00:00 2001 From: Oleksandr Shchur Date: Fri, 15 May 2026 09:19:45 +0000 Subject: [PATCH 09/10] Add flag checking if fine-tuning is enabled for the model --- src/autogluon/cloud/model/foundation_model.py | 2 ++ src/autogluon/cloud/model/registry.py | 7 +++++++ 2 files changed, 9 insertions(+) diff --git a/src/autogluon/cloud/model/foundation_model.py b/src/autogluon/cloud/model/foundation_model.py index 4247126..5053bb1 100644 --- a/src/autogluon/cloud/model/foundation_model.py +++ b/src/autogluon/cloud/model/foundation_model.py @@ -135,6 +135,8 @@ def fit( FoundationModel New instance with hyperparameters pointing to the fine-tuned artifact. """ + if not self._config.get("fine_tunable", False): + raise ValueError(f"Model '{self.model_id}' does not support fine-tuning.") raise NotImplementedError def cache_model_artifact(self, s3_path: str) -> str: diff --git a/src/autogluon/cloud/model/registry.py b/src/autogluon/cloud/model/registry.py index 128c73d..8f6e9b9 100644 --- a/src/autogluon/cloud/model/registry.py +++ b/src/autogluon/cloud/model/registry.py @@ -12,6 +12,7 @@ class FoundationModelConfig(TypedDict): inference_hyperparameters: Dict[str, Any] # defaults for deploy() and predict() training_hyperparameters: Dict[str, Any] # defaults for fit() default_instance_type: str + fine_tunable: bool # whether .fit() is supported FOUNDATION_MODEL_REGISTRY: dict[str, FoundationModelConfig] = { @@ -21,6 +22,7 @@ class FoundationModelConfig(TypedDict): "inference_hyperparameters": {"model_path": "amazon/chronos-bolt-tiny"}, "training_hyperparameters": {"model_path": "amazon/chronos-bolt-tiny"}, "default_instance_type": "ml.g5.xlarge", + "fine_tunable": False, }, "chronos-bolt-small": { "task": "forecasting", @@ -28,6 +30,7 @@ class FoundationModelConfig(TypedDict): "inference_hyperparameters": {"model_path": "amazon/chronos-bolt-small"}, "training_hyperparameters": {"model_path": "amazon/chronos-bolt-small"}, "default_instance_type": "ml.g5.xlarge", + "fine_tunable": False, }, "chronos-bolt-base": { "task": "forecasting", @@ -35,6 +38,7 @@ class FoundationModelConfig(TypedDict): "inference_hyperparameters": {"model_path": "amazon/chronos-bolt-base"}, "training_hyperparameters": {"model_path": "amazon/chronos-bolt-base"}, "default_instance_type": "ml.g5.xlarge", + "fine_tunable": False, }, "chronos-2": { "task": "forecasting", @@ -42,6 +46,7 @@ class FoundationModelConfig(TypedDict): "inference_hyperparameters": {"model_path": "amazon/chronos-2"}, "training_hyperparameters": {"model_path": "amazon/chronos-2", "fine_tune": True}, "default_instance_type": "ml.g5.xlarge", + "fine_tunable": True, }, # TODO: Replace dummy configs with real values "mitra-classification": { @@ -50,6 +55,7 @@ class FoundationModelConfig(TypedDict): "inference_hyperparameters": {"model_path": "TODO"}, "training_hyperparameters": {"model_path": "TODO"}, "default_instance_type": "ml.g5.xlarge", + "fine_tunable": False, }, "mitra-regression": { "task": "regression", @@ -57,6 +63,7 @@ class FoundationModelConfig(TypedDict): "inference_hyperparameters": {"model_path": "TODO"}, "training_hyperparameters": {"model_path": "TODO"}, "default_instance_type": "ml.g5.xlarge", + "fine_tunable": False, }, } From 8f90da3877c073adf0f96f941af66db4ec6bf62d Mon Sep 17 00:00:00 2001 From: Oleksandr Shchur Date: Fri, 15 May 2026 09:26:33 +0000 Subject: [PATCH 10/10] Remove reference to RemoteJob.result() --- src/autogluon/cloud/model/foundation_model.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/autogluon/cloud/model/foundation_model.py b/src/autogluon/cloud/model/foundation_model.py index 5053bb1..d4ebabc 100644 --- a/src/autogluon/cloud/model/foundation_model.py +++ b/src/autogluon/cloud/model/foundation_model.py @@ -99,7 +99,12 @@ def deploy( @abstractmethod def predict(self, data: Union[str, pd.DataFrame], wait: bool = True, **kwargs) -> Union[pd.DataFrame, RemoteJob]: - """Subclasses override with task-specific signature.""" + """Subclasses override with task-specific signature. + + When wait=False, returns a RemoteJob handle. Use job.get_job_status() to poll + and job.get_output_path() to get the S3 path to predictions once complete. + """ + # TODO: consider adding a .result() method to RemoteJob that downloads + parses output ... def fit(