Skip to content

Commit 8740d23

Browse files
authored
Compose mode and think flags (#492)
* feat: add reasoning budget handling and backend mode parsing to compose command * add docs * feat: add helper function for creating int32 pointers and refactor reasoning budget assignment
1 parent b4512b5 commit 8740d23

File tree

3 files changed

+268
-0
lines changed

3 files changed

+268
-0
lines changed

cmd/cli/commands/compose.go

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,28 @@ func newComposeCmd() *cobra.Command {
3434
return c
3535
}
3636

37+
// Reasoning budget constants for the think parameter conversion
38+
const (
39+
reasoningBudgetUnlimited int32 = -1
40+
reasoningBudgetDisabled int32 = 0
41+
reasoningBudgetMedium int32 = 1024
42+
reasoningBudgetLow int32 = 256
43+
)
44+
45+
// ptr is a helper function to create a pointer to int32
46+
func ptr(v int32) *int32 {
47+
return &v
48+
}
49+
3750
func newUpCommand() *cobra.Command {
3851
var models []string
3952
var ctxSize int64
4053
var backend string
4154
var draftModel string
4255
var numTokens int
4356
var minAcceptanceRate float64
57+
var mode string
58+
var think string
4459
c := &cobra.Command{
4560
Use: "up",
4661
RunE: func(cmd *cobra.Command, args []string) error {
@@ -81,6 +96,30 @@ func newUpCommand() *cobra.Command {
8196
sendInfo(fmt.Sprintf("Enabling speculative decoding with draft model: %s", draftModel))
8297
}
8398

99+
// Parse mode if provided
100+
var backendMode *inference.BackendMode
101+
if mode != "" {
102+
parsedMode, err := parseBackendMode(mode)
103+
if err != nil {
104+
_ = sendError(err.Error())
105+
return err
106+
}
107+
backendMode = &parsedMode
108+
sendInfo(fmt.Sprintf("Setting backend mode to %s", mode))
109+
}
110+
111+
// Parse think parameter for reasoning budget
112+
var reasoningBudget *int32
113+
if think != "" {
114+
budget, err := parseThinkToReasoningBudget(think)
115+
if err != nil {
116+
_ = sendError(err.Error())
117+
return err
118+
}
119+
reasoningBudget = budget
120+
sendInfo(fmt.Sprintf("Setting think mode to %s", think))
121+
}
122+
84123
for _, model := range models {
85124
configuration := inference.BackendConfiguration{
86125
Speculative: speculativeConfig,
@@ -91,8 +130,17 @@ func newUpCommand() *cobra.Command {
91130
configuration.ContextSize = &v
92131
}
93132

133+
// Set llama.cpp-specific reasoning budget if provided
134+
if reasoningBudget != nil {
135+
if configuration.LlamaCpp == nil {
136+
configuration.LlamaCpp = &inference.LlamaCppConfig{}
137+
}
138+
configuration.LlamaCpp.ReasoningBudget = reasoningBudget
139+
}
140+
94141
if err := desktopClient.ConfigureBackend(scheduling.ConfigureRequest{
95142
Model: model,
143+
Mode: backendMode,
96144
BackendConfiguration: configuration,
97145
}); err != nil {
98146
configErrFmtString := "failed to configure backend for model %s with context-size %d"
@@ -123,10 +171,57 @@ func newUpCommand() *cobra.Command {
123171
c.Flags().StringVar(&draftModel, "speculative-draft-model", "", "draft model for speculative decoding")
124172
c.Flags().IntVar(&numTokens, "speculative-num-tokens", 0, "number of tokens to predict speculatively")
125173
c.Flags().Float64Var(&minAcceptanceRate, "speculative-min-acceptance-rate", 0, "minimum acceptance rate for speculative decoding")
174+
c.Flags().StringVar(&mode, "mode", "", "backend operation mode (completion, embedding, reranking)")
175+
c.Flags().StringVar(&think, "think", "", "enable reasoning mode for thinking models (true/false/high/medium/low)")
126176
_ = c.MarkFlagRequired("model")
127177
return c
128178
}
129179

180+
// parseBackendMode parses a string mode value into an inference.BackendMode.
181+
func parseBackendMode(mode string) (inference.BackendMode, error) {
182+
switch strings.ToLower(mode) {
183+
case "completion":
184+
return inference.BackendModeCompletion, nil
185+
case "embedding":
186+
return inference.BackendModeEmbedding, nil
187+
case "reranking":
188+
return inference.BackendModeReranking, nil
189+
default:
190+
return inference.BackendModeCompletion, fmt.Errorf("invalid mode %q: must be one of completion, embedding, reranking", mode)
191+
}
192+
}
193+
194+
// parseThinkToReasoningBudget converts the think parameter string to a reasoning budget value.
195+
// Accepts: "true", "false", "high", "medium", "low"
196+
// Returns:
197+
// - nil for empty string or "true" (use server default, which is unlimited)
198+
// - -1 for "high" (explicitly set unlimited)
199+
// - 0 for "false" (disable thinking)
200+
// - 1024 for "medium"
201+
// - 256 for "low"
202+
func parseThinkToReasoningBudget(think string) (*int32, error) {
203+
if think == "" {
204+
return nil, nil
205+
}
206+
207+
switch strings.ToLower(think) {
208+
case "true":
209+
// Use nil to let the server use its default (currently unlimited)
210+
return nil, nil
211+
case "high":
212+
// Explicitly set unlimited reasoning budget
213+
return ptr(reasoningBudgetUnlimited), nil
214+
case "false":
215+
return ptr(reasoningBudgetDisabled), nil
216+
case "medium":
217+
return ptr(reasoningBudgetMedium), nil
218+
case "low":
219+
return ptr(reasoningBudgetLow), nil
220+
default:
221+
return nil, fmt.Errorf("invalid think value %q: must be one of true, false, high, medium, low", think)
222+
}
223+
}
224+
130225
func newDownCommand() *cobra.Command {
131226
c := &cobra.Command{
132227
Use: "down",

cmd/cli/commands/compose_test.go

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
package commands
2+
3+
import (
4+
"testing"
5+
6+
"github.com/docker/model-runner/pkg/inference"
7+
"github.com/stretchr/testify/assert"
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
func TestParseBackendMode(t *testing.T) {
12+
tests := []struct {
13+
name string
14+
input string
15+
expected inference.BackendMode
16+
expectError bool
17+
}{
18+
{
19+
name: "completion mode lowercase",
20+
input: "completion",
21+
expected: inference.BackendModeCompletion,
22+
expectError: false,
23+
},
24+
{
25+
name: "completion mode uppercase",
26+
input: "COMPLETION",
27+
expected: inference.BackendModeCompletion,
28+
expectError: false,
29+
},
30+
{
31+
name: "completion mode mixed case",
32+
input: "Completion",
33+
expected: inference.BackendModeCompletion,
34+
expectError: false,
35+
},
36+
{
37+
name: "embedding mode",
38+
input: "embedding",
39+
expected: inference.BackendModeEmbedding,
40+
expectError: false,
41+
},
42+
{
43+
name: "reranking mode",
44+
input: "reranking",
45+
expected: inference.BackendModeReranking,
46+
expectError: false,
47+
},
48+
{
49+
name: "invalid mode",
50+
input: "invalid",
51+
expected: inference.BackendModeCompletion, // default on error
52+
expectError: true,
53+
},
54+
{
55+
name: "empty string",
56+
input: "",
57+
expected: inference.BackendModeCompletion, // default on error
58+
expectError: true,
59+
},
60+
}
61+
62+
for _, tt := range tests {
63+
t.Run(tt.name, func(t *testing.T) {
64+
result, err := parseBackendMode(tt.input)
65+
if tt.expectError {
66+
require.Error(t, err)
67+
} else {
68+
require.NoError(t, err)
69+
assert.Equal(t, tt.expected, result)
70+
}
71+
})
72+
}
73+
}
74+
75+
func TestParseThinkToReasoningBudget(t *testing.T) {
76+
tests := []struct {
77+
name string
78+
input string
79+
expected *int32
80+
expectError bool
81+
}{
82+
{
83+
name: "empty string returns nil",
84+
input: "",
85+
expected: nil,
86+
expectError: false,
87+
},
88+
{
89+
name: "true returns nil (use server default)",
90+
input: "true",
91+
expected: nil,
92+
expectError: false,
93+
},
94+
{
95+
name: "TRUE returns nil (case insensitive)",
96+
input: "TRUE",
97+
expected: nil,
98+
expectError: false,
99+
},
100+
{
101+
name: "false disables reasoning",
102+
input: "false",
103+
expected: ptr(reasoningBudgetDisabled),
104+
expectError: false,
105+
},
106+
{
107+
name: "high explicitly sets unlimited (-1)",
108+
input: "high",
109+
expected: ptr(reasoningBudgetUnlimited),
110+
expectError: false,
111+
},
112+
{
113+
name: "medium sets 1024 tokens",
114+
input: "medium",
115+
expected: ptr(reasoningBudgetMedium),
116+
expectError: false,
117+
},
118+
{
119+
name: "low sets 256 tokens",
120+
input: "low",
121+
expected: ptr(reasoningBudgetLow),
122+
expectError: false,
123+
},
124+
{
125+
name: "invalid value returns error",
126+
input: "invalid",
127+
expected: nil,
128+
expectError: true,
129+
},
130+
{
131+
name: "numeric string returns error",
132+
input: "1024",
133+
expected: nil,
134+
expectError: true,
135+
},
136+
}
137+
138+
for _, tt := range tests {
139+
t.Run(tt.name, func(t *testing.T) {
140+
result, err := parseThinkToReasoningBudget(tt.input)
141+
if tt.expectError {
142+
require.Error(t, err)
143+
} else {
144+
require.NoError(t, err)
145+
if tt.expected == nil {
146+
assert.Nil(t, result)
147+
} else {
148+
require.NotNil(t, result)
149+
assert.Equal(t, *tt.expected, *result)
150+
}
151+
}
152+
})
153+
}
154+
}

cmd/cli/docs/reference/docker_model_compose_up.yaml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,15 @@ options:
2323
experimentalcli: false
2424
kubernetes: false
2525
swarm: false
26+
- option: mode
27+
value_type: string
28+
description: backend operation mode (completion, embedding, reranking)
29+
deprecated: false
30+
hidden: false
31+
experimental: false
32+
experimentalcli: false
33+
kubernetes: false
34+
swarm: false
2635
- option: model
2736
value_type: stringArray
2837
default_value: '[]'
@@ -62,6 +71,16 @@ options:
6271
experimentalcli: false
6372
kubernetes: false
6473
swarm: false
74+
- option: think
75+
value_type: string
76+
description: |
77+
enable reasoning mode for thinking models (true/false/high/medium/low)
78+
deprecated: false
79+
hidden: false
80+
experimental: false
81+
experimentalcli: false
82+
kubernetes: false
83+
swarm: false
6584
inherited_options:
6685
- option: project-name
6786
value_type: string

0 commit comments

Comments
 (0)