Skip to content

Commit b4512b5

Browse files
authored
Change context size and reasoning budget types from int64/uint64 to int32 (#487)
* fix: change context size and reasoning budget types from uint64 to int32 * fix: change context size and reasoning budget types from int64 to int32
1 parent b2b0643 commit b4512b5

File tree

25 files changed

+236
-410
lines changed

25 files changed

+236
-410
lines changed

cmd/cli/commands/compose.go

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ func newUpCommand() *cobra.Command {
6666
return err
6767
}
6868

69-
if ctxSize > 0 {
69+
if cmd.Flags().Changed("context-size") {
7070
sendInfo(fmt.Sprintf("Setting context size to %d", ctxSize))
7171
}
7272

@@ -82,12 +82,18 @@ func newUpCommand() *cobra.Command {
8282
}
8383

8484
for _, model := range models {
85+
configuration := inference.BackendConfiguration{
86+
Speculative: speculativeConfig,
87+
}
88+
if cmd.Flags().Changed("context-size") {
89+
// TODO is the context size the same for all models?
90+
v := int32(ctxSize)
91+
configuration.ContextSize = &v
92+
}
93+
8594
if err := desktopClient.ConfigureBackend(scheduling.ConfigureRequest{
86-
Model: model,
87-
BackendConfiguration: inference.BackendConfiguration{
88-
ContextSize: ctxSize,
89-
Speculative: speculativeConfig,
90-
},
95+
Model: model,
96+
BackendConfiguration: configuration,
9197
}); err != nil {
9298
configErrFmtString := "failed to configure backend for model %s with context-size %d"
9399
_ = sendErrorf(configErrFmtString+": %v", model, ctxSize, err)

cmd/cli/commands/configure.go

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package commands
33
import (
44
"encoding/json"
55
"fmt"
6+
"strconv"
67

78
"github.com/docker/model-runner/cmd/cli/commands/completion"
89
"github.com/docker/model-runner/pkg/inference"
@@ -11,13 +12,45 @@ import (
1112
"github.com/spf13/cobra"
1213
)
1314

15+
// Int32PtrValue implements pflag.Value interface for *int32 pointers
16+
// This allows flags to have a nil default value instead of 0
17+
type Int32PtrValue struct {
18+
ptr **int32
19+
}
20+
21+
func NewInt32PtrValue(p **int32) *Int32PtrValue {
22+
return &Int32PtrValue{ptr: p}
23+
}
24+
25+
func (v *Int32PtrValue) String() string {
26+
if v.ptr == nil || *v.ptr == nil {
27+
return ""
28+
}
29+
return strconv.FormatInt(int64(**v.ptr), 10)
30+
}
31+
32+
func (v *Int32PtrValue) Set(s string) error {
33+
val, err := strconv.ParseInt(s, 10, 32)
34+
if err != nil {
35+
return err
36+
}
37+
i32 := int32(val)
38+
*v.ptr = &i32
39+
return nil
40+
}
41+
42+
func (v *Int32PtrValue) Type() string {
43+
return "int32"
44+
}
45+
1446
func newConfigureCmd() *cobra.Command {
1547
var opts scheduling.ConfigureRequest
1648
var draftModel string
1749
var numTokens int
1850
var minAcceptanceRate float64
1951
var hfOverrides string
20-
var reasoningBudget int64
52+
var contextSize *int32
53+
var reasoningBudget *int32
2154

2255
c := &cobra.Command{
2356
Use: "configure [--context-size=<n>] [--speculative-draft-model=<model>] [--hf_overrides=<json>] [--reasoning-budget=<n>] MODEL",
@@ -34,6 +67,8 @@ func newConfigureCmd() *cobra.Command {
3467
return nil
3568
},
3669
RunE: func(cmd *cobra.Command, args []string) error {
70+
// contextSize is nil by default, only set if user provided the flag
71+
opts.ContextSize = contextSize
3772
// Build the speculative config if any speculative flags are set
3873
if draftModel != "" || numTokens > 0 || minAcceptanceRate > 0 {
3974
opts.Speculative = &inference.SpeculativeDecodingConfig{
@@ -57,25 +92,24 @@ func newConfigureCmd() *cobra.Command {
5792
}
5893
opts.VLLM.HFOverrides = hfo
5994
}
60-
// Set llama.cpp-specific reasoning budget if explicitly provided
61-
// Note: We check if flag was changed rather than checking value > 0
62-
// because 0 is a valid value (disables reasoning) and -1 means unlimited
63-
if cmd.Flags().Changed("reasoning-budget") {
95+
// Set llama.cpp-specific reasoning budget if provided
96+
// reasoningBudget is nil by default, only set if user provided the flag
97+
if reasoningBudget != nil {
6498
if opts.LlamaCpp == nil {
6599
opts.LlamaCpp = &inference.LlamaCppConfig{}
66100
}
67-
opts.LlamaCpp.ReasoningBudget = &reasoningBudget
101+
opts.LlamaCpp.ReasoningBudget = reasoningBudget
68102
}
69103
return desktopClient.ConfigureBackend(opts)
70104
},
71105
ValidArgsFunction: completion.ModelNames(getDesktopClient, -1),
72106
}
73107

74-
c.Flags().Int64Var(&opts.ContextSize, "context-size", -1, "context size (in tokens)")
108+
c.Flags().Var(NewInt32PtrValue(&contextSize), "context-size", "context size (in tokens)")
75109
c.Flags().StringVar(&draftModel, "speculative-draft-model", "", "draft model for speculative decoding")
76110
c.Flags().IntVar(&numTokens, "speculative-num-tokens", 0, "number of tokens to predict speculatively")
77111
c.Flags().Float64Var(&minAcceptanceRate, "speculative-min-acceptance-rate", 0, "minimum acceptance rate for speculative decoding")
78112
c.Flags().StringVar(&hfOverrides, "hf_overrides", "", "HuggingFace model config overrides (JSON) - vLLM only")
79-
c.Flags().Int64Var(&reasoningBudget, "reasoning-budget", 0, "reasoning budget for reasoning models - llama.cpp only")
113+
c.Flags().Var(NewInt32PtrValue(&reasoningBudget), "reasoning-budget", "reasoning budget for reasoning models - llama.cpp only")
80114
return c
81115
}

cmd/cli/commands/configure_test.go

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@ func TestConfigureCmdReasoningBudgetFlag(t *testing.T) {
1414
t.Fatal("--reasoning-budget flag not found")
1515
}
1616

17-
// Verify the default value is 0
18-
if reasoningBudgetFlag.DefValue != "0" {
19-
t.Errorf("Expected default reasoning-budget value to be '0', got '%s'", reasoningBudgetFlag.DefValue)
17+
// Verify the default value is empty (nil pointer)
18+
if reasoningBudgetFlag.DefValue != "" {
19+
t.Errorf("Expected default reasoning-budget value to be '' (nil), got '%s'", reasoningBudgetFlag.DefValue)
2020
}
2121

2222
// Verify the flag type
23-
if reasoningBudgetFlag.Value.Type() != "int64" {
24-
t.Errorf("Expected reasoning-budget flag type to be 'int64', got '%s'", reasoningBudgetFlag.Value.Type())
23+
if reasoningBudgetFlag.Value.Type() != "int32" {
24+
t.Errorf("Expected reasoning-budget flag type to be 'int32', got '%s'", reasoningBudgetFlag.Value.Type())
2525
}
2626
}
2727

@@ -30,31 +30,31 @@ func TestConfigureCmdReasoningBudgetFlagChanged(t *testing.T) {
3030
name string
3131
setValue string
3232
expectChanged bool
33-
expectedValue int64
33+
expectedValue string
3434
}{
3535
{
3636
name: "flag not set - should not be changed",
3737
setValue: "",
3838
expectChanged: false,
39-
expectedValue: 0,
39+
expectedValue: "",
4040
},
4141
{
4242
name: "flag set to 0 (disable reasoning) - should be changed",
4343
setValue: "0",
4444
expectChanged: true,
45-
expectedValue: 0,
45+
expectedValue: "0",
4646
},
4747
{
4848
name: "flag set to -1 (unlimited) - should be changed",
4949
setValue: "-1",
5050
expectChanged: true,
51-
expectedValue: -1,
51+
expectedValue: "-1",
5252
},
5353
{
5454
name: "flag set to positive value - should be changed",
5555
setValue: "1024",
5656
expectChanged: true,
57-
expectedValue: 1024,
57+
expectedValue: "1024",
5858
},
5959
}
6060

@@ -77,13 +77,11 @@ func TestConfigureCmdReasoningBudgetFlagChanged(t *testing.T) {
7777
t.Errorf("Expected Changed() = %v, got %v", tt.expectChanged, isChanged)
7878
}
7979

80-
// Verify the value
81-
value, err := cmd.Flags().GetInt64("reasoning-budget")
82-
if err != nil {
83-
t.Fatalf("Failed to get reasoning-budget flag value: %v", err)
84-
}
80+
// Verify the value using String() method
81+
flag := cmd.Flags().Lookup("reasoning-budget")
82+
value := flag.Value.String()
8583
if value != tt.expectedValue {
86-
t.Errorf("Expected value = %d, got %d", tt.expectedValue, value)
84+
t.Errorf("Expected value = %s, got %s", tt.expectedValue, value)
8785
}
8886
})
8987
}
@@ -120,9 +118,9 @@ func TestConfigureCmdContextSizeFlag(t *testing.T) {
120118
t.Fatal("--context-size flag not found")
121119
}
122120

123-
// Verify the default value is -1 (indicating not set)
124-
if contextSizeFlag.DefValue != "-1" {
125-
t.Errorf("Expected default context-size value to be '-1', got '%s'", contextSizeFlag.DefValue)
121+
// Verify the default value is empty (nil pointer)
122+
if contextSizeFlag.DefValue != "" {
123+
t.Errorf("Expected default context-size value to be '' (nil), got '%s'", contextSizeFlag.DefValue)
126124
}
127125

128126
// Test setting the flag value
@@ -131,14 +129,10 @@ func TestConfigureCmdContextSizeFlag(t *testing.T) {
131129
t.Errorf("Failed to set context-size flag: %v", err)
132130
}
133131

134-
// Verify the value was set
135-
contextSizeValue, err := cmd.Flags().GetInt64("context-size")
136-
if err != nil {
137-
t.Errorf("Failed to get context-size flag value: %v", err)
138-
}
139-
140-
if contextSizeValue != 8192 {
141-
t.Errorf("Expected context-size flag value to be 8192, got %d", contextSizeValue)
132+
// Verify the value was set using String() method
133+
contextSizeValue := contextSizeFlag.Value.String()
134+
if contextSizeValue != "8192" {
135+
t.Errorf("Expected context-size flag value to be '8192', got '%s'", contextSizeValue)
142136
}
143137
}
144138

cmd/cli/commands/integration_test.go

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ func verifyModelInspect(t *testing.T, client *desktop.Client, ref, expectedID, e
216216

217217
// createAndPushTestModel creates a minimal test model and pushes it to the local registry.
218218
// Returns the model ID, FQDNs for host and network access, and the manifest digest.
219-
func createAndPushTestModel(t *testing.T, registryURL, modelRef string, contextSize uint64) (modelID, hostFQDN, networkFQDN, digest string) {
219+
func createAndPushTestModel(t *testing.T, registryURL, modelRef string, contextSize *int32) (modelID, hostFQDN, networkFQDN, digest string) {
220220
ctx := context.Background()
221221

222222
// Use the dummy GGUF file from assets
@@ -234,8 +234,8 @@ func createAndPushTestModel(t *testing.T, registryURL, modelRef string, contextS
234234
require.NoError(t, err)
235235

236236
// Set context size if specified
237-
if contextSize > 0 {
238-
pkg = pkg.WithContextSize(contextSize)
237+
if contextSize != nil {
238+
pkg = pkg.WithContextSize(*contextSize)
239239
}
240240

241241
// Construct the full reference with the local registry host for pushing from test host
@@ -287,7 +287,7 @@ func TestIntegration_PullModel(t *testing.T) {
287287
// Create and push two test models with different organizations
288288
// Model 1: custom org (test/test-model:latest)
289289
modelRef1 := "test/test-model:latest"
290-
modelID1, hostFQDN1, networkFQDN1, digest1 := createAndPushTestModel(t, env.registryURL, modelRef1, 2048)
290+
modelID1, hostFQDN1, networkFQDN1, digest1 := createAndPushTestModel(t, env.registryURL, modelRef1, int32ptr(2048))
291291
t.Logf("Test model 1 pushed: %s (ID: %s) FQDN: %s Digest: %s", hostFQDN1, modelID1, networkFQDN1, digest1)
292292

293293
// Generate test cases for custom org model (test/test-model)
@@ -304,7 +304,7 @@ func TestIntegration_PullModel(t *testing.T) {
304304

305305
// Model 2: default org (ai/test-model:latest)
306306
modelRef2 := "ai/test-model:latest"
307-
modelID2, hostFQDN2, networkFQDN2, digest2 := createAndPushTestModel(t, env.registryURL, modelRef2, 2048)
307+
modelID2, hostFQDN2, networkFQDN2, digest2 := createAndPushTestModel(t, env.registryURL, modelRef2, int32ptr(2048))
308308
t.Logf("Test model 2 pushed: %s (ID: %s) FQDN: %s Digest: %s", hostFQDN2, modelID2, networkFQDN2, digest2)
309309

310310
// Generate test cases for default org model (ai/test-model)
@@ -420,7 +420,7 @@ func TestIntegration_InspectModel(t *testing.T) {
420420

421421
// Create and push a test model with default org (ai/inspect-test:latest)
422422
modelRef := "ai/inspect-test:latest"
423-
modelID, hostFQDN, networkFQDN, digest := createAndPushTestModel(t, env.registryURL, modelRef, 2048)
423+
modelID, hostFQDN, networkFQDN, digest := createAndPushTestModel(t, env.registryURL, modelRef, int32ptr(2048))
424424
t.Logf("Test model pushed: %s (ID: %s) FQDN: %s Digest: %s", hostFQDN, modelID, networkFQDN, digest)
425425

426426
// Pull the model using a short reference
@@ -479,7 +479,7 @@ func TestIntegration_TagModel(t *testing.T) {
479479

480480
// Create and push a test model with default org (ai/tag-test:latest)
481481
modelRef := "ai/tag-test:latest"
482-
modelID, hostFQDN, networkFQDN, digest := createAndPushTestModel(t, env.registryURL, modelRef, 2048)
482+
modelID, hostFQDN, networkFQDN, digest := createAndPushTestModel(t, env.registryURL, modelRef, int32ptr(2048))
483483
t.Logf("Test model pushed: %s (ID: %s) FQDN: %s Digest: %s", hostFQDN, modelID, networkFQDN, digest)
484484

485485
// Pull the model using a simple reference
@@ -657,7 +657,7 @@ func TestIntegration_PushModel(t *testing.T) {
657657

658658
// Create and push a test model with default org (ai/tag-test:latest)
659659
modelRef := "ai/tag-test:latest"
660-
modelID, hostFQDN, networkFQDN, digest := createAndPushTestModel(t, env.registryURL, modelRef, 2048)
660+
modelID, hostFQDN, networkFQDN, digest := createAndPushTestModel(t, env.registryURL, modelRef, int32ptr(2048))
661661
t.Logf("Test model pushed: %s (ID: %s) FQDN: %s Digest: %s", hostFQDN, modelID, networkFQDN, digest)
662662

663663
// Pull the model using a simple reference
@@ -791,7 +791,7 @@ func TestIntegration_RemoveModel(t *testing.T) {
791791

792792
// Create and push a test model with default org (ai/rm-test:latest)
793793
modelRef := "ai/rm-test:latest"
794-
modelID, hostFQDN, networkFQDN, digest := createAndPushTestModel(t, env.registryURL, modelRef, 2048)
794+
modelID, hostFQDN, networkFQDN, digest := createAndPushTestModel(t, env.registryURL, modelRef, int32ptr(2048))
795795
t.Logf("Test model pushed: %s (ID: %s) FQDN: %s Digest: %s", hostFQDN, modelID, networkFQDN, digest)
796796

797797
// Generate all reference test cases
@@ -842,9 +842,9 @@ func TestIntegration_RemoveModel(t *testing.T) {
842842
t.Run("remove multiple models", func(t *testing.T) {
843843
// Create and push two different models
844844
modelRef1 := "ai/rm-multi-1:latest"
845-
modelID1, _, _, _ := createAndPushTestModel(t, env.registryURL, modelRef1, 2048)
845+
modelID1, _, _, _ := createAndPushTestModel(t, env.registryURL, modelRef1, int32ptr(2048))
846846
modelRef2 := "ai/rm-multi-2:latest"
847-
modelID2, _, _, _ := createAndPushTestModel(t, env.registryURL, modelRef2, 2048)
847+
modelID2, _, _, _ := createAndPushTestModel(t, env.registryURL, modelRef2, int32ptr(2048))
848848

849849
// Pull both models
850850
t.Logf("Pulling first model: rm-multi-1")
@@ -1014,3 +1014,7 @@ func TestIntegration_RemoveModel(t *testing.T) {
10141014
})
10151015
})
10161016
}
1017+
1018+
func int32ptr(n int32) *int32 {
1019+
return &n
1020+
}

cmd/cli/commands/package.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,9 +284,9 @@ func packageModel(cmd *cobra.Command, opts packageOptions) error {
284284
distClient := initResult.distClient
285285

286286
// Set context size
287-
if opts.contextSize > 0 {
287+
if cmd.Flags().Changed("context-size") {
288288
cmd.PrintErrf("Setting context size %d\n", opts.contextSize)
289-
pkg = pkg.WithContextSize(opts.contextSize)
289+
pkg = pkg.WithContextSize(int32(opts.contextSize))
290290
}
291291

292292
// Add license files

cmd/cli/docs/reference/docker_model_configure.yaml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@ pname: docker model
66
plink: docker_model.yaml
77
options:
88
- option: context-size
9-
value_type: int64
10-
default_value: "-1"
9+
value_type: int32
1110
description: context size (in tokens)
1211
deprecated: false
1312
hidden: false
@@ -25,8 +24,7 @@ options:
2524
kubernetes: false
2625
swarm: false
2726
- option: reasoning-budget
28-
value_type: int64
29-
default_value: "0"
27+
value_type: int32
3028
description: reasoning budget for reasoning models - llama.cpp only
3129
deprecated: false
3230
hidden: false

cmd/mdltool/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ func cmdPackage(args []string) int {
321321

322322
if contextSize > 0 {
323323
fmt.Println("Setting context size:", contextSize)
324-
b = b.WithContextSize(contextSize)
324+
b = b.WithContextSize(int32(contextSize))
325325
}
326326

327327
if mmproj != "" {

pkg/distribution/builder/builder.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ func (b *Builder) WithLicense(path string) (*Builder, error) {
6767
}, nil
6868
}
6969

70-
func (b *Builder) WithContextSize(size uint64) *Builder {
70+
func (b *Builder) WithContextSize(size int32) *Builder {
7171
return &Builder{
7272
model: mutate.ContextSize(b.model, size),
7373
originalLayers: b.originalLayers,

pkg/distribution/internal/mutate/model.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ type model struct {
1616
base types.ModelArtifact
1717
appended []v1.Layer
1818
configMediaType ggcr.MediaType
19-
contextSize *uint64
19+
contextSize *int32
2020
}
2121

2222
func (m *model) Descriptor() (types.Descriptor, error) {

pkg/distribution/internal/mutate/mutate.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ func ConfigMediaType(mdl types.ModelArtifact, mt ggcr.MediaType) types.ModelArti
2121
}
2222
}
2323

24-
func ContextSize(mdl types.ModelArtifact, cs uint64) types.ModelArtifact {
24+
func ContextSize(mdl types.ModelArtifact, cs int32) types.ModelArtifact {
2525
return &model{
2626
base: mdl,
2727
contextSize: &cs,

0 commit comments

Comments
 (0)