Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/autogluon/cloud/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .foundation_model import FoundationModel, TabularFoundationModel, TimeSeriesFoundationModel
Comment thread
shchur marked this conversation as resolved.

Comment thread
shchur marked this conversation as resolved.
__all__ = ["FoundationModel", "TabularFoundationModel", "TimeSeriesFoundationModel"]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question:

Are TabularFoundationModel and TimeSeriesFoundationModel part of the public API, or internal to the FoundationModel factory? They're in __all__, but examples only use FoundationModel(...). For instance, should users be able to write:

def forecast(model: TimeSeriesFoundationModel, df: pd.DataFrame) -> pd.DataFrame:
    return model.predict(df, prediction_length=24)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, with the current design the user should be able to directly create a TimeSeriesFoundationModel object bypassing the FoundationModel base class, e.g.

model = TimeSeriesFoundationModel(model_id="chronos-bolt-base")
predictions = model.predict(...)

The FoundationModel base class only exists for convenience.

316 changes: 316 additions & 0 deletions src/autogluon/cloud/model/foundation_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,316 @@
"""FoundationModel — deploy and predict with pretrained foundation models on AWS."""

from __future__ import annotations

from abc import abstractmethod
from typing import Any, Dict, List, Literal, Optional, Union

import pandas as pd

from autogluon.cloud.endpoint.endpoint import Endpoint
from autogluon.cloud.job.remote_job import RemoteJob

from .registry import get_model_config


class FoundationModel:
"""
Pretrained foundation model inference on AWS.

Factory: FoundationModel("chronos-bolt-base", ...) returns the appropriate
task-specific subclass (TimeSeriesFoundationModel, TabularFoundationModel).

Examples
--------
>>> model = FoundationModel("chronos-bolt-base", role_arn="arn:...", hyperparameters={"model_path": "s3://cached/"})
>>> endpoint = model.deploy()
>>> predictions = endpoint.predict(data)
Comment thread
shchur marked this conversation as resolved.
>>> endpoint.delete_endpoint()
"""

def __new__(cls, model_id: str, **kwargs) -> "FoundationModel":

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @abdulfatir,

  1. The FoundationModel class automatically resolves to TimeSeriesFoundationModel, TabularFoundationModel, etc based on the task associated with each model_id. The subclasses only differ in their API for predict: e.g. TimeSeriesFoundationModel takes historical data + known covariates as input, and TabularFoundationModel takes train & test DFs as input

if cls is not FoundationModel:
return super().__new__(cls)
config = get_model_config(model_id)
task = config["task"]
if task == "forecasting":
return super().__new__(TimeSeriesFoundationModel)
elif task in ("classification", "regression"):
return super().__new__(TabularFoundationModel)
raise ValueError(f"Unsupported task: {task}")

def __init__(
self,
model_id: str,
backend: Literal["sagemaker"] = "sagemaker",

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@abdulfatir
3. We can add support for other backends like Lambda in the future - but I would decouple this from the FoundationModel API so that the user doesn't need to worry about it

role_arn: Optional[str] = None,
region: Optional[str] = None,
s3_output_path: Optional[str] = None,
hyperparameters: Optional[Dict[str, Any]] = None,
):
self.model_id = model_id
self.role_arn = role_arn
self.region = region
self.s3_output_path = s3_output_path
self._config = get_model_config(model_id)
self._hyperparameter_overrides = hyperparameters or {}
# TODO: instantiate backend via BackendFactory
self._backend_type = backend

def _get_hyperparameters(
self, context: Literal["inference", "training"], overrides: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
config_key = "inference_hyperparameters" if context == "inference" else "training_hyperparameters"
return self._config.get(config_key, {}) | self._hyperparameter_overrides | (overrides or {})

def deploy(
self,
instance_type: Optional[str] = None,
endpoint_name: Optional[str] = None,
hyperparameters: Optional[Dict[str, Any]] = None,
wait: bool = True,
**backend_kwargs,
) -> Endpoint:
"""
Deploy model to an endpoint.

Parameters
----------
instance_type
Instance type for the endpoint.
If None, will use the default from the model registry.
endpoint_name
Custom endpoint name.
If None, will auto-generate a unique name.
hyperparameters
Model hyperparameters for inference. Overrides values passed to the constructor.
Available hyperparameters for each model are listed in the AutoGluon documentation.
wait
Whether to block until the endpoint is ready.
**backend_kwargs
Backend-specific arguments. Use these to configure serverless, async, or
autoscaling (e.g. memory_size_in_mb, max_concurrency, initial_instance_count).

Returns
-------
Endpoint
"""
raise NotImplementedError

@abstractmethod
def predict(self, data: Union[str, pd.DataFrame], wait: bool = True, **kwargs) -> Union[pd.DataFrame, RemoteJob]:
"""Subclasses override with task-specific signature.

When wait=False, returns a RemoteJob handle. Use job.get_job_status() to poll
and job.get_output_path() to get the S3 path to predictions once complete.
"""
# TODO: consider adding a .result() method to RemoteJob that downloads + parses output
Comment on lines +104 to +107

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@melopeo I missed the top-level comment. Here is a temporary solution, I will figure out the best path when implementing the backend.

...

def fit(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Putting fit() on FoundationModel (the base class) implies every foundation model can be fine-tuned.

In general it makes sense to support fine-tuning of a pretrained model. My question is if we are assuming all FoundationModels under consideration will have a fine-tuning capability. If a foundation model can't be fine-tuned, how are we planning to handle that?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, I added a boolean flag to the ModelConfig indicating whether fine-tuning is supported.

self,
train_data: Union[str, pd.DataFrame],
output_path: Optional[str] = None,
instance_type: Optional[str] = None,
hyperparameters: Optional[Dict[str, Any]] = None,
wait: bool = True,
**kwargs,
) -> "FoundationModel":
"""
Fine-tune the model. Returns a new FoundationModel pointing to the fine-tuned artifact.

Parameters
----------
train_data
Training data as DataFrame or S3 path.
output_path
S3 path to store fine-tuned model.
If None, will auto-generate under s3_output_path.
instance_type
Instance type for the training job.
If None, will use the default from the model registry.
hyperparameters
Model hyperparameters for training. Overrides values passed to the constructor.
Available hyperparameters for each model are listed in the AutoGluon documentation.
wait
If True, block until training completes.

Returns
-------
FoundationModel
New instance with hyperparameters pointing to the fine-tuned artifact.
"""
if not self._config.get("fine_tunable", False):
raise ValueError(f"Model '{self.model_id}' does not support fine-tuning.")
raise NotImplementedError

def cache_model_artifact(self, s3_path: str) -> str:
"""
Pre-cache model weights to S3 (for VPC-deployed endpoints).

Launches a small job that downloads weights from HuggingFace
and uploads them to S3.

Parameters
----------
s3_path
S3 path where the model weights should be cached.

Returns
-------
str
S3 path to the cached artifact.
"""
raise NotImplementedError


class TimeSeriesFoundationModel(FoundationModel):
"""Foundation model for time series forecasting (Chronos, etc.)."""

def predict(
self,
data: Union[str, pd.DataFrame],
target: str = "target",
id_column: str = "item_id",
timestamp_column: str = "timestamp",
known_covariates: Optional[Union[str, pd.DataFrame]] = None,
static_features: Optional[Union[str, pd.DataFrame]] = None,
prediction_length: int = 1,
quantile_levels: Optional[List[float]] = None,
hyperparameters: Optional[Dict[str, Any]] = None,
output_path: Optional[str] = None,
instance_type: Optional[str] = None,
wait: bool = True,
**backend_kwargs,
) -> Union[pd.DataFrame, RemoteJob]:
"""
Run batch prediction for time series.

Parameters
----------
data
Historical time series in long format (DataFrame or S3 path).
target
Name of the target column to forecast.
id_column
Name of the item ID column.
timestamp_column
Name of the timestamp column.
known_covariates
Future values of known covariates (DataFrame or S3 path).
static_features
Metadata attributes of individual items (DataFrame or S3 path).
prediction_length
Number of time steps to forecast.
quantile_levels
Quantiles to predict.
hyperparameters
Model hyperparameters for inference. Overrides values passed to the constructor.
Available hyperparameters for each model are listed in the AutoGluon documentation.
output_path
S3 path to store predictions.
If None, will auto-generate under s3_output_path.
instance_type
Instance type for the prediction job.
If None, will use the default from the model registry.
wait
If True, block and return DataFrame. If False, return the job handle.
**backend_kwargs
Additional backend-specific arguments (e.g. job_name, custom_image_uri,
framework_version, volume_size).

Returns
-------
Union[pd.DataFrame, RemoteJob]
"""
raise NotImplementedError


class TabularFoundationModel(FoundationModel):
"""Foundation model for tabular prediction (Mitra, TabICL, etc.)."""

def predict(
self,
train_data: Union[str, pd.DataFrame],
test_data: Union[str, pd.DataFrame],
label: str = "target",
hyperparameters: Optional[Dict[str, Any]] = None,
output_path: Optional[str] = None,
instance_type: Optional[str] = None,
wait: bool = True,
**backend_kwargs,
) -> Union[pd.DataFrame, RemoteJob]:
"""
Run batch prediction for tabular tasks.

Parameters
----------
train_data
Labeled few-shot context for the foundation model.
test_data
Unlabeled data to predict on.
label
Target column name in train_data.
hyperparameters
Model hyperparameters for inference. Overrides values passed to the constructor.
Available hyperparameters for each model are listed in the AutoGluon documentation.
output_path
S3 path to store predictions.
If None, will auto-generate under s3_output_path.
instance_type
Instance type for the prediction job.
If None, will use the default from the model registry.
wait
If True, block and return DataFrame. If False, return the job handle.
**backend_kwargs
Additional backend-specific arguments (e.g. job_name, custom_image_uri,
framework_version, volume_size).

Returns
-------
Union[pd.DataFrame, RemoteJob]
"""
raise NotImplementedError

def predict_proba(
self,
train_data: Union[str, pd.DataFrame],
test_data: Union[str, pd.DataFrame],
label: str = "target",
hyperparameters: Optional[Dict[str, Any]] = None,
output_path: Optional[str] = None,
instance_type: Optional[str] = None,
wait: bool = True,
**backend_kwargs,
) -> Union[pd.DataFrame, RemoteJob]:
"""
Run batch prediction returning class probabilities.

Parameters
----------
train_data
Labeled few-shot context for the foundation model.
test_data
Unlabeled data to predict on.
label
Target column name in train_data.
hyperparameters
Model hyperparameters for inference. Overrides values passed to the constructor.
Available hyperparameters for each model are listed in the AutoGluon documentation.
output_path
S3 path to store predictions.
If None, will auto-generate under s3_output_path.
instance_type
Instance type for the prediction job.
If None, will use the default from the model registry.
wait
If True, block and return DataFrame. If False, return the job handle.
**backend_kwargs
Additional backend-specific arguments (e.g. job_name, custom_image_uri,
framework_version, volume_size).

Returns
-------
Union[pd.DataFrame, RemoteJob]
"""
raise NotImplementedError
75 changes: 75 additions & 0 deletions src/autogluon/cloud/model/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""Foundation model registry.

Maps model_id to AG-compatible configuration for deploy / predict.
"""

from typing import Any, Dict, Literal, TypedDict


class FoundationModelConfig(TypedDict):
task: Literal["forecasting", "classification", "regression"]
model_name: str # AG model class name (e.g. "Chronos", "Chronos2", "Mitra")
inference_hyperparameters: Dict[str, Any] # defaults for deploy() and predict()
training_hyperparameters: Dict[str, Any] # defaults for fit()
default_instance_type: str
fine_tunable: bool # whether .fit() is supported


FOUNDATION_MODEL_REGISTRY: dict[str, FoundationModelConfig] = {
"chronos-bolt-tiny": {
"task": "forecasting",
"model_name": "Chronos",
"inference_hyperparameters": {"model_path": "amazon/chronos-bolt-tiny"},
"training_hyperparameters": {"model_path": "amazon/chronos-bolt-tiny"},
"default_instance_type": "ml.g5.xlarge",
"fine_tunable": False,
},
"chronos-bolt-small": {
"task": "forecasting",
"model_name": "Chronos",
"inference_hyperparameters": {"model_path": "amazon/chronos-bolt-small"},
"training_hyperparameters": {"model_path": "amazon/chronos-bolt-small"},
"default_instance_type": "ml.g5.xlarge",
"fine_tunable": False,
},
"chronos-bolt-base": {
"task": "forecasting",
"model_name": "Chronos",
"inference_hyperparameters": {"model_path": "amazon/chronos-bolt-base"},
"training_hyperparameters": {"model_path": "amazon/chronos-bolt-base"},
"default_instance_type": "ml.g5.xlarge",
"fine_tunable": False,
},
"chronos-2": {
"task": "forecasting",
"model_name": "Chronos2",
"inference_hyperparameters": {"model_path": "amazon/chronos-2"},
"training_hyperparameters": {"model_path": "amazon/chronos-2", "fine_tune": True},
"default_instance_type": "ml.g5.xlarge",
"fine_tunable": True,
},
# TODO: Replace dummy configs with real values
"mitra-classification": {
"task": "classification",
"model_name": "Mitra",
"inference_hyperparameters": {"model_path": "TODO"},
"training_hyperparameters": {"model_path": "TODO"},
"default_instance_type": "ml.g5.xlarge",
"fine_tunable": False,
},
"mitra-regression": {
"task": "regression",
"model_name": "Mitra",
"inference_hyperparameters": {"model_path": "TODO"},
"training_hyperparameters": {"model_path": "TODO"},
"default_instance_type": "ml.g5.xlarge",
"fine_tunable": False,
},
}


def get_model_config(model_id: str) -> FoundationModelConfig:
if model_id not in FOUNDATION_MODEL_REGISTRY:
available = list(FOUNDATION_MODEL_REGISTRY.keys())
raise ValueError(f"Unknown model_id '{model_id}'. Available models: {available}")
return FOUNDATION_MODEL_REGISTRY[model_id]
Loading