Skip to content

Add a skeleton for the foundation model class#217

Merged
shchur merged 10 commits into
autogluon:masterfrom
shchur:foundation-model-api
May 15, 2026
Merged

Add a skeleton for the foundation model class#217
shchur merged 10 commits into
autogluon:masterfrom
shchur:foundation-model-api

Conversation

@shchur
Copy link
Copy Markdown
Collaborator

@shchur shchur commented May 12, 2026

Issue #, if available:

  • This PR outlines the user-facing API for the new FoundationModel class in AutoGluon Cloud. This class allows users to deploy and run inference with pretrained foundation models (Chronos, Mitra, etc.) on AWS.
  • How is this different from sagemaker.jumpstart.JumpStartModel?
    • DataFrame in / DataFrame out (no manual JSON payload construction)
    • Supports batch prediction and fine-tuning out of the box
    • Model onboarding is much easier for us in the future - any model available in AutoGluon works as soon as we add it to the config
  • How is this different from TimeSeriesCloudPredictor?
    • No fit() required before deploy() or predict() — foundation models work out of the box
    • Cleaner API: the user does need to provide nested dictionaries of predictor_init_kwargs and predictor_fit_kwargs
    • The FoundationModel class is stateless — no need to worry about the underlying state (attached artifacts, endpoints, etc)

Example usage

# ============================================================
# Time Series — Chronos
# ============================================================
from autogluon.cloud import FoundationModel

model = FoundationModel("chronos-bolt-base", role_arn="arn:aws:iam::123456789:role/SageMakerRole")

# --- Deploy an endpoint ---
endpoint = model.deploy(instance_type="ml.g5.xlarge")
predictions = endpoint.predict(df, prediction_length=24)
endpoint.delete_endpoint()

# --- Batch predict (large dataset, uses SageMaker Training job under the hood) ---
predictions = model.predict(df, prediction_length=24, quantile_levels=[0.1, 0.5, 0.9])

# --- Batch predict (async) ---
job = model.predict(df, prediction_length=24, wait=False)
# ... later ...
predictions = job.result()

# --- Fine-tune, then deploy ---
fine_tuned = model.fit(train_data=df, instance_type="ml.g5.xlarge")
endpoint = fine_tuned.deploy()

# --- Use custom/cached weights (VPC-safe) ---
model = FoundationModel(
    "chronos-bolt-base",
    hyperparameters={"model_path": "s3://my-bucket/cached-weights/"},
)
endpoint = model.deploy()

# --- Override hyperparameters at call time ---
predictions = model.predict(df, prediction_length=24, hyperparameters={"cross_learning": False})

# ============================================================
# Tabular — Mitra / TabICL
# ============================================================
model = FoundationModel("mitra-classification", role_arn="arn:aws:iam::123456789:role/SageMakerRole")

predictions = model.predict(train_data=train_df, test_data=test_df, label="class")

To Do:

  • Wiring the user-facing API to the internal backend that actually runs stuff on AWS

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

Copy link
Copy Markdown
Collaborator

@AnirudhDagar AnirudhDagar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome design and I like this API! Just one thing to note, #213 introduces autogluon.cloud.init() which persists the user's role_arn, region, and S3 bucket to ~/.autogluon/cloud.yaml at first-time setup. If FoundationModel falls back to that config when those args aren't passed, usage collapses to:

# one time per account
import autogluon.cloud as agc
agc.init()

# Every session after that
from autogluon.cloud import FoundationModel

model = FoundationModel("chronos-bolt-base")            # no role_arn needed it will automatically be picked
endpoint = model.deploy(instance_type="ml.g5.xlarge")
predictions = endpoint.predict(data=df, prediction_length=24)

I can add this, once both the PRs are shipped.

@prateekdesai04
Copy link
Copy Markdown
Contributor

AG-Cloud requires SageMaker as a package, it would be good to assume that customers may have v3 as well.
v3 has quite some API changes, do these changes fit v3 and v2 both (should be the case ideally)?
Here is also a migration by the Sagemaker team - https://github.com/nargokul/sagemaker-python-sdk/blob/master/migration.md#migration-tool-mcp-server

@abdulfatir
Copy link
Copy Markdown

Thanks @shchur! This looks great and would really simplify the user experience with FMs. A few comments:

  • Is FoundationModel supposed to be a magic class which would automatically infer the model/task type? Maybe we can add it more concrete in some way, e.g., by adding a task_type kwarg or subclassing the base FoundationModel class?
  • Related to above: how do we plan to handle model-specific hyperparameters (e.g., cross_learning, n_estimators, etc.)? What exactly would model_config map to internally?
  • For the batch prediction mode, we will have SM job start overhead. Is there a way to reduce this by using another AWS service here (Lambda?)?

Comment thread src/autogluon/cloud/model/__init__.py
Comment thread src/autogluon/cloud/model/__init__.py
>>> endpoint.delete_endpoint()
"""

def __new__(cls, model_id: str, **kwargs) -> "FoundationModel":
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @abdulfatir,

  1. The FoundationModel class automatically resolves to TimeSeriesFoundationModel, TabularFoundationModel, etc based on the task associated with each model_id. The subclasses only differ in their API for predict: e.g. TimeSeriesFoundationModel takes historical data + known covariates as input, and TabularFoundationModel takes train & test DFs as input

Comment thread src/autogluon/cloud/model/registry.py Outdated
class FoundationModelConfig(TypedDict):
task: Literal["forecasting", "classification", "regression"]
model_name: str # AG model class name (e.g. "Chronos", "Chronos2", "Mitra")
model_config: Dict[str, Any] # passed to the AG model (e.g. {"model_path": "..."})
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@abdulfatir
2. model_config = hyperparameters provided to the TimeSeriesPredictor/TabularPredictor under the hood. Now that you pointed this out, I think that we should separate the training and inference configs, e.g.

"chronos-2": {
    "task": "forecasting",
    "model_name": "Chronos2",
    "inference_config": {"model_path": "amazon/chronos-2"},
    "training_config": {"model_path": "amazon/chronos-2", "fine_tune": True},
    ...
},

The user can provide kwargs to deploy(), predict(), fit() that updates the default inference/train configs stored in the registry.

Do you think we should allow the user to configure these hyperparameters inside the payload (e.g. passing cross_learning: True inside the call to endpoint.predict), or is it okay if this is fixed at the creation of the endpoint?

def __init__(
self,
model_id: str,
backend: Literal["sagemaker"] = "sagemaker",
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@abdulfatir
3. We can add support for other backends like Lambda in the future - but I would decouple this from the FoundationModel API so that the user doesn't need to worry about it

@shchur shchur changed the title [WIP] Add a skeleton for the foundation model class Add a skeleton for the foundation model class May 13, 2026
@shchur shchur requested review from AnirudhDagar and melopeo May 13, 2026 15:57
@@ -0,0 +1,3 @@
from .foundation_model import FoundationModel, TabularFoundationModel, TimeSeriesFoundationModel

__all__ = ["FoundationModel", "TabularFoundationModel", "TimeSeriesFoundationModel"]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question:

Are TabularFoundationModel and TimeSeriesFoundationModel part of the public API, or internal to the FoundationModel factory? They're in __all__, but examples only use FoundationModel(...). For instance, should users be able to write:

def forecast(model: TimeSeriesFoundationModel, df: pd.DataFrame) -> pd.DataFrame:
    return model.predict(df, prediction_length=24)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, with the current design the user should be able to directly create a TimeSeriesFoundationModel object bypassing the FoundationModel base class, e.g.

model = TimeSeriesFoundationModel(model_id="chronos-bolt-base")
predictions = model.predict(...)

The FoundationModel base class only exists for convenience.

Comment thread src/autogluon/cloud/model/foundation_model.py
"""Subclasses override with task-specific signature."""
...

def fit(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Putting fit() on FoundationModel (the base class) implies every foundation model can be fine-tuned.

In general it makes sense to support fine-tuning of a pretrained model. My question is if we are assuming all FoundationModels under consideration will have a fine-tuning capability. If a foundation model can't be fine-tuned, how are we planning to handle that?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, I added a boolean flag to the ModelConfig indicating whether fine-tuning is supported.

@melopeo
Copy link
Copy Markdown
Collaborator

melopeo commented May 15, 2026

Hey @shchur , thanks for this. It looks very nice.

Question: The example calls job.result(), but RemoteJob doesn't define it. Is the plan to add it in a follow-up?

Comment on lines +104 to +107
When wait=False, returns a RemoteJob handle. Use job.get_job_status() to poll
and job.get_output_path() to get the S3 path to predictions once complete.
"""
# TODO: consider adding a .result() method to RemoteJob that downloads + parses output
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@melopeo I missed the top-level comment. Here is a temporary solution, I will figure out the best path when implementing the backend.

@shchur shchur merged commit 84c4f9e into autogluon:master May 15, 2026
12 checks passed
@shchur shchur deleted the foundation-model-api branch May 15, 2026 10:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants