Skip to content

Commit 6bef10a

Browse files
authored
Simplify configure (#495)
* refactor: unify reasoning budget handling and command flags * replace reasoning budget with think flag for model configuration * refactor: change Think flag to pointer for nil detection in model configuration
1 parent cdc3203 commit 6bef10a

File tree

9 files changed

+357
-413
lines changed

9 files changed

+357
-413
lines changed

cmd/cli/commands/compose.go

Lines changed: 7 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -34,28 +34,13 @@ func newComposeCmd() *cobra.Command {
3434
return c
3535
}
3636

37-
// Reasoning budget constants for the think parameter conversion
38-
const (
39-
reasoningBudgetUnlimited int32 = -1
40-
reasoningBudgetDisabled int32 = 0
41-
reasoningBudgetMedium int32 = 1024
42-
reasoningBudgetLow int32 = 256
43-
)
44-
45-
// ptr is a helper function to create a pointer to int32
46-
func ptr(v int32) *int32 {
47-
return &v
48-
}
49-
5037
func newUpCommand() *cobra.Command {
5138
var models []string
5239
var ctxSize int64
5340
var backend string
5441
var draftModel string
5542
var numTokens int
5643
var minAcceptanceRate float64
57-
var mode string
58-
var think string
5944
c := &cobra.Command{
6045
Use: "up",
6146
RunE: func(cmd *cobra.Command, args []string) error {
@@ -81,7 +66,7 @@ func newUpCommand() *cobra.Command {
8166
return err
8267
}
8368

84-
if cmd.Flags().Changed("context-size") {
69+
if ctxSize > 0 {
8570
sendInfo(fmt.Sprintf("Setting context size to %d", ctxSize))
8671
}
8772

@@ -96,52 +81,14 @@ func newUpCommand() *cobra.Command {
9681
sendInfo(fmt.Sprintf("Enabling speculative decoding with draft model: %s", draftModel))
9782
}
9883

99-
// Parse mode if provided
100-
var backendMode *inference.BackendMode
101-
if mode != "" {
102-
parsedMode, err := parseBackendMode(mode)
103-
if err != nil {
104-
_ = sendError(err.Error())
105-
return err
106-
}
107-
backendMode = &parsedMode
108-
sendInfo(fmt.Sprintf("Setting backend mode to %s", mode))
109-
}
110-
111-
// Parse think parameter for reasoning budget
112-
var reasoningBudget *int32
113-
if think != "" {
114-
budget, err := parseThinkToReasoningBudget(think)
115-
if err != nil {
116-
_ = sendError(err.Error())
117-
return err
118-
}
119-
reasoningBudget = budget
120-
sendInfo(fmt.Sprintf("Setting think mode to %s", think))
121-
}
122-
12384
for _, model := range models {
124-
configuration := inference.BackendConfiguration{
125-
Speculative: speculativeConfig,
126-
}
127-
if cmd.Flags().Changed("context-size") {
128-
// TODO is the context size the same for all models?
129-
v := int32(ctxSize)
130-
configuration.ContextSize = &v
131-
}
132-
133-
// Set llama.cpp-specific reasoning budget if provided
134-
if reasoningBudget != nil {
135-
if configuration.LlamaCpp == nil {
136-
configuration.LlamaCpp = &inference.LlamaCppConfig{}
137-
}
138-
configuration.LlamaCpp.ReasoningBudget = reasoningBudget
139-
}
140-
85+
size := int32(ctxSize)
14186
if err := desktopClient.ConfigureBackend(scheduling.ConfigureRequest{
142-
Model: model,
143-
Mode: backendMode,
144-
BackendConfiguration: configuration,
87+
Model: model,
88+
BackendConfiguration: inference.BackendConfiguration{
89+
ContextSize: &size,
90+
Speculative: speculativeConfig,
91+
},
14592
}); err != nil {
14693
configErrFmtString := "failed to configure backend for model %s with context-size %d"
14794
_ = sendErrorf(configErrFmtString+": %v", model, ctxSize, err)
@@ -171,57 +118,10 @@ func newUpCommand() *cobra.Command {
171118
c.Flags().StringVar(&draftModel, "speculative-draft-model", "", "draft model for speculative decoding")
172119
c.Flags().IntVar(&numTokens, "speculative-num-tokens", 0, "number of tokens to predict speculatively")
173120
c.Flags().Float64Var(&minAcceptanceRate, "speculative-min-acceptance-rate", 0, "minimum acceptance rate for speculative decoding")
174-
c.Flags().StringVar(&mode, "mode", "", "backend operation mode (completion, embedding, reranking)")
175-
c.Flags().StringVar(&think, "think", "", "enable reasoning mode for thinking models (true/false/high/medium/low)")
176121
_ = c.MarkFlagRequired("model")
177122
return c
178123
}
179124

180-
// parseBackendMode parses a string mode value into an inference.BackendMode.
181-
func parseBackendMode(mode string) (inference.BackendMode, error) {
182-
switch strings.ToLower(mode) {
183-
case "completion":
184-
return inference.BackendModeCompletion, nil
185-
case "embedding":
186-
return inference.BackendModeEmbedding, nil
187-
case "reranking":
188-
return inference.BackendModeReranking, nil
189-
default:
190-
return inference.BackendModeCompletion, fmt.Errorf("invalid mode %q: must be one of completion, embedding, reranking", mode)
191-
}
192-
}
193-
194-
// parseThinkToReasoningBudget converts the think parameter string to a reasoning budget value.
195-
// Accepts: "true", "false", "high", "medium", "low"
196-
// Returns:
197-
// - nil for empty string or "true" (use server default, which is unlimited)
198-
// - -1 for "high" (explicitly set unlimited)
199-
// - 0 for "false" (disable thinking)
200-
// - 1024 for "medium"
201-
// - 256 for "low"
202-
func parseThinkToReasoningBudget(think string) (*int32, error) {
203-
if think == "" {
204-
return nil, nil
205-
}
206-
207-
switch strings.ToLower(think) {
208-
case "true":
209-
// Use nil to let the server use its default (currently unlimited)
210-
return nil, nil
211-
case "high":
212-
// Explicitly set unlimited reasoning budget
213-
return ptr(reasoningBudgetUnlimited), nil
214-
case "false":
215-
return ptr(reasoningBudgetDisabled), nil
216-
case "medium":
217-
return ptr(reasoningBudgetMedium), nil
218-
case "low":
219-
return ptr(reasoningBudgetLow), nil
220-
default:
221-
return nil, fmt.Errorf("invalid think value %q: must be one of true, false, high, medium, low", think)
222-
}
223-
}
224-
225125
func newDownCommand() *cobra.Command {
226126
c := &cobra.Command{
227127
Use: "down",

cmd/cli/commands/compose_test.go

Lines changed: 0 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -71,84 +71,3 @@ func TestParseBackendMode(t *testing.T) {
7171
})
7272
}
7373
}
74-
75-
func TestParseThinkToReasoningBudget(t *testing.T) {
76-
tests := []struct {
77-
name string
78-
input string
79-
expected *int32
80-
expectError bool
81-
}{
82-
{
83-
name: "empty string returns nil",
84-
input: "",
85-
expected: nil,
86-
expectError: false,
87-
},
88-
{
89-
name: "true returns nil (use server default)",
90-
input: "true",
91-
expected: nil,
92-
expectError: false,
93-
},
94-
{
95-
name: "TRUE returns nil (case insensitive)",
96-
input: "TRUE",
97-
expected: nil,
98-
expectError: false,
99-
},
100-
{
101-
name: "false disables reasoning",
102-
input: "false",
103-
expected: ptr(reasoningBudgetDisabled),
104-
expectError: false,
105-
},
106-
{
107-
name: "high explicitly sets unlimited (-1)",
108-
input: "high",
109-
expected: ptr(reasoningBudgetUnlimited),
110-
expectError: false,
111-
},
112-
{
113-
name: "medium sets 1024 tokens",
114-
input: "medium",
115-
expected: ptr(reasoningBudgetMedium),
116-
expectError: false,
117-
},
118-
{
119-
name: "low sets 256 tokens",
120-
input: "low",
121-
expected: ptr(reasoningBudgetLow),
122-
expectError: false,
123-
},
124-
{
125-
name: "invalid value returns error",
126-
input: "invalid",
127-
expected: nil,
128-
expectError: true,
129-
},
130-
{
131-
name: "numeric string returns error",
132-
input: "1024",
133-
expected: nil,
134-
expectError: true,
135-
},
136-
}
137-
138-
for _, tt := range tests {
139-
t.Run(tt.name, func(t *testing.T) {
140-
result, err := parseThinkToReasoningBudget(tt.input)
141-
if tt.expectError {
142-
require.Error(t, err)
143-
} else {
144-
require.NoError(t, err)
145-
if tt.expected == nil {
146-
assert.Nil(t, result)
147-
} else {
148-
require.NotNil(t, result)
149-
assert.Equal(t, *tt.expected, *result)
150-
}
151-
}
152-
})
153-
}
154-
}

cmd/cli/commands/configure.go

Lines changed: 7 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,17 @@
11
package commands
22

33
import (
4-
"encoding/json"
54
"fmt"
6-
"strconv"
75

86
"github.com/docker/model-runner/cmd/cli/commands/completion"
9-
"github.com/docker/model-runner/pkg/inference"
10-
11-
"github.com/docker/model-runner/pkg/inference/scheduling"
127
"github.com/spf13/cobra"
138
)
149

15-
// Int32PtrValue implements pflag.Value interface for *int32 pointers
16-
// This allows flags to have a nil default value instead of 0
17-
type Int32PtrValue struct {
18-
ptr **int32
19-
}
20-
21-
func NewInt32PtrValue(p **int32) *Int32PtrValue {
22-
return &Int32PtrValue{ptr: p}
23-
}
24-
25-
func (v *Int32PtrValue) String() string {
26-
if v.ptr == nil || *v.ptr == nil {
27-
return ""
28-
}
29-
return strconv.FormatInt(int64(**v.ptr), 10)
30-
}
31-
32-
func (v *Int32PtrValue) Set(s string) error {
33-
val, err := strconv.ParseInt(s, 10, 32)
34-
if err != nil {
35-
return err
36-
}
37-
i32 := int32(val)
38-
*v.ptr = &i32
39-
return nil
40-
}
41-
42-
func (v *Int32PtrValue) Type() string {
43-
return "int32"
44-
}
45-
4610
func newConfigureCmd() *cobra.Command {
47-
var opts scheduling.ConfigureRequest
48-
var draftModel string
49-
var numTokens int
50-
var minAcceptanceRate float64
51-
var hfOverrides string
52-
var contextSize *int32
53-
var reasoningBudget *int32
11+
var flags ConfigureFlags
5412

5513
c := &cobra.Command{
56-
Use: "configure [--context-size=<n>] [--speculative-draft-model=<model>] [--hf_overrides=<json>] [--reasoning-budget=<n>] MODEL",
14+
Use: "configure [--context-size=<n>] [--speculative-draft-model=<model>] [--hf_overrides=<json>] [--mode=<mode>] [--think] MODEL",
5715
Short: "Configure runtime options for a model",
5816
Hidden: true,
5917
Args: func(cmd *cobra.Command, args []string) error {
@@ -63,53 +21,19 @@ func newConfigureCmd() *cobra.Command {
6321
"See 'docker model configure --help' for more information",
6422
len(args), args)
6523
}
66-
opts.Model = args[0]
6724
return nil
6825
},
6926
RunE: func(cmd *cobra.Command, args []string) error {
70-
// contextSize is nil by default, only set if user provided the flag
71-
opts.ContextSize = contextSize
72-
// Build the speculative config if any speculative flags are set
73-
if draftModel != "" || numTokens > 0 || minAcceptanceRate > 0 {
74-
opts.Speculative = &inference.SpeculativeDecodingConfig{
75-
DraftModel: draftModel,
76-
NumTokens: numTokens,
77-
MinAcceptanceRate: minAcceptanceRate,
78-
}
79-
}
80-
// Parse and validate HuggingFace overrides if provided (vLLM-specific)
81-
if hfOverrides != "" {
82-
var hfo inference.HFOverrides
83-
if err := json.Unmarshal([]byte(hfOverrides), &hfo); err != nil {
84-
return fmt.Errorf("invalid --hf_overrides JSON: %w", err)
85-
}
86-
// Validate the overrides to prevent command injection
87-
if err := hfo.Validate(); err != nil {
88-
return err
89-
}
90-
if opts.VLLM == nil {
91-
opts.VLLM = &inference.VLLMConfig{}
92-
}
93-
opts.VLLM.HFOverrides = hfo
94-
}
95-
// Set llama.cpp-specific reasoning budget if provided
96-
// reasoningBudget is nil by default, only set if user provided the flag
97-
if reasoningBudget != nil {
98-
if opts.LlamaCpp == nil {
99-
opts.LlamaCpp = &inference.LlamaCppConfig{}
100-
}
101-
opts.LlamaCpp.ReasoningBudget = reasoningBudget
27+
model := args[0]
28+
opts, err := flags.BuildConfigureRequest(model)
29+
if err != nil {
30+
return err
10231
}
10332
return desktopClient.ConfigureBackend(opts)
10433
},
10534
ValidArgsFunction: completion.ModelNames(getDesktopClient, -1),
10635
}
10736

108-
c.Flags().Var(NewInt32PtrValue(&contextSize), "context-size", "context size (in tokens)")
109-
c.Flags().StringVar(&draftModel, "speculative-draft-model", "", "draft model for speculative decoding")
110-
c.Flags().IntVar(&numTokens, "speculative-num-tokens", 0, "number of tokens to predict speculatively")
111-
c.Flags().Float64Var(&minAcceptanceRate, "speculative-min-acceptance-rate", 0, "minimum acceptance rate for speculative decoding")
112-
c.Flags().StringVar(&hfOverrides, "hf_overrides", "", "HuggingFace model config overrides (JSON) - vLLM only")
113-
c.Flags().Var(NewInt32PtrValue(&reasoningBudget), "reasoning-budget", "reasoning budget for reasoning models - llama.cpp only")
37+
flags.RegisterFlags(c)
11438
return c
11539
}

0 commit comments

Comments
 (0)