-
Notifications
You must be signed in to change notification settings - Fork 17
Add a skeleton for the foundation model class #217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
413ec07
0e2d613
fb92f32
371ff0d
901cba3
ad18b97
9141907
e4a725e
26e0741
8f90da3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| from .foundation_model import FoundationModel, TabularFoundationModel, TimeSeriesFoundationModel | ||
|
|
||
|
shchur marked this conversation as resolved.
|
||
| __all__ = ["FoundationModel", "TabularFoundationModel", "TimeSeriesFoundationModel"] | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Question: Are TabularFoundationModel and TimeSeriesFoundationModel part of the public API, or internal to the FoundationModel factory? They're in def forecast(model: TimeSeriesFoundationModel, df: pd.DataFrame) -> pd.DataFrame:
return model.predict(df, prediction_length=24)
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, with the current design the user should be able to directly create a model = TimeSeriesFoundationModel(model_id="chronos-bolt-base")
predictions = model.predict(...)The |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,316 @@ | ||
| """FoundationModel — deploy and predict with pretrained foundation models on AWS.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from abc import abstractmethod | ||
| from typing import Any, Dict, List, Literal, Optional, Union | ||
|
|
||
| import pandas as pd | ||
|
|
||
| from autogluon.cloud.endpoint.endpoint import Endpoint | ||
| from autogluon.cloud.job.remote_job import RemoteJob | ||
|
|
||
| from .registry import get_model_config | ||
|
|
||
|
|
||
| class FoundationModel: | ||
| """ | ||
| Pretrained foundation model inference on AWS. | ||
|
|
||
| Factory: FoundationModel("chronos-bolt-base", ...) returns the appropriate | ||
| task-specific subclass (TimeSeriesFoundationModel, TabularFoundationModel). | ||
|
|
||
| Examples | ||
| -------- | ||
| >>> model = FoundationModel("chronos-bolt-base", role_arn="arn:...", hyperparameters={"model_path": "s3://cached/"}) | ||
| >>> endpoint = model.deploy() | ||
| >>> predictions = endpoint.predict(data) | ||
|
shchur marked this conversation as resolved.
|
||
| >>> endpoint.delete_endpoint() | ||
| """ | ||
|
|
||
| def __new__(cls, model_id: str, **kwargs) -> "FoundationModel": | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @abdulfatir,
|
||
| if cls is not FoundationModel: | ||
| return super().__new__(cls) | ||
| config = get_model_config(model_id) | ||
| task = config["task"] | ||
| if task == "forecasting": | ||
| return super().__new__(TimeSeriesFoundationModel) | ||
| elif task in ("classification", "regression"): | ||
| return super().__new__(TabularFoundationModel) | ||
| raise ValueError(f"Unsupported task: {task}") | ||
|
|
||
| def __init__( | ||
| self, | ||
| model_id: str, | ||
| backend: Literal["sagemaker"] = "sagemaker", | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @abdulfatir |
||
| role_arn: Optional[str] = None, | ||
| region: Optional[str] = None, | ||
| s3_output_path: Optional[str] = None, | ||
| hyperparameters: Optional[Dict[str, Any]] = None, | ||
| ): | ||
| self.model_id = model_id | ||
| self.role_arn = role_arn | ||
| self.region = region | ||
| self.s3_output_path = s3_output_path | ||
| self._config = get_model_config(model_id) | ||
| self._hyperparameter_overrides = hyperparameters or {} | ||
| # TODO: instantiate backend via BackendFactory | ||
| self._backend_type = backend | ||
|
|
||
| def _get_hyperparameters( | ||
| self, context: Literal["inference", "training"], overrides: Optional[Dict[str, Any]] = None | ||
| ) -> Dict[str, Any]: | ||
| config_key = "inference_hyperparameters" if context == "inference" else "training_hyperparameters" | ||
| return self._config.get(config_key, {}) | self._hyperparameter_overrides | (overrides or {}) | ||
|
|
||
| def deploy( | ||
| self, | ||
| instance_type: Optional[str] = None, | ||
| endpoint_name: Optional[str] = None, | ||
| hyperparameters: Optional[Dict[str, Any]] = None, | ||
| wait: bool = True, | ||
| **backend_kwargs, | ||
| ) -> Endpoint: | ||
| """ | ||
| Deploy model to an endpoint. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| instance_type | ||
| Instance type for the endpoint. | ||
| If None, will use the default from the model registry. | ||
| endpoint_name | ||
| Custom endpoint name. | ||
| If None, will auto-generate a unique name. | ||
| hyperparameters | ||
| Model hyperparameters for inference. Overrides values passed to the constructor. | ||
| Available hyperparameters for each model are listed in the AutoGluon documentation. | ||
| wait | ||
| Whether to block until the endpoint is ready. | ||
| **backend_kwargs | ||
| Backend-specific arguments. Use these to configure serverless, async, or | ||
| autoscaling (e.g. memory_size_in_mb, max_concurrency, initial_instance_count). | ||
|
|
||
| Returns | ||
| ------- | ||
| Endpoint | ||
| """ | ||
| raise NotImplementedError | ||
|
|
||
| @abstractmethod | ||
| def predict(self, data: Union[str, pd.DataFrame], wait: bool = True, **kwargs) -> Union[pd.DataFrame, RemoteJob]: | ||
| """Subclasses override with task-specific signature. | ||
|
|
||
| When wait=False, returns a RemoteJob handle. Use job.get_job_status() to poll | ||
| and job.get_output_path() to get the S3 path to predictions once complete. | ||
| """ | ||
| # TODO: consider adding a .result() method to RemoteJob that downloads + parses output | ||
|
Comment on lines
+104
to
+107
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @melopeo I missed the top-level comment. Here is a temporary solution, I will figure out the best path when implementing the backend. |
||
| ... | ||
|
|
||
| def fit( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Putting fit() on FoundationModel (the base class) implies every foundation model can be fine-tuned. In general it makes sense to support fine-tuning of a pretrained model. My question is if we are assuming all FoundationModels under consideration will have a fine-tuning capability. If a foundation model can't be fine-tuned, how are we planning to handle that?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good question, I added a boolean flag to the |
||
| self, | ||
| train_data: Union[str, pd.DataFrame], | ||
| output_path: Optional[str] = None, | ||
| instance_type: Optional[str] = None, | ||
| hyperparameters: Optional[Dict[str, Any]] = None, | ||
| wait: bool = True, | ||
| **kwargs, | ||
| ) -> "FoundationModel": | ||
| """ | ||
| Fine-tune the model. Returns a new FoundationModel pointing to the fine-tuned artifact. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| train_data | ||
| Training data as DataFrame or S3 path. | ||
| output_path | ||
| S3 path to store fine-tuned model. | ||
| If None, will auto-generate under s3_output_path. | ||
| instance_type | ||
| Instance type for the training job. | ||
| If None, will use the default from the model registry. | ||
| hyperparameters | ||
| Model hyperparameters for training. Overrides values passed to the constructor. | ||
| Available hyperparameters for each model are listed in the AutoGluon documentation. | ||
| wait | ||
| If True, block until training completes. | ||
|
|
||
| Returns | ||
| ------- | ||
| FoundationModel | ||
| New instance with hyperparameters pointing to the fine-tuned artifact. | ||
| """ | ||
| if not self._config.get("fine_tunable", False): | ||
| raise ValueError(f"Model '{self.model_id}' does not support fine-tuning.") | ||
| raise NotImplementedError | ||
|
|
||
| def cache_model_artifact(self, s3_path: str) -> str: | ||
| """ | ||
| Pre-cache model weights to S3 (for VPC-deployed endpoints). | ||
|
|
||
| Launches a small job that downloads weights from HuggingFace | ||
| and uploads them to S3. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| s3_path | ||
| S3 path where the model weights should be cached. | ||
|
|
||
| Returns | ||
| ------- | ||
| str | ||
| S3 path to the cached artifact. | ||
| """ | ||
| raise NotImplementedError | ||
|
|
||
|
|
||
| class TimeSeriesFoundationModel(FoundationModel): | ||
| """Foundation model for time series forecasting (Chronos, etc.).""" | ||
|
|
||
| def predict( | ||
| self, | ||
| data: Union[str, pd.DataFrame], | ||
| target: str = "target", | ||
| id_column: str = "item_id", | ||
| timestamp_column: str = "timestamp", | ||
| known_covariates: Optional[Union[str, pd.DataFrame]] = None, | ||
| static_features: Optional[Union[str, pd.DataFrame]] = None, | ||
| prediction_length: int = 1, | ||
| quantile_levels: Optional[List[float]] = None, | ||
| hyperparameters: Optional[Dict[str, Any]] = None, | ||
| output_path: Optional[str] = None, | ||
| instance_type: Optional[str] = None, | ||
| wait: bool = True, | ||
| **backend_kwargs, | ||
| ) -> Union[pd.DataFrame, RemoteJob]: | ||
| """ | ||
| Run batch prediction for time series. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| data | ||
| Historical time series in long format (DataFrame or S3 path). | ||
| target | ||
| Name of the target column to forecast. | ||
| id_column | ||
| Name of the item ID column. | ||
| timestamp_column | ||
| Name of the timestamp column. | ||
| known_covariates | ||
| Future values of known covariates (DataFrame or S3 path). | ||
| static_features | ||
| Metadata attributes of individual items (DataFrame or S3 path). | ||
| prediction_length | ||
| Number of time steps to forecast. | ||
| quantile_levels | ||
| Quantiles to predict. | ||
| hyperparameters | ||
| Model hyperparameters for inference. Overrides values passed to the constructor. | ||
| Available hyperparameters for each model are listed in the AutoGluon documentation. | ||
| output_path | ||
| S3 path to store predictions. | ||
| If None, will auto-generate under s3_output_path. | ||
| instance_type | ||
| Instance type for the prediction job. | ||
| If None, will use the default from the model registry. | ||
| wait | ||
| If True, block and return DataFrame. If False, return the job handle. | ||
| **backend_kwargs | ||
| Additional backend-specific arguments (e.g. job_name, custom_image_uri, | ||
| framework_version, volume_size). | ||
|
|
||
| Returns | ||
| ------- | ||
| Union[pd.DataFrame, RemoteJob] | ||
| """ | ||
| raise NotImplementedError | ||
|
|
||
|
|
||
| class TabularFoundationModel(FoundationModel): | ||
| """Foundation model for tabular prediction (Mitra, TabICL, etc.).""" | ||
|
|
||
| def predict( | ||
| self, | ||
| train_data: Union[str, pd.DataFrame], | ||
| test_data: Union[str, pd.DataFrame], | ||
| label: str = "target", | ||
| hyperparameters: Optional[Dict[str, Any]] = None, | ||
| output_path: Optional[str] = None, | ||
| instance_type: Optional[str] = None, | ||
| wait: bool = True, | ||
| **backend_kwargs, | ||
| ) -> Union[pd.DataFrame, RemoteJob]: | ||
| """ | ||
| Run batch prediction for tabular tasks. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| train_data | ||
| Labeled few-shot context for the foundation model. | ||
| test_data | ||
| Unlabeled data to predict on. | ||
| label | ||
| Target column name in train_data. | ||
| hyperparameters | ||
| Model hyperparameters for inference. Overrides values passed to the constructor. | ||
| Available hyperparameters for each model are listed in the AutoGluon documentation. | ||
| output_path | ||
| S3 path to store predictions. | ||
| If None, will auto-generate under s3_output_path. | ||
| instance_type | ||
| Instance type for the prediction job. | ||
| If None, will use the default from the model registry. | ||
| wait | ||
| If True, block and return DataFrame. If False, return the job handle. | ||
| **backend_kwargs | ||
| Additional backend-specific arguments (e.g. job_name, custom_image_uri, | ||
| framework_version, volume_size). | ||
|
|
||
| Returns | ||
| ------- | ||
| Union[pd.DataFrame, RemoteJob] | ||
| """ | ||
| raise NotImplementedError | ||
|
|
||
| def predict_proba( | ||
| self, | ||
| train_data: Union[str, pd.DataFrame], | ||
| test_data: Union[str, pd.DataFrame], | ||
| label: str = "target", | ||
| hyperparameters: Optional[Dict[str, Any]] = None, | ||
| output_path: Optional[str] = None, | ||
| instance_type: Optional[str] = None, | ||
| wait: bool = True, | ||
| **backend_kwargs, | ||
| ) -> Union[pd.DataFrame, RemoteJob]: | ||
| """ | ||
| Run batch prediction returning class probabilities. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| train_data | ||
| Labeled few-shot context for the foundation model. | ||
| test_data | ||
| Unlabeled data to predict on. | ||
| label | ||
| Target column name in train_data. | ||
| hyperparameters | ||
| Model hyperparameters for inference. Overrides values passed to the constructor. | ||
| Available hyperparameters for each model are listed in the AutoGluon documentation. | ||
| output_path | ||
| S3 path to store predictions. | ||
| If None, will auto-generate under s3_output_path. | ||
| instance_type | ||
| Instance type for the prediction job. | ||
| If None, will use the default from the model registry. | ||
| wait | ||
| If True, block and return DataFrame. If False, return the job handle. | ||
| **backend_kwargs | ||
| Additional backend-specific arguments (e.g. job_name, custom_image_uri, | ||
| framework_version, volume_size). | ||
|
|
||
| Returns | ||
| ------- | ||
| Union[pd.DataFrame, RemoteJob] | ||
| """ | ||
| raise NotImplementedError | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,75 @@ | ||
| """Foundation model registry. | ||
|
|
||
| Maps model_id to AG-compatible configuration for deploy / predict. | ||
| """ | ||
|
|
||
| from typing import Any, Dict, Literal, TypedDict | ||
|
|
||
|
|
||
| class FoundationModelConfig(TypedDict): | ||
| task: Literal["forecasting", "classification", "regression"] | ||
| model_name: str # AG model class name (e.g. "Chronos", "Chronos2", "Mitra") | ||
| inference_hyperparameters: Dict[str, Any] # defaults for deploy() and predict() | ||
| training_hyperparameters: Dict[str, Any] # defaults for fit() | ||
| default_instance_type: str | ||
| fine_tunable: bool # whether .fit() is supported | ||
|
|
||
|
|
||
| FOUNDATION_MODEL_REGISTRY: dict[str, FoundationModelConfig] = { | ||
| "chronos-bolt-tiny": { | ||
| "task": "forecasting", | ||
| "model_name": "Chronos", | ||
| "inference_hyperparameters": {"model_path": "amazon/chronos-bolt-tiny"}, | ||
| "training_hyperparameters": {"model_path": "amazon/chronos-bolt-tiny"}, | ||
| "default_instance_type": "ml.g5.xlarge", | ||
| "fine_tunable": False, | ||
| }, | ||
| "chronos-bolt-small": { | ||
| "task": "forecasting", | ||
| "model_name": "Chronos", | ||
| "inference_hyperparameters": {"model_path": "amazon/chronos-bolt-small"}, | ||
| "training_hyperparameters": {"model_path": "amazon/chronos-bolt-small"}, | ||
| "default_instance_type": "ml.g5.xlarge", | ||
| "fine_tunable": False, | ||
| }, | ||
| "chronos-bolt-base": { | ||
| "task": "forecasting", | ||
| "model_name": "Chronos", | ||
| "inference_hyperparameters": {"model_path": "amazon/chronos-bolt-base"}, | ||
| "training_hyperparameters": {"model_path": "amazon/chronos-bolt-base"}, | ||
| "default_instance_type": "ml.g5.xlarge", | ||
| "fine_tunable": False, | ||
| }, | ||
| "chronos-2": { | ||
| "task": "forecasting", | ||
| "model_name": "Chronos2", | ||
| "inference_hyperparameters": {"model_path": "amazon/chronos-2"}, | ||
| "training_hyperparameters": {"model_path": "amazon/chronos-2", "fine_tune": True}, | ||
| "default_instance_type": "ml.g5.xlarge", | ||
| "fine_tunable": True, | ||
| }, | ||
| # TODO: Replace dummy configs with real values | ||
| "mitra-classification": { | ||
| "task": "classification", | ||
| "model_name": "Mitra", | ||
| "inference_hyperparameters": {"model_path": "TODO"}, | ||
| "training_hyperparameters": {"model_path": "TODO"}, | ||
| "default_instance_type": "ml.g5.xlarge", | ||
| "fine_tunable": False, | ||
| }, | ||
| "mitra-regression": { | ||
| "task": "regression", | ||
| "model_name": "Mitra", | ||
| "inference_hyperparameters": {"model_path": "TODO"}, | ||
| "training_hyperparameters": {"model_path": "TODO"}, | ||
| "default_instance_type": "ml.g5.xlarge", | ||
| "fine_tunable": False, | ||
| }, | ||
| } | ||
|
|
||
|
|
||
| def get_model_config(model_id: str) -> FoundationModelConfig: | ||
| if model_id not in FOUNDATION_MODEL_REGISTRY: | ||
| available = list(FOUNDATION_MODEL_REGISTRY.keys()) | ||
| raise ValueError(f"Unknown model_id '{model_id}'. Available models: {available}") | ||
| return FOUNDATION_MODEL_REGISTRY[model_id] |
Uh oh!
There was an error while loading. Please reload this page.