diff --git a/clarifai/cli/artifact.py b/clarifai/cli/artifact.py index 1c8aa7ab..db7dfce5 100644 --- a/clarifai/cli/artifact.py +++ b/clarifai/cli/artifact.py @@ -410,7 +410,11 @@ def delete(ctx, path, force): prompt_msg = f"Are you sure you want to delete artifact version '{version_id}'?" else: prompt_msg = f"Are you sure you want to delete artifact '{parsed['artifact_id']}'?" - if not click.confirm(prompt_msg): + + # Using prompt_yes_no for better automation support + from clarifai.utils.cli import prompt_yes_no + + if not prompt_yes_no(prompt_msg, default=False): click.echo("Operation cancelled") return diff --git a/clarifai/cli/base.py b/clarifai/cli/base.py index c6069180..d82958b0 100644 --- a/clarifai/cli/base.py +++ b/clarifai/cli/base.py @@ -73,6 +73,9 @@ def login(ctx, api_url, user_id): # Input user_id if not supplied if not user_id: + if not sys.stdin.isatty(): + logger.error("User ID is required for login in non-interactive mode.") + raise click.Abort() user_id = click.prompt('Enter your Clarifai user ID', type=str) click.echo() # Blank line for readability @@ -96,8 +99,13 @@ def login(ctx, api_url, user_id): click.echo('\n> Verifying token...') validate_context_auth(pat, user_id, api_url) - # Save context with default name - context_name = 'default' + # Context naming + default_context_name = 'default' + click.echo('\n> Let\'s save these credentials to a new context.') + click.echo('> You can have multiple contexts to easily switch between accounts or projects.\n') + context_name = click.prompt("Enter a name for this context", default=default_context_name) + + # Save context context = Context( context_name, CLARIFAI_API_BASE=api_url, @@ -119,12 +127,17 @@ def pat_display(pat): def input_or_default(prompt, default): + if default is not None: + logger.info(f"{prompt} [Using default: {default}]") + return default + if not sys.stdin.isatty(): + raise click.Abort() value = input(prompt) return value if value else default # Context management commands under config group -@config.command(aliases=['get-contexts', 'list-contexts', 'ls']) +@config.command(aliases=['get-contexts', 'list-contexts', 'ls', 'list']) @click.option( '-o', '--output-format', default='wide', type=click.Choice(['wide', 'name', 'json', 'yaml']) ) @@ -138,6 +151,7 @@ def get_contexts(ctx, output_format): 'USER_ID': lambda c: c.user_id, 'API_BASE': lambda c: c.api_base, 'PAT': lambda c: pat_display(c.pat), + 'HF_TOKEN': lambda c: pat_display(c.hf_token) if hasattr(c, 'hf_token') else "", } additional_columns = set() for cont in ctx.obj.contexts.values(): @@ -213,6 +227,9 @@ def create_context( click.secho(f'Error: Context "{name}" already exists', fg='red', err=True) sys.exit(1) if not user_id: + if not sys.stdin.isatty(): + click.echo("Error: user-id is required in non-interactive mode.", err=True) + sys.exit(1) user_id = input('user id: ') if not base_url: base_url = input_or_default( diff --git a/clarifai/cli/deployment.py b/clarifai/cli/deployment.py index 92a345b9..c92ddf26 100644 --- a/clarifai/cli/deployment.py +++ b/clarifai/cli/deployment.py @@ -16,24 +16,53 @@ def deployment(): @deployment.command(['c']) -@click.argument('nodepool_id') -@click.argument('deployment_id') @click.option( '--config', type=click.Path(exists=True), required=True, - help='Path to the deployment config file.', + help='Path to the deployment config YAML file.', ) @click.pass_context -def create(ctx, nodepool_id, deployment_id, config): - """Create a new Deployment with the given config file.""" +def create(ctx, config): + """ + Create a new Deployment from a config file. + + The config file is a YAML that defines the worker (model or workflow), + nodepools, autoscale settings, and visibility. + + Ex: clarifai deployment create --config deployment.yaml + + Example deployment.yaml: + + \b + deployment: + id: "my-deployment" + worker: + model: + id: "model-id" + model_version: + id: "version-id" + user_id: "owner-id" + app_id: "app-id" + nodepools: + - id: "nodepool-id" + compute_cluster: + id: "cluster-id" + user_id: "cluster-owner-id" + autoscale_config: + min_replicas: 1 + max_replicas: 1 + scale_to_zero_delay_seconds: 300 + deploy_latest_version: true + + """ from clarifai.client.nodepool import Nodepool validate_context(ctx) - if not nodepool_id: - deployment_config = from_yaml(config) - nodepool_id = deployment_config['deployment']['nodepools'][0]['id'] + deployment_config = from_yaml(config) + nodepool_id = deployment_config['deployment']['nodepools'][0]['id'] + deployment_id = deployment_config['deployment']['id'] nodepool = Nodepool( nodepool_id=nodepool_id, @@ -41,47 +70,67 @@ def create(ctx, nodepool_id, deployment_id, config): pat=ctx.obj.current.pat, base_url=ctx.obj.current.api_base, ) - if deployment_id: - nodepool.create_deployment(config, deployment_id=deployment_id) - else: - nodepool.create_deployment(config) + nodepool.create_deployment(config, deployment_id=deployment_id) @deployment.command(['ls']) -@click.argument('nodepool_id', default="") +@click.option('--nodepool_id', required=False, help='Nodepool ID to list deployments for.') +@click.option( + '--compute_cluster_id', required=False, help='Compute cluster ID to list deployments for.' +) @click.option('--page_no', required=False, help='Page number to list.', default=1) @click.option('--per_page', required=False, help='Number of items per page.', default=16) @click.pass_context -def list(ctx, nodepool_id, page_no, per_page): +def list(ctx, nodepool_id, compute_cluster_id, page_no, per_page): """List all deployments for the nodepool.""" + from clarifai_grpc.grpc.api import resources_pb2 + from clarifai.client.compute_cluster import ComputeCluster from clarifai.client.nodepool import Nodepool from clarifai.client.user import User validate_context(ctx) if nodepool_id: + kwargs = {} + if compute_cluster_id: + kwargs['compute_cluster'] = resources_pb2.ComputeCluster(id=compute_cluster_id) nodepool = Nodepool( nodepool_id=nodepool_id, user_id=ctx.obj.current.user_id, pat=ctx.obj.current.pat, base_url=ctx.obj.current.api_base, + **kwargs, ) response = nodepool.list_deployments(page_no=page_no, per_page=per_page) else: - user = User( - user_id=ctx.obj.current.user_id, - pat=ctx.obj.current.pat, - base_url=ctx.obj.current.api_base, - ) - ccs = user.list_compute_clusters(page_no, per_page) - nps = [] - for cc in ccs: - compute_cluster = ComputeCluster( - compute_cluster_id=cc.id, + if compute_cluster_id: + ccs = [ + ComputeCluster( + compute_cluster_id=compute_cluster_id, + user_id=ctx.obj.current.user_id, + pat=ctx.obj.current.pat, + base_url=ctx.obj.current.api_base, + ) + ] + else: + user = User( user_id=ctx.obj.current.user_id, pat=ctx.obj.current.pat, base_url=ctx.obj.current.api_base, ) + all_ccs = user.list_compute_clusters(page_no, per_page) + ccs = [ + ComputeCluster( + compute_cluster_id=cc.id, + user_id=ctx.obj.current.user_id, + pat=ctx.obj.current.pat, + base_url=ctx.obj.current.api_base, + ) + for cc in all_ccs + ] + + nps = [] + for compute_cluster in ccs: nps.extend([i for i in compute_cluster.list_nodepools(page_no, per_page)]) response = [] for np in nps: diff --git a/clarifai/cli/model.py b/clarifai/cli/model.py index 5d6033ec..2d760ed0 100644 --- a/clarifai/cli/model.py +++ b/clarifai/cli/model.py @@ -651,7 +651,6 @@ def init( files_to_download[i] = f"{i + 1}. {file}" files_to_download = '\n'.join(files_to_download) logger.info(f"Files to be downloaded are:\n{files_to_download}") - input("Press Enter to continue...") if not toolkit: if folder_path != "": try: @@ -781,7 +780,6 @@ def init( # Fall back to template-based initialization if no GitHub repo or if GitHub repo failed if not github_url: logger.info("Initializing model with default templates...") - input("Press Enter to continue...") from clarifai.cli.base import input_or_default from clarifai.cli.templates.model_templates import ( @@ -908,8 +906,14 @@ def _ensure_hf_token(ctx, model_path): required=False, help='Target platform(s) for Docker image build (e.g., "linux/amd64" or "linux/amd64,linux/arm64"). This overrides the platform specified in config.yaml.', ) +@click.option( + '--autodeploy', + is_flag=True, + default=False, + help='If provided, automatically walk through the creation of a deployment after uploading.', +) @click.pass_context -def upload(ctx, model_path, stage, skip_dockerfile, platform): +def upload(ctx, model_path, stage, skip_dockerfile, platform, autodeploy): """Upload a model to Clarifai. MODEL_PATH: Path to the model directory. If not specified, the current directory is used by default. @@ -925,6 +929,7 @@ def upload(ctx, model_path, stage, skip_dockerfile, platform): stage, skip_dockerfile, platform=platform, + autodeploy=autodeploy, pat=ctx.obj.current.pat, base_url=ctx.obj.current.api_base, ) @@ -1298,10 +1303,10 @@ def local_runner(ctx, model_path, pool_size, suppress_toolkit_logs, mode, keep_i raise except Exception as e: logger.warning(f"Failed to get compute cluster with ID '{compute_cluster_id}':\n{e}") - y = input( - f"Compute cluster not found. Do you want to create a new compute cluster {user_id}/{compute_cluster_id}? (y/n): " - ) - if y.lower() != 'y': + if not prompt_yes_no( + f"Compute cluster not found. Do you want to create a new compute cluster {user_id}/{compute_cluster_id}?", + default=True, + ): raise click.Abort() # Create a compute cluster with default configuration for local runner. compute_cluster = user.create_compute_cluster( @@ -1327,10 +1332,10 @@ def local_runner(ctx, model_path, pool_size, suppress_toolkit_logs, mode, keep_i ctx.obj.to_yaml() # save to yaml file. except Exception as e: logger.warning(f"Failed to get nodepool with ID '{nodepool_id}':\n{e}") - y = input( - f"Nodepool not found. Do you want to create a new nodepool {user_id}/{compute_cluster_id}/{nodepool_id}? (y/n): " - ) - if y.lower() != 'y': + if not prompt_yes_no( + f"Nodepool not found. Do you want to create a new nodepool {user_id}/{compute_cluster_id}/{nodepool_id}?", + default=True, + ): raise click.Abort() nodepool = compute_cluster.create_nodepool( nodepool_config=DEFAULT_LOCAL_RUNNER_NODEPOOL_CONFIG, nodepool_id=nodepool_id @@ -1355,8 +1360,9 @@ def local_runner(ctx, model_path, pool_size, suppress_toolkit_logs, mode, keep_i ctx.obj.to_yaml() # save to yaml file. except Exception as e: logger.warning(f"Failed to get app with ID '{app_id}':\n{e}") - y = input(f"App not found. Do you want to create a new app {user_id}/{app_id}? (y/n): ") - if y.lower() != 'y': + if not prompt_yes_no( + f"App not found. Do you want to create a new app {user_id}/{app_id}?", default=True + ): raise click.Abort() app = user.create_app(app_id) ctx.obj.current.CLARIFAI_APP_ID = app_id @@ -1385,10 +1391,10 @@ def local_runner(ctx, model_path, pool_size, suppress_toolkit_logs, mode, keep_i raise Exception except Exception as e: logger.warning(f"Failed to get model with ID '{model_id}':\n{e}") - y = input( - f"Model not found. Do you want to create a new model {user_id}/{app_id}/models/{model_id}? (y/n): " - ) - if y.lower() != 'y': + if not prompt_yes_no( + f"Model not found. Do you want to create a new model {user_id}/{app_id}/models/{model_id}?", + default=True, + ): raise click.Abort() model = app.create_model(model_id, model_type_id=uploaded_model_type_id) @@ -1531,10 +1537,10 @@ def local_runner(ctx, model_path, pool_size, suppress_toolkit_logs, mode, keep_i ctx.obj.to_yaml() # save to yaml file. except Exception as e: logger.warning(f"Failed to get deployment with ID {deployment_id}:\n{e}") - y = input( - f"Deployment not found. Do you want to create a new deployment {user_id}/{compute_cluster_id}/{nodepool_id}/{deployment_id}? (y/n): " - ) - if y.lower() != 'y': + if not prompt_yes_no( + f"Deployment not found. Do you want to create a new deployment {user_id}/{compute_cluster_id}/{nodepool_id}/{deployment_id}?", + default=True, + ): raise click.Abort() nodepool.create_deployment( deployment_id=deployment_id, @@ -1565,10 +1571,10 @@ def local_runner(ctx, model_path, pool_size, suppress_toolkit_logs, mode, keep_i # The config.yaml doens't match what we created above. if 'model' in config and model_id != config['model'].get('id'): logger.info(f"Current model section of config.yaml: {config.get('model', {})}") - y = input( - "Do you want to backup config.yaml to config.yaml.bk then update the config.yaml with the new model information? (y/n): " - ) - if y.lower() != 'y': + if not prompt_yes_no( + "Do you want to backup config.yaml to config.yaml.bk then update the config.yaml with the new model information?", + default=True, + ): raise click.Abort() config = ModelBuilder._set_local_runner_model( config, user_id, app_id, model_id, uploaded_model_type_id diff --git a/clarifai/cli/nodepool.py b/clarifai/cli/nodepool.py index 7cbaf692..90d466ca 100644 --- a/clarifai/cli/nodepool.py +++ b/clarifai/cli/nodepool.py @@ -62,7 +62,9 @@ def create(ctx, compute_cluster_id, nodepool_id, config): @nodepool.command(['ls']) -@click.argument('compute_cluster_id', default="") +@click.option( + '--compute_cluster_id', required=False, help='Compute cluster ID to list nodepools for.' +) @click.option('--page_no', required=False, help='Page number to list.', default=1) @click.option('--per_page', required=False, help='Number of items per page.', default=128) @click.pass_context diff --git a/clarifai/client/deployment.py b/clarifai/client/deployment.py index 39895aa8..a1424a58 100644 --- a/clarifai/client/deployment.py +++ b/clarifai/client/deployment.py @@ -1,4 +1,7 @@ -from clarifai_grpc.grpc.api import resources_pb2 +from typing import Dict + +from clarifai_grpc.grpc.api import resources_pb2, service_pb2 +from clarifai_grpc.grpc.api.status import status_code_pb2 from clarifai.client.base import BaseClient from clarifai.client.lister import Lister @@ -32,8 +35,15 @@ def __init__( **kwargs: Additional keyword arguments to be passed to the deployment. """ self.kwargs = {**kwargs, 'id': deployment_id, 'user_id': user_id} + + # Filter kwargs to only include fields that exist in the Deployment proto + proto_fields = { + f.name for f in resources_pb2.Deployment.DESCRIPTOR.fields if f.name in self.kwargs + } + proto_kwargs = {k: self.kwargs[k] for k in proto_fields} + self.deployment_info = resources_pb2.Deployment() - dict_to_protobuf(self.deployment_info, self.kwargs) + dict_to_protobuf(self.deployment_info, proto_kwargs) self.logger = logger BaseClient.__init__( self, @@ -45,6 +55,18 @@ def __init__( ) Lister.__init__(self) + def refresh(self): + """Refresh the deployment info from the API.""" + request = service_pb2.GetDeploymentRequest( + user_app_id=self.user_app_id, deployment_id=self.id + ) + response = self._grpc_request(self.STUB.GetDeployment, request) + if response.status.code != status_code_pb2.SUCCESS: + raise Exception(f"Failed to get deployment: {response.status.details}") + + self.deployment_info.CopyFrom(response.deployment) + return self + @staticmethod def get_runner_selector(user_id: str, deployment_id: str) -> resources_pb2.RunnerSelector: """Returns a RunnerSelector object for the given deployment_id. @@ -70,3 +92,175 @@ def __str__(self): if hasattr(self.deployment_info, param) ] return f"Deployment Details: \n{', '.join(attribute_strings)}\n" + + def logs( + self, stream: bool = False, log_type: str = "runner", page: int = 1, per_page: int = 100 + ): + """Get logs for the deployment. + + Args: + stream (bool): Whether to stream the logs or list them. + log_type (str): The type of logs to retrieve. Defaults to "runner". + Valid types are "runner" and "runner.events". + page (int): The page number to list (only for list). + per_page (int): The number of items per page (only for list). + + Yields: + LogEntry: Log entry objects. + + Example: + >>> from clarifai.client.deployment import Deployment + >>> deployment = Deployment(deployment_id="deployment_id", user_id="user_id") + >>> for entry in deployment.logs(stream=True): + ... print(entry.message) + """ + if log_type not in ["runner", "runner.events"]: + raise ValueError( + f"Invalid log_type '{log_type}'. Valid types for deployment are 'runner' and 'runner.events'." + ) + + if not self.deployment_info.HasField("worker"): + self.refresh() + + request_kwargs = { + "user_app_id": self.user_app_id, + "log_type": log_type, + } + + # Add additional fields from deployment_info if they exist + if self.deployment_info.nodepools: + nodepool = self.deployment_info.nodepools[0] + request_kwargs["nodepool_id"] = nodepool.id + if nodepool.compute_cluster.id: + request_kwargs["compute_cluster_id"] = nodepool.compute_cluster.id + if nodepool.compute_cluster.user_id: + request_kwargs["compute_cluster_user_id"] = nodepool.compute_cluster.user_id + + if self.deployment_info.HasField("worker"): + worker = self.deployment_info.worker + if worker.HasField("model"): + request_kwargs["model_id"] = worker.model.id + if worker.model.model_version.id: + request_kwargs["model_version_id"] = worker.model.model_version.id + elif worker.HasField("workflow"): + request_kwargs["workflow_id"] = worker.workflow.id + if worker.workflow.workflow_version.id: + request_kwargs["workflow_version_id"] = worker.workflow.workflow_version.id + + if stream: + request = service_pb2.StreamLogEntriesRequest(**request_kwargs) + for response in self.STUB.StreamLogEntries(request): + if response.status.code != status_code_pb2.SUCCESS: + raise Exception(f"Failed to stream logs: {response}") + for entry in response.log_entries: + yield entry + else: + request_kwargs["page"] = page + request_kwargs["per_page"] = per_page + request = service_pb2.ListLogEntriesRequest(**request_kwargs) + response = self.STUB.ListLogEntries(request) + if response.status.code != status_code_pb2.SUCCESS: + raise Exception(f"Failed to list logs: {response}") + for entry in response.log_entries: + yield entry + + def patch(self, action: str = "overwrite", **kwargs): + """Patch the deployment. + + Args: + action (str): The action to perform (overwrite, merge, remove). Defaults to "overwrite". + **kwargs: The fields to patch on the deployment. + """ + deployment = resources_pb2.Deployment(id=self.id) + dict_to_protobuf(deployment, kwargs) + + request = service_pb2.PatchDeploymentsRequest( + user_app_id=self.user_app_id, action=action, deployments=[deployment] + ) + response = self._grpc_request(self.STUB.PatchDeployments, request) + if response.status.code != status_code_pb2.SUCCESS: + self.logger.error(f"PatchDeployments failed. Status: {response.status}") + raise Exception(f"Failed to patch deployment: {response}") + + # Update local deployment_info if success + dict_to_protobuf(self.deployment_info, kwargs) + return response + + def runner_metrics(self) -> Dict[str, int]: + """Get the accumulated runner metrics for the deployment. + + This aggregates runner metrics across all nodepools to find the total pods + running across all of them. + + Returns: + Dict[str, int]: A dictionary with 'pods_total' and 'pods_running'. + + Example: + >>> from clarifai.client.deployment import Deployment + >>> deployment = Deployment(deployment_id="deployment_id", user_id="user_id") + >>> print(deployment.runner_metrics()) + """ + if not self.deployment_info.worker.HasField( + "model" + ) and not self.deployment_info.worker.HasField("workflow"): + self.refresh() + + from clarifai.client.user import User + + user = User(user_id=self.user_app_id.user_id, pat=self.pat, base_url=self.base) + + model_version_ids = None + workflow_version_ids = None + if self.worker.HasField("model"): + model_version_ids = [self.worker.model.model_version.id] + elif self.worker.HasField("workflow"): + workflow_version_ids = [self.worker.workflow.workflow_version.id] + + pods_total = 0 + pods_running = 0 + + for np_proto in self.deployment_info.nodepools: + filter_by = { + "nodepool_id": np_proto.id, + "compute_cluster_id": np_proto.compute_cluster.id, + } + if model_version_ids: + filter_by["model_version_ids"] = model_version_ids + if workflow_version_ids: + filter_by["workflow_version_ids"] = workflow_version_ids + + runners = user.list_runners(filter_by=filter_by) + + for runner in runners: + metrics = runner.get("runner_metrics") + if metrics: + pods_total += metrics.get("pods_total", 0) + pods_running += metrics.get("pods_running", 0) + + return {"pods_total": pods_total, "pods_running": pods_running} + + def update(self, min_replicas: int = None, max_replicas: int = None): + """Update deployment replicas. + + Args: + min_replicas (int): The minimum number of replicas. + max_replicas (int): The maximum number of replicas. + + Example: + >>> from clarifai.client.deployment import Deployment + >>> deployment = Deployment(deployment_id="deployment_id", user_id="user_id") + >>> deployment.update(min_replicas=1, max_replicas=2) + """ + patch_kwargs = {} + if min_replicas is not None or max_replicas is not None: + autoscale_config = {} + if min_replicas is not None: + autoscale_config["min_replicas"] = min_replicas + if max_replicas is not None: + autoscale_config["max_replicas"] = max_replicas + patch_kwargs["autoscale_config"] = autoscale_config + + if not patch_kwargs: + return + + return self.patch(action="overwrite", **patch_kwargs) diff --git a/clarifai/client/model.py b/clarifai/client/model.py index 7267ab47..b637414f 100644 --- a/clarifai/client/model.py +++ b/clarifai/client/model.py @@ -122,6 +122,7 @@ def __init__( Lister.__init__(self) self.deployment_user_id = deployment_user_id + self.deployment_id = deployment_id self.load_info(validate=True) @@ -728,6 +729,57 @@ def _set_runner_selector( # set the runner selector self._runner_selector = runner_selector + def logs( + self, stream: bool = False, log_type: str = "runner", page: int = 1, per_page: int = 100 + ): + """Get logs for the model through its deployment. + + Args: + stream (bool): Whether to stream the logs or list them. + log_type (str): The type of logs to retrieve. Defaults to "runner". Use "builder" for build logs. + page (int): The page number to list (only for list). + per_page (int): The number of items per page (only for list). + + Yields: + LogEntry: Log entry objects. + + Example: + >>> from clarifai.client.model import Model + >>> model = Model(model_id="model_id", deployment_id="deployment_id") + >>> for entry in model.logs(stream=True): + ... print(entry.message) + """ + if not self.deployment_id: + raise UserError( + "Model object must be initialized with a deployment_id or " + "from_current_context() to access logs." + ) + + from clarifai.client.deployment import Deployment + + user_id = self.deployment_user_id + if not user_id: + from clarifai.client.user import User + + user_id = ( + User(pat=self.auth_helper.pat, token=self.auth_helper._token) + .get_user_info(user_id='me') + .user.id + ) + + deployment = Deployment( + deployment_id=self.deployment_id, + user_id=user_id, + base_url=self.base, + pat=self.pat, + token=self.token, + root_certificates_path=self.root_certificates_path, + ) + for entry in deployment.logs( + stream=stream, log_type=log_type, page=page, per_page=per_page + ): + yield entry + def predict_by_filepath( self, filepath: str, diff --git a/clarifai/client/nodepool.py b/clarifai/client/nodepool.py index 5fb46113..58256238 100644 --- a/clarifai/client/nodepool.py +++ b/clarifai/client/nodepool.py @@ -92,7 +92,6 @@ def _process_deployment_config(self, deployment_config: Dict[str, Any]) -> Dict[ assert ("worker" in deployment) and ( ("model" in deployment["worker"]) or ("workflow" in deployment["worker"]) ), "worker info not found in the config file" - assert "scheduling_choice" in deployment, "scheduling_choice not found in the config file" assert "nodepools" in deployment, "nodepools not found in the config file" deployment['user_id'] = ( deployment['user_id'] if 'user_id' in deployment else self.user_app_id.user_id @@ -154,12 +153,14 @@ def create_deployment( config_filepath: str = None, deployment_id: str = None, deployment_config: Dict[str, Any] = None, + wait: bool = True, ) -> Deployment: """Creates a deployment for the nodepool. Args: config_filepath (str): The path to the deployment config file. deployment_id (str): New deployment ID for the deployment to create. + wait (bool): Whether to wait for the deployment to be successful. Defaults to True. Returns: Deployment: A Deployment object for the specified deployment ID. @@ -186,16 +187,23 @@ def create_deployment( else: raise AssertionError("Either config_filepath or deployment_config must be provided.") - deployment_config = self._process_deployment_config(deployment_config) + # Extract min_replicas before processing config as it might be mutated into proto objects + min_replicas = int( + deployment_config.get('deployment', {}) + .get('autoscale_config', {}) + .get('min_replicas', 1) + ) + + processed_config = self._process_deployment_config(deployment_config) - if 'id' in deployment_config: + if 'id' in processed_config: if deployment_id is None: - deployment_id = deployment_config['id'] - deployment_config.pop('id') + deployment_id = processed_config['id'] + processed_config.pop('id') request = service_pb2.PostDeploymentsRequest( user_app_id=self.user_app_id, - deployments=[resources_pb2.Deployment(id=deployment_id, **deployment_config)], + deployments=[resources_pb2.Deployment(id=deployment_id, **processed_config)], ) response = self._grpc_request(self.STUB.PostDeployments, request) if response.status.code != status_code_pb2.SUCCESS: @@ -208,7 +216,65 @@ def create_deployment( response.deployments[0], preserving_proto_field_name=True, use_integers_for_enums=True ) kwargs = self.process_response_keys(dict_response, "deployment") - return Deployment.from_auth_helper(auth=self.auth_helper, **kwargs) + deployment = Deployment.from_auth_helper(auth=self.auth_helper, **kwargs) + + if wait: + if min_replicas == 0: + self.logger.warning( + "min_replicas is set to 0. Actual replicas of this model " + "will not be deployed until a prediction request is received. This saves " + "on infrastructure costs but will result in longer warmup time for the " + "first prediction." + ) + else: + import threading + import time + + stop_event = threading.Event() + logs_received = threading.Event() + + def stream_logs(log_type, prefix): + while not stop_event.is_set(): + try: + for entry in deployment.logs(stream=True, log_type=log_type): + if stop_event.is_set(): + break + logs_received.set() + timestamp = entry.time.ToDatetime().isoformat() + self.logger.info(f"[{prefix}] {timestamp} {entry.message}") + except Exception: + # Stream might fail if runner is not yet provisioned; wait and retry + if not stop_event.is_set(): + time.sleep(2) + + if not stop_event.is_set(): + time.sleep(1) # Brief pause before retrying if iterator ends normally + + self.logger.info( + f"Waiting for deployment '{deployment_id}' to reach {min_replicas} running replicas..." + ) + self.logger.info("Streaming logs (runner and runner.events):\n") + + threads = [] + for log_type, prefix in [("runner", "RUNNER"), ("runner.events", "EVENT")]: + t = threading.Thread(target=stream_logs, args=(log_type, prefix), daemon=True) + t.start() + threads.append(t) + + try: + while True: + metrics = deployment.runner_metrics() + # Exit only if replicas are running AND we've seen at least one log entry + if metrics["pods_running"] >= min_replicas and logs_received.is_set(): + self.logger.info( + f"Deployment '{deployment_id}' is successful! ({metrics['pods_running']} replicas running)" + ) + break + time.sleep(5) + finally: + stop_event.set() + + return deployment def deployment(self, deployment_id: str) -> Deployment: """Returns a Deployment object for the existing deployment ID. diff --git a/clarifai/client/pipeline.py b/clarifai/client/pipeline.py index 9e9aa513..7e7dc7c8 100644 --- a/clarifai/client/pipeline.py +++ b/clarifai/client/pipeline.py @@ -192,6 +192,56 @@ def monitor_only(self, timeout: int = 3600, monitor_interval: int = 10) -> Dict: # Monitor the existing run return self._monitor_pipeline_run(self.pipeline_version_run_id, timeout, monitor_interval) + def logs( + self, + stream: bool = False, + log_type: str = "pipeline.version.run", + page: int = 1, + per_page: int = 100, + ): + """Get logs for the pipeline version run. + + Args: + stream (bool): Whether to stream the logs or list them. + log_type (str): The type of logs to retrieve. Defaults to "pipeline.version.run". + page (int): The page number to list (only for list). + per_page (int): The number of items per page (only for list). + + Yields: + LogEntry: Log entry objects. + + Example: + >>> from clarifai.client.pipeline import Pipeline + >>> pipeline = Pipeline(pipeline_id="pipeline_id", user_id="user_id", app_id="app_id") + >>> for entry in pipeline.logs(stream=True): + ... print(entry.message) + """ + request_kwargs = { + "user_app_id": self.user_app_id, + "log_type": log_type, + "pipeline_id": self.pipeline_id, + "pipeline_version_id": self.pipeline_version_id or "", + } + if self.pipeline_version_run_id: + request_kwargs["pipeline_version_run_id"] = self.pipeline_version_run_id + + if stream: + request = service_pb2.StreamLogEntriesRequest(**request_kwargs) + for response in self.STUB.StreamLogEntries(request): + if response.status.code != status_code_pb2.SUCCESS: + raise Exception(f"Failed to stream logs: {response.status.details}") + for entry in response.log_entries: + yield entry + else: + request_kwargs["page"] = page + request_kwargs["per_page"] = per_page + request = service_pb2.ListLogEntriesRequest(**request_kwargs) + response = self.STUB.ListLogEntries(request) + if response.status.code != status_code_pb2.SUCCESS: + raise Exception(f"Failed to list logs: {response.status.details}") + for entry in response.log_entries: + yield entry + def _monitor_pipeline_run(self, run_id: str, timeout: int, monitor_interval: int) -> Dict: """Monitor a pipeline version run until completion. diff --git a/clarifai/runners/dockerfile_template/Dockerfile.template b/clarifai/runners/dockerfile_template/Dockerfile.template index 3de761d4..fc348653 100644 --- a/clarifai/runners/dockerfile_template/Dockerfile.template +++ b/clarifai/runners/dockerfile_template/Dockerfile.template @@ -7,7 +7,7 @@ FROM --platform=$BUILDPLATFORM ${DOWNLOADER_IMAGE} as model-assets # Install minimal tools needed for download -RUN pip install --no-cache-dir clarifai==${CLARIFAI_VERSION} huggingface_hub +RUN pip install --no-cache-dir clarifai==${CLARIFAI_VERSION} huggingface_hub[hf_transfer] WORKDIR /home/nonroot/main diff --git a/clarifai/runners/models/model_builder.py b/clarifai/runners/models/model_builder.py index 457d8892..d30ac519 100644 --- a/clarifai/runners/models/model_builder.py +++ b/clarifai/runners/models/model_builder.py @@ -81,30 +81,43 @@ def is_related(object_class, main_class): def get_user_input(prompt, required=True, default=None): """Get user input with optional default value.""" - if default: - prompt = f"{prompt} [{default}]: " + if not sys.stdin.isatty(): + if default is not None: + logger.info(f"{prompt} [Using default: {default}]") + return default + if not required: + return "" + else: + raise UserError(f"Input required for prompt: '{prompt}' but stdin is not a TTY.") + + if default is not None: + prompt_text = f"{prompt} [{default}]: " else: - prompt = f"{prompt}: " + prompt_text = f"{prompt}: " while True: - value = input(prompt).strip() - if not value and default: - return default - if not value and required: - print("❌ This field is required. Please enter a value.") - continue + value = input(prompt_text).strip() + if not value: + if default is not None: + return default + if required: + print("❌ This field is required. Please enter a value.") + continue + return "" return value def get_yes_no_input(prompt, default=None): """Get yes/no input from user.""" - if default is not None: - prompt = f"{prompt} [{'Y/n' if default else 'y/N'}]: " - else: - prompt = f"{prompt} [y/n]: " + if not sys.stdin.isatty(): + if default is not None: + logger.info(f"{prompt} [Using default: {'Y/n' if default else 'y/N'}]") + return default + raise UserError(f"Input required for prompt: '{prompt}' but stdin is not a TTY.") + full_prompt = f"{prompt} [{'Y/n' if default else 'y/N' if default is not None else 'y/n'}]: " while True: - response = input(prompt).strip().lower() + response = input(full_prompt).strip().lower() if not response and default is not None: return default if response in ['y', 'yes']: @@ -119,9 +132,7 @@ def select_compute_option(user_id: str, pat: Optional[str] = None, base_url: Opt Dynamically list compute-clusters and node-pools that belong to `user_id` and return a dict with nodepool_id, compute_cluster_id, cluster_user_id. """ - user = User( - user_id=user_id, pat=pat, base_url=base_url - ) # PAT / BASE URL are picked from env-vars + user = User(user_id=user_id, pat=pat, base_url=base_url) clusters = list(user.list_compute_clusters()) if not clusters: print("❌ No compute clusters found for this user.") @@ -130,15 +141,14 @@ def select_compute_option(user_id: str, pat: Optional[str] = None, base_url: Opt for idx, cc in enumerate(clusters, 1): desc = getattr(cc, "description", "") or "No description" print(f"{idx}. {cc.id} – {desc}") - while True: - try: - sel = int(input("Select compute cluster (number): ")) - 1 - if 0 <= sel < len(clusters): - cluster = clusters[sel] - break - print("❌ Invalid selection.") - except ValueError: - print("❌ Please enter a number.") + + if len(clusters) > 1: + cluster_idx = int(get_user_input("Select compute cluster (enter number)", default="1")) + cluster = clusters[cluster_idx - 1] + else: + cluster = clusters[0] + logger.info(f"Selecting compute cluster: {cluster.id}") + nodepools = list(cluster.list_nodepools()) if not nodepools: print("❌ No nodepools in selected cluster.") @@ -147,15 +157,14 @@ def select_compute_option(user_id: str, pat: Optional[str] = None, base_url: Opt for idx, np in enumerate(nodepools, 1): desc = getattr(np, "description", "") or "No description" print(f"{idx}. {np.id} – {desc}") - while True: - try: - sel = int(input("Select nodepool (number): ")) - 1 - if 0 <= sel < len(nodepools): - nodepool = nodepools[sel] - break - print("❌ Invalid selection.") - except ValueError: - print("❌ Please enter a number.") + + if len(nodepools) > 1: + nodepool_idx = int(get_user_input("Select nodepool (enter number)", default="1")) + nodepool = nodepools[nodepool_idx - 1] + else: + nodepool = nodepools[0] + logger.info(f"Selecting nodepool: {nodepool.id}") + return { "nodepool_id": nodepool.id, "compute_cluster_id": cluster.id, @@ -423,19 +432,31 @@ def _validate_config_checkpoints(self): ) assert loader_type == "huggingface", "Only huggingface loader supported for now" if loader_type == "huggingface": - assert "repo_id" in self.config.get("checkpoints"), ( - "No repo_id specified in the config file" - ) - repo_id = self.config.get("checkpoints").get("repo_id") + assert "repo_id" in checkpoints, "No repo_id specified in the config file" + repo_id = checkpoints.get("repo_id") + + # Priority: 1. config.yaml, 2. HF_TOKEN env var, 3. User prompt + hf_token = checkpoints.get("hf_token") + if not hf_token: + hf_token = os.environ.get("HF_TOKEN") + if hf_token: + logger.info("Using HF_TOKEN from environment variable.") + elif sys.stdin.isatty(): + hf_token = get_user_input( + "Hugging Face token not found. Please enter it (optional, press enter to skip)", + required=False, + ) - # get from config.yaml otherwise fall back to HF_TOKEN env var. - hf_token = self.config.get("checkpoints").get( - "hf_token", os.environ.get("HF_TOKEN", None) - ) + # Update config file if a token was found elsewhere. + if hf_token and hf_token != checkpoints.get("hf_token"): + self.config["checkpoints"]["hf_token"] = hf_token + try: + self._save_config(os.path.join(self.folder, "config.yaml"), self.config) + logger.info("Updated config.yaml with Hugging Face token.") + except Exception as e: + logger.warning(f"Could not update config.yaml with Hugging Face token: {e}") - allowed_file_patterns = self.config.get("checkpoints").get( - 'allowed_file_patterns', None - ) + allowed_file_patterns = checkpoints.get("allowed_file_patterns", None) if isinstance(allowed_file_patterns, str): allowed_file_patterns = [allowed_file_patterns] ignore_file_patterns = self.config.get("checkpoints").get('ignore_file_patterns', None) @@ -488,8 +509,7 @@ def create_app(): logger.info(f"App {app_id} not found for user {user_id}.") if self.app_not_found_action == "prompt": - create_app_prompt = input(f"Do you want to create App `{app_id}`? (y/n): ") - if create_app_prompt.lower() == 'y': + if get_yes_no_input(f"Do you want to create App `{app_id}`?", True): create_app() return True else: @@ -1377,16 +1397,13 @@ def create_dockerfile(self, generate_dockerfile=False): ) should_create_dockerfile = False else: - logger.info("Dockerfile already exists with different content.") - response = input( - "A different Dockerfile already exists. Do you want to overwrite it with the generated one? " - "Type 'y' to overwrite, 'n' to keep your custom Dockerfile: " - ) - if response.lower() != 'y': - logger.info("Keeping existing custom Dockerfile.") - should_create_dockerfile = False + logger.warning("A different Dockerfile already exists.") + if get_yes_no_input( + "Do you want to overwrite the existing Dockerfile?", False + ): + should_create_dockerfile = True else: - logger.info("Overwriting existing Dockerfile with generated content.") + should_create_dockerfile = False if should_create_dockerfile: # Write Dockerfile @@ -1579,11 +1596,7 @@ def _check_git_status_and_prompt(self) -> bool: if status_result.stdout.strip(): logger.warning("Uncommitted changes detected in model path:") logger.warning(status_result.stdout) - - response = input( - "\nDo you want to continue upload with uncommitted changes? (y/N): " - ) - return response.lower() in ['y', 'yes'] + return True else: logger.info("Model path has no uncommitted changes.") return True @@ -1721,8 +1734,8 @@ def upload_model_version(self, git_info=None): if when != "upload" and not HuggingFaceLoader.validate_config( self.checkpoint_path ): - input( - "Press Enter to download the HuggingFace model's config.json file to infer the concepts and continue..." + logger.info( + "Downloading the HuggingFace model's config.json file to infer the concepts." ) loader = HuggingFaceLoader(repo_id=repo_id, token=hf_token) loader.download_config(self.checkpoint_path) @@ -1935,6 +1948,7 @@ def upload_model( stage, skip_dockerfile, platform: Optional[str] = None, + autodeploy: bool = False, pat: Optional[str] = None, base_url: Optional[str] = None, ): @@ -1945,6 +1959,7 @@ def upload_model( :param stage: The stage we are calling download checkpoints from. Typically this would "upload" and will download checkpoints if config.yaml checkpoints section has when set to "upload". Other options include "runtime" to be used in load_model or "upload" to be used during model upload. Set this stage to whatever you have in config.yaml to force downloading now. :param skip_dockerfile: If True, will skip Dockerfile generation entirely. If False or not provided, intelligently handle existing Dockerfiles with user confirmation. :param platform: Target platform(s) for Docker image build (e.g., "linux/amd64" or "linux/amd64,linux/arm64"). This overrides the platform specified in config.yaml. + :param autodeploy: If True, will automatically setup deployment for the model after upload. :param pat: Personal access token for authentication. If None, will use environment variables. :param base_url: Base URL for the API. If None, will use environment variables. """ @@ -1977,18 +1992,18 @@ def upload_model( if not builder._check_git_status_and_prompt(): logger.info("Upload cancelled by user due to uncommitted changes.") return - input("Press Enter to continue...") - model_version = builder.upload_model_version(git_info) + logger.info("Ready to upload. Starting model version upload.") + builder.upload_model_version(git_info) - # Ask user if they want to deploy the model - if model_version is not None: # if it comes back None then it failed. - if get_yes_no_input("\n🔶 Do you want to deploy the model?", True): - # Setup deployment for the uploaded model - setup_deployment_for_model(builder) - else: - logger.info("Model uploaded successfully. Skipping deployment setup.") - return + if autodeploy: + # Setup deployment for the uploaded model + setup_deployment_for_model(builder) + else: + logger.info( + "Model uploaded successfully. Skipping deployment setup, you can create a deployment in UI or CLI." + ) + return def deploy_model( @@ -2129,6 +2144,7 @@ def setup_deployment_for_model(builder): deployment_id = get_user_input( "Enter deployment ID", default=f"deploy-{state['model_id']}-{uuid.uuid4().hex[:6]}" ) + min_replicas = int(get_user_input("Enter minimum replicas", default="1")) max_replicas = int(get_user_input("Enter maximum replicas", default="5")) diff --git a/clarifai/runners/pipeline_steps/pipeline_step_builder.py b/clarifai/runners/pipeline_steps/pipeline_step_builder.py index 41af62be..ebe36a6c 100644 --- a/clarifai/runners/pipeline_steps/pipeline_step_builder.py +++ b/clarifai/runners/pipeline_steps/pipeline_step_builder.py @@ -616,8 +616,6 @@ def upload_pipeline_step(folder, skip_dockerfile=False): f"New pipeline step {builder.pipeline_step_id} will be created with its first version." ) - input("Press Enter to continue...") - success = builder.upload_pipeline_step_version() if success: logger.info("Pipeline step upload completed successfully!") diff --git a/clarifai/utils/cli.py b/clarifai/utils/cli.py index 846f4259..628f3a4e 100644 --- a/clarifai/utils/cli.py +++ b/clarifai/utils/cli.py @@ -632,14 +632,17 @@ def prompt_required_field(message: str, default: Optional[str] = None) -> str: Returns: str: The value entered by the user. """ + if default: + logger.info(f"{message} [Using default: {default}]") + return default + + if not sys.stdin.isatty(): + logger.error(f"{message} [Non-interactive: required field missing, aborting]") + raise click.Abort() + while True: - prompt = f"{message}" - if default: - prompt += f" [{default}]" - prompt += ": " + prompt = f"{message}: " value = input(prompt).strip() - if not value and default: - return default if value: return value click.echo("❌ This field is required. Please enter a value.") @@ -655,10 +658,14 @@ def prompt_optional_field(message: str, default: Optional[str] = None) -> Option Returns: Optional[str]: The value entered by the user. """ - prompt = f"{message}" - if default: - prompt += f" [{default}]" - prompt += ": " + if default is not None: + logger.info(f"{message} [Using default: {default}]") + return default + + if not sys.stdin.isatty(): + return default + + prompt = f"{message}: " value = input(prompt).strip() if not value: return default @@ -675,11 +682,15 @@ def prompt_int_field(message: str, default: Optional[int] = None) -> int: Returns: int: The value entered by the user. """ + if default is not None: + logger.info(f"{message} [Using default: {default}]") + return default + + if not sys.stdin.isatty(): + raise click.Abort() + while True: - prompt = f"{message}" - if default is not None: - prompt += f" [{default}]" - prompt += ": " + prompt = f"{message}: " raw = input(prompt).strip() if not raw and default is not None: return default @@ -699,17 +710,19 @@ def prompt_yes_no(message: str, default: Optional[bool] = None) -> bool: Returns: bool: The value entered by the user. """ - if default is True: - suffix = " [Y/n]" - elif default is False: - suffix = " [y/N]" - else: - suffix = " [y/n]" + if default is not None: + logger.info(f"{message} [Using default: {'Y/n' if default else 'y/N'}]") + return default + + if not sys.stdin.isatty(): + res = default if default is not None else True + logger.info(f"{message} [Non-interactive: using {res}]") + return res + + suffix = " [y/n]" prompt = f"{message}{suffix}: " while True: response = input(prompt).strip().lower() - if not response and default is not None: - return default if response in ("y", "yes"): return True if response in ("n", "no"): diff --git a/tests/cli/test_compute_orchestration.py b/tests/cli/test_compute_orchestration.py index 40b06f88..39b171c2 100644 --- a/tests/cli/test_compute_orchestration.py +++ b/tests/cli/test_compute_orchestration.py @@ -159,8 +159,6 @@ def test_create_deployment(self, cli_runner): [ "deployment", "create", - CREATE_NODEPOOL_ID, - CREATE_DEPLOYMENT_ID, "--config", DEPLOYMENT_CONFIG_FILE, ], @@ -175,13 +173,43 @@ def test_list_compute_clusters(self, cli_runner): def test_list_nodepools(self, cli_runner): cli_runner.invoke(cli, ["login", "--env", CLARIFAI_ENV]) - result = cli_runner.invoke(cli, ["nodepool", "list", CREATE_COMPUTE_CLUSTER_ID]) + result = cli_runner.invoke( + cli, ["nodepool", "list", "--compute_cluster_id", CREATE_COMPUTE_CLUSTER_ID] + ) assert result.exit_code == 0, logger.exception(result) assert "USER_ID" in result.output def test_list_deployments(self, cli_runner): cli_runner.invoke(cli, ["login", "--env", CLARIFAI_ENV]) - result = cli_runner.invoke(cli, ["deployment", "list", CREATE_NODEPOOL_ID]) + result = cli_runner.invoke( + cli, ["deployment", "list", "--nodepool_id", CREATE_NODEPOOL_ID] + ) + + assert result.exit_code == 0, logger.exception(result) + assert "USER_ID" in result.output + + def test_list_deployments_with_cluster_id(self, cli_runner): + cli_runner.invoke(cli, ["login", "--env", CLARIFAI_ENV]) + result = cli_runner.invoke( + cli, ["deployment", "list", "--compute_cluster_id", CREATE_COMPUTE_CLUSTER_ID] + ) + + assert result.exit_code == 0, logger.exception(result) + assert "USER_ID" in result.output + + def test_list_deployments_with_nodepool_and_cluster_id(self, cli_runner): + cli_runner.invoke(cli, ["login", "--env", CLARIFAI_ENV]) + result = cli_runner.invoke( + cli, + [ + "deployment", + "list", + "--nodepool_id", + CREATE_NODEPOOL_ID, + "--compute_cluster_id", + CREATE_COMPUTE_CLUSTER_ID, + ], + ) assert result.exit_code == 0, logger.exception(result) assert "USER_ID" in result.output