Add a skeleton for the foundation model class#217
Conversation
AnirudhDagar
left a comment
There was a problem hiding this comment.
Awesome design and I like this API! Just one thing to note, #213 introduces autogluon.cloud.init() which persists the user's role_arn, region, and S3 bucket to ~/.autogluon/cloud.yaml at first-time setup. If FoundationModel falls back to that config when those args aren't passed, usage collapses to:
# one time per account
import autogluon.cloud as agc
agc.init()
# Every session after that
from autogluon.cloud import FoundationModel
model = FoundationModel("chronos-bolt-base") # no role_arn needed it will automatically be picked
endpoint = model.deploy(instance_type="ml.g5.xlarge")
predictions = endpoint.predict(data=df, prediction_length=24)I can add this, once both the PRs are shipped.
|
AG-Cloud requires SageMaker as a package, it would be good to assume that customers may have v3 as well. |
|
Thanks @shchur! This looks great and would really simplify the user experience with FMs. A few comments:
|
| >>> endpoint.delete_endpoint() | ||
| """ | ||
|
|
||
| def __new__(cls, model_id: str, **kwargs) -> "FoundationModel": |
There was a problem hiding this comment.
Thanks @abdulfatir,
- The
FoundationModelclass automatically resolves toTimeSeriesFoundationModel,TabularFoundationModel, etc based on the task associated with eachmodel_id. The subclasses only differ in their API forpredict: e.g.TimeSeriesFoundationModeltakes historical data + known covariates as input, andTabularFoundationModeltakes train & test DFs as input
| class FoundationModelConfig(TypedDict): | ||
| task: Literal["forecasting", "classification", "regression"] | ||
| model_name: str # AG model class name (e.g. "Chronos", "Chronos2", "Mitra") | ||
| model_config: Dict[str, Any] # passed to the AG model (e.g. {"model_path": "..."}) |
There was a problem hiding this comment.
@abdulfatir
2. model_config = hyperparameters provided to the TimeSeriesPredictor/TabularPredictor under the hood. Now that you pointed this out, I think that we should separate the training and inference configs, e.g.
"chronos-2": {
"task": "forecasting",
"model_name": "Chronos2",
"inference_config": {"model_path": "amazon/chronos-2"},
"training_config": {"model_path": "amazon/chronos-2", "fine_tune": True},
...
},The user can provide kwargs to deploy(), predict(), fit() that updates the default inference/train configs stored in the registry.
Do you think we should allow the user to configure these hyperparameters inside the payload (e.g. passing cross_learning: True inside the call to endpoint.predict), or is it okay if this is fixed at the creation of the endpoint?
| def __init__( | ||
| self, | ||
| model_id: str, | ||
| backend: Literal["sagemaker"] = "sagemaker", |
There was a problem hiding this comment.
@abdulfatir
3. We can add support for other backends like Lambda in the future - but I would decouple this from the FoundationModel API so that the user doesn't need to worry about it
| @@ -0,0 +1,3 @@ | |||
| from .foundation_model import FoundationModel, TabularFoundationModel, TimeSeriesFoundationModel | |||
|
|
|||
| __all__ = ["FoundationModel", "TabularFoundationModel", "TimeSeriesFoundationModel"] | |||
There was a problem hiding this comment.
Question:
Are TabularFoundationModel and TimeSeriesFoundationModel part of the public API, or internal to the FoundationModel factory? They're in __all__, but examples only use FoundationModel(...). For instance, should users be able to write:
def forecast(model: TimeSeriesFoundationModel, df: pd.DataFrame) -> pd.DataFrame:
return model.predict(df, prediction_length=24)There was a problem hiding this comment.
Yes, with the current design the user should be able to directly create a TimeSeriesFoundationModel object bypassing the FoundationModel base class, e.g.
model = TimeSeriesFoundationModel(model_id="chronos-bolt-base")
predictions = model.predict(...)The FoundationModel base class only exists for convenience.
| """Subclasses override with task-specific signature.""" | ||
| ... | ||
|
|
||
| def fit( |
There was a problem hiding this comment.
Putting fit() on FoundationModel (the base class) implies every foundation model can be fine-tuned.
In general it makes sense to support fine-tuning of a pretrained model. My question is if we are assuming all FoundationModels under consideration will have a fine-tuning capability. If a foundation model can't be fine-tuned, how are we planning to handle that?
There was a problem hiding this comment.
Good question, I added a boolean flag to the ModelConfig indicating whether fine-tuning is supported.
|
Hey @shchur , thanks for this. It looks very nice. Question: The example calls job.result(), but RemoteJob doesn't define it. Is the plan to add it in a follow-up? |
| When wait=False, returns a RemoteJob handle. Use job.get_job_status() to poll | ||
| and job.get_output_path() to get the S3 path to predictions once complete. | ||
| """ | ||
| # TODO: consider adding a .result() method to RemoteJob that downloads + parses output |
There was a problem hiding this comment.
@melopeo I missed the top-level comment. Here is a temporary solution, I will figure out the best path when implementing the backend.
Issue #, if available:
FoundationModelclass in AutoGluon Cloud. This class allows users to deploy and run inference with pretrained foundation models (Chronos, Mitra, etc.) on AWS.sagemaker.jumpstart.JumpStartModel?TimeSeriesCloudPredictor?fit()required beforedeploy()orpredict()— foundation models work out of the boxpredictor_init_kwargsandpredictor_fit_kwargsFoundationModelclass is stateless — no need to worry about the underlying state (attached artifacts, endpoints, etc)Example usage
To Do:
By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.