Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 125 additions & 79 deletions clarifai/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1992,26 +1992,35 @@ def serve_cmd(ctx, model_path, grpc, mode, port, concurrency, keep_image, verbos
created['model'] = model_id
click.echo("done")

# 5. Model version (always created fresh — always cleaned up)
out.status("Creating model version... ", nl=False)
version_model = model.create_version(
pretrained_model_config={"local_dev": True},
method_signatures=method_signatures,
)
version_model.load_info()
version_id = version_model.model_version.id
created['model_version'] = version_id
click.echo(f"done ({version_id[:8]})")
# 5. Model version — reuse existing if model_version_id is in context
existing_version_id = None
try:
existing_version_id = ctx.obj.current.model_version_id
except AttributeError:
pass

# 6. Stale deployment cleanup (from previous crash)
deployment_id = f"local-{model_id}"
if existing_version_id:
version_id = existing_version_id
out.status(f"Model version ready ({version_id[:8]})")
else:
out.status("Creating model version... ", nl=False)
version_model = model.create_version(
pretrained_model_config={"local_dev": True},
method_signatures=method_signatures,
)
version_model.load_info()
version_id = version_model.model_version.id
created['model_version'] = version_id
click.echo(f"done ({version_id[:8]})")

# 6. Resolve deployment_id from context or default
existing_deployment_id = None
try:
nodepool.deployment(deployment_id)
nodepool.delete_deployments([deployment_id])
except Exception:
existing_deployment_id = ctx.obj.current.deployment_id
except AttributeError:
pass
deployment_id = existing_deployment_id or f"local-{model_id}"

# 7. Runner (always created fresh — always cleaned up)
worker = {
"model": {
"id": model_id,
Expand All @@ -2020,43 +2029,80 @@ def serve_cmd(ctx, model_path, grpc, mode, port, concurrency, keep_image, verbos
"app_id": app_id,
}
}
out.status("Creating runner... ", nl=False)
runner = nodepool.create_runner(
runner_config={
"runner": {
"description": f"local runner for {model_id}",
"worker": worker,
"num_replicas": 1,
}
}
)
runner_id = runner.id
created['runner'] = runner_id
click.echo("done")

# 8. Deployment (always created fresh — always cleaned up)
out.status("Creating deployment... ", nl=False)
nodepool.create_deployment(
deployment_id=deployment_id,
deployment_config={
"deployment": {
"scheduling_choice": 3,
"worker": worker,
"nodepools": [
{
"id": np_id,
"compute_cluster": {
"id": cc_id,
"user_id": user_id,
},
# 7. Runner — reuse existing if deployment already exists in context
if existing_deployment_id:
# Existing deployment implies an existing runner; find it
try:
runners = nodepool.list_runners()
runner_id = None
for r in runners:
runner_id = r.id
break
if not runner_id:
raise Exception("No runners found")
out.status(f"Runner ready ({runner_id[:8]})")
except Exception:
# Fallback: create a new runner
out.status("Creating runner... ", nl=False)
runner = nodepool.create_runner(
runner_config={
"runner": {
"description": f"local runner for {model_id}",
"worker": worker,
"num_replicas": 1,
}
],
"deploy_latest_version": True,
}
)
runner_id = runner.id
created['runner'] = runner_id
click.echo("done")
out.status(f"Deployment ready ({deployment_id})")
else:
# Stale deployment cleanup (from previous crash)
try:
nodepool.deployment(deployment_id)
nodepool.delete_deployments([deployment_id])
except Exception:
pass

out.status("Creating runner... ", nl=False)
runner = nodepool.create_runner(
runner_config={
"runner": {
"description": f"local runner for {model_id}",
"worker": worker,
"num_replicas": 1,
}
}
},
)
created['deployment'] = deployment_id
click.echo("done")
)
runner_id = runner.id
created['runner'] = runner_id
click.echo("done")

# 8. Deployment (always created fresh — always cleaned up)
out.status("Creating deployment... ", nl=False)
nodepool.create_deployment(
deployment_id=deployment_id,
deployment_config={
"deployment": {
"scheduling_choice": 3,
"worker": worker,
"nodepools": [
{
"id": np_id,
"compute_cluster": {
"id": cc_id,
"user_id": user_id,
},
}
],
"deploy_latest_version": True,
}
},
)
created['deployment'] = deployment_id
click.echo("done")

# Toolkit customization (before serving)
if config.get('toolkit', {}).get('provider') == 'ollama':
Expand Down Expand Up @@ -2089,35 +2135,35 @@ def serve_cmd(ctx, model_path, grpc, mode, port, concurrency, keep_image, verbos

def _cleanup():
out.phase_header("Stopping")
with _quiet_sdk_logger(suppress):
if 'deployment' in created:
out.status("Deleting deployment... ", nl=False)
try:
nodepool.delete_deployments([created['deployment']])
click.echo("done")
except Exception:
click.echo("failed")
if 'runner' in created:
out.status("Deleting runner... ", nl=False)
try:
nodepool.delete_runners([created['runner']])
click.echo("done")
except Exception:
click.echo("failed")
if 'model_version' in created:
out.status("Deleting model version... ", nl=False)
try:
model.delete_version(version_id=created['model_version'])
click.echo("done")
except Exception:
click.echo("failed")
if 'model' in created:
out.status("Deleting model... ", nl=False)
try:
app.delete_model(created['model'])
click.echo("done")
except Exception:
click.echo("failed")
# with _quiet_sdk_logger(suppress):
# if 'deployment' in created:
# out.status("Deleting deployment... ", nl=False)
# try:
# nodepool.delete_deployments([created['deployment']])
# click.echo("done")
# except Exception:
# click.echo("failed")
# if 'runner' in created:
# out.status("Deleting runner... ", nl=False)
# try:
# nodepool.delete_runners([created['runner']])
# click.echo("done")
# except Exception:
# click.echo("failed")
# if 'model_version' in created:
# out.status("Deleting model version... ", nl=False)
# try:
# model.delete_version(version_id=created['model_version'])
# click.echo("done")
# except Exception:
# click.echo("failed")
# if 'model' in created:
# out.status("Deleting model... ", nl=False)
# try:
# app.delete_model(created['model'])
# click.echo("done")
# except Exception:
# click.echo("failed")
out.status("Stopped.")

def _do_cleanup():
Expand Down
24 changes: 15 additions & 9 deletions clarifai/runners/models/openai_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,18 +240,17 @@ def _update_old_fields(self, request_data: Dict[str, Any]) -> Dict[str, Any]:

Note: this updates the request data in place and returns it.
"""
# Sync max_tokens and max_completion_tokens, preferring max_completion_tokens if both exist
max_tokens = request_data.get('max_tokens')
# Normalize to max_completion_tokens (current OpenAI standard).
# Many backends (TRT-LLM, vLLM) reject requests containing both fields,
# so we keep only max_completion_tokens after syncing.
max_tokens = request_data.pop('max_tokens', None)
max_completion_tokens = request_data.get('max_completion_tokens')

if max_completion_tokens is not None and max_tokens is not None:
# Both exist - prefer max_completion_tokens and sync max_tokens to it
request_data['max_tokens'] = max_completion_tokens
elif max_completion_tokens is not None:
# Only max_completion_tokens exists - copy to max_tokens for older backends
request_data['max_tokens'] = max_completion_tokens
if max_completion_tokens is not None:
# max_completion_tokens takes precedence
pass
elif max_tokens is not None:
# Only max_tokens exists - copy to max_completion_tokens for newer backends
# Only max_tokens was provided — promote to max_completion_tokens
request_data['max_completion_tokens'] = max_tokens
if 'top_p' in request_data:
request_data['top_p'] = float(request_data['top_p'])
Expand All @@ -274,6 +273,11 @@ def _update_old_fields(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
m.pop('file', None)
m.pop('panelId', None)

if 'logprobs' in request_data:
request_data.pop('logprobs')
if 'top_logprobs' in request_data:
request_data.pop('top_logprobs')

# Handle the "Currently only named tools are supported." error we see from trt-llm
if 'tools' in request_data and request_data['tools'] is None:
request_data.pop('tools', None)
Expand All @@ -293,6 +297,7 @@ def openai_transport(self, msg: str) -> str:
JSON string containing the response or error
"""
try:
logger.info("openai non-streaming request started...")
request_data = from_json(msg)
request_data = self._update_old_fields(request_data)
endpoint = request_data.pop("openai_endpoint", self.DEFAULT_ENDPOINT)
Expand Down Expand Up @@ -320,6 +325,7 @@ def openai_stream_transport(self, msg: str) -> Iterator[str]:
Iterator[str]: An iterator yielding text chunks from the streaming response.
"""
try:
logger.info("openai streaming request started...")
request_data = from_json(msg)
request_data = self._update_old_fields(request_data)
endpoint = request_data.pop("openai_endpoint", self.DEFAULT_ENDPOINT)
Expand Down
21 changes: 7 additions & 14 deletions tests/runners/test_openai_update_old_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,38 +11,34 @@ def setup_method(self):
self.model = DummyOpenAIModel()

def test_max_tokens_only(self):
"""Test that max_tokens is copied to max_completion_tokens when only max_tokens exists."""
"""Test that max_tokens is promoted to max_completion_tokens and removed."""
request_data = {"max_tokens": 100}
result = self.model._update_old_fields(request_data)

# Both fields should now be set
assert result.get("max_tokens") == 100
assert result.get("max_completion_tokens") == 100
assert "max_tokens" not in result

def test_max_completion_tokens_only(self):
"""Test that max_completion_tokens is copied to max_tokens when only max_completion_tokens exists."""
"""Test that max_completion_tokens is kept as-is, no max_tokens added."""
request_data = {"max_completion_tokens": 200}
result = self.model._update_old_fields(request_data)

# Both fields should now be set
assert result.get("max_tokens") == 200
assert result.get("max_completion_tokens") == 200
assert "max_tokens" not in result

def test_both_fields_present(self):
"""Test that max_completion_tokens is preferred when both fields are present."""
request_data = {"max_tokens": 100, "max_completion_tokens": 200}
result = self.model._update_old_fields(request_data)

# max_completion_tokens should be kept and max_tokens should be synced to it
assert result.get("max_completion_tokens") == 200
assert result.get("max_tokens") == 200
assert "max_tokens" not in result

def test_neither_field_present(self):
"""Test that no changes are made when neither field is present."""
request_data = {"temperature": 0.7}
result = self.model._update_old_fields(request_data)

# No max_tokens fields should be added
assert "max_tokens" not in result
assert "max_completion_tokens" not in result
assert result.get("temperature") == 0.7
Expand All @@ -52,9 +48,8 @@ def test_zero_values(self):
request_data = {"max_tokens": 0}
result = self.model._update_old_fields(request_data)

# Zero is a valid value and should be synced
assert result.get("max_tokens") == 0
assert result.get("max_completion_tokens") == 0
assert "max_tokens" not in result

def test_other_fields_unchanged(self):
"""Test that other fields in request_data are not affected."""
Expand All @@ -66,10 +61,8 @@ def test_other_fields_unchanged(self):
}
result = self.model._update_old_fields(request_data)

# Check that other fields are preserved
assert result.get("temperature") == 0.7
assert result.get("top_p") == 0.9
assert result.get("model") == "gpt-4"
# Check that syncing happened
assert result.get("max_tokens") == 100
assert result.get("max_completion_tokens") == 100
assert "max_tokens" not in result
Loading