diff --git a/cmd/cli/commands/compose.go b/cmd/cli/commands/compose.go index fbba8cf08..029761a46 100644 --- a/cmd/cli/commands/compose.go +++ b/cmd/cli/commands/compose.go @@ -66,7 +66,7 @@ func newUpCommand() *cobra.Command { return err } - if ctxSize > 0 { + if cmd.Flags().Changed("context-size") { sendInfo(fmt.Sprintf("Setting context size to %d", ctxSize)) } @@ -82,12 +82,18 @@ func newUpCommand() *cobra.Command { } for _, model := range models { + configuration := inference.BackendConfiguration{ + Speculative: speculativeConfig, + } + if cmd.Flags().Changed("context-size") { + // TODO is the context size the same for all models? + v := int32(ctxSize) + configuration.ContextSize = &v + } + if err := desktopClient.ConfigureBackend(scheduling.ConfigureRequest{ - Model: model, - BackendConfiguration: inference.BackendConfiguration{ - ContextSize: ctxSize, - Speculative: speculativeConfig, - }, + Model: model, + BackendConfiguration: configuration, }); err != nil { configErrFmtString := "failed to configure backend for model %s with context-size %d" _ = sendErrorf(configErrFmtString+": %v", model, ctxSize, err) diff --git a/cmd/cli/commands/configure.go b/cmd/cli/commands/configure.go index 619b0be3b..debdd59eb 100644 --- a/cmd/cli/commands/configure.go +++ b/cmd/cli/commands/configure.go @@ -3,6 +3,7 @@ package commands import ( "encoding/json" "fmt" + "strconv" "github.com/docker/model-runner/cmd/cli/commands/completion" "github.com/docker/model-runner/pkg/inference" @@ -11,13 +12,45 @@ import ( "github.com/spf13/cobra" ) +// Int32PtrValue implements pflag.Value interface for *int32 pointers +// This allows flags to have a nil default value instead of 0 +type Int32PtrValue struct { + ptr **int32 +} + +func NewInt32PtrValue(p **int32) *Int32PtrValue { + return &Int32PtrValue{ptr: p} +} + +func (v *Int32PtrValue) String() string { + if v.ptr == nil || *v.ptr == nil { + return "" + } + return strconv.FormatInt(int64(**v.ptr), 10) +} + +func (v *Int32PtrValue) Set(s string) error { + val, err := strconv.ParseInt(s, 10, 32) + if err != nil { + return err + } + i32 := int32(val) + *v.ptr = &i32 + return nil +} + +func (v *Int32PtrValue) Type() string { + return "int32" +} + func newConfigureCmd() *cobra.Command { var opts scheduling.ConfigureRequest var draftModel string var numTokens int var minAcceptanceRate float64 var hfOverrides string - var reasoningBudget int64 + var contextSize *int32 + var reasoningBudget *int32 c := &cobra.Command{ Use: "configure [--context-size=] [--speculative-draft-model=] [--hf_overrides=] [--reasoning-budget=] MODEL", @@ -34,6 +67,8 @@ func newConfigureCmd() *cobra.Command { return nil }, RunE: func(cmd *cobra.Command, args []string) error { + // contextSize is nil by default, only set if user provided the flag + opts.ContextSize = contextSize // Build the speculative config if any speculative flags are set if draftModel != "" || numTokens > 0 || minAcceptanceRate > 0 { opts.Speculative = &inference.SpeculativeDecodingConfig{ @@ -57,25 +92,24 @@ func newConfigureCmd() *cobra.Command { } opts.VLLM.HFOverrides = hfo } - // Set llama.cpp-specific reasoning budget if explicitly provided - // Note: We check if flag was changed rather than checking value > 0 - // because 0 is a valid value (disables reasoning) and -1 means unlimited - if cmd.Flags().Changed("reasoning-budget") { + // Set llama.cpp-specific reasoning budget if provided + // reasoningBudget is nil by default, only set if user provided the flag + if reasoningBudget != nil { if opts.LlamaCpp == nil { opts.LlamaCpp = &inference.LlamaCppConfig{} } - opts.LlamaCpp.ReasoningBudget = &reasoningBudget + opts.LlamaCpp.ReasoningBudget = reasoningBudget } return desktopClient.ConfigureBackend(opts) }, ValidArgsFunction: completion.ModelNames(getDesktopClient, -1), } - c.Flags().Int64Var(&opts.ContextSize, "context-size", -1, "context size (in tokens)") + c.Flags().Var(NewInt32PtrValue(&contextSize), "context-size", "context size (in tokens)") c.Flags().StringVar(&draftModel, "speculative-draft-model", "", "draft model for speculative decoding") c.Flags().IntVar(&numTokens, "speculative-num-tokens", 0, "number of tokens to predict speculatively") c.Flags().Float64Var(&minAcceptanceRate, "speculative-min-acceptance-rate", 0, "minimum acceptance rate for speculative decoding") c.Flags().StringVar(&hfOverrides, "hf_overrides", "", "HuggingFace model config overrides (JSON) - vLLM only") - c.Flags().Int64Var(&reasoningBudget, "reasoning-budget", 0, "reasoning budget for reasoning models - llama.cpp only") + c.Flags().Var(NewInt32PtrValue(&reasoningBudget), "reasoning-budget", "reasoning budget for reasoning models - llama.cpp only") return c } diff --git a/cmd/cli/commands/configure_test.go b/cmd/cli/commands/configure_test.go index 47c3d7c22..c43e5dacd 100644 --- a/cmd/cli/commands/configure_test.go +++ b/cmd/cli/commands/configure_test.go @@ -14,14 +14,14 @@ func TestConfigureCmdReasoningBudgetFlag(t *testing.T) { t.Fatal("--reasoning-budget flag not found") } - // Verify the default value is 0 - if reasoningBudgetFlag.DefValue != "0" { - t.Errorf("Expected default reasoning-budget value to be '0', got '%s'", reasoningBudgetFlag.DefValue) + // Verify the default value is empty (nil pointer) + if reasoningBudgetFlag.DefValue != "" { + t.Errorf("Expected default reasoning-budget value to be '' (nil), got '%s'", reasoningBudgetFlag.DefValue) } // Verify the flag type - if reasoningBudgetFlag.Value.Type() != "int64" { - t.Errorf("Expected reasoning-budget flag type to be 'int64', got '%s'", reasoningBudgetFlag.Value.Type()) + if reasoningBudgetFlag.Value.Type() != "int32" { + t.Errorf("Expected reasoning-budget flag type to be 'int32', got '%s'", reasoningBudgetFlag.Value.Type()) } } @@ -30,31 +30,31 @@ func TestConfigureCmdReasoningBudgetFlagChanged(t *testing.T) { name string setValue string expectChanged bool - expectedValue int64 + expectedValue string }{ { name: "flag not set - should not be changed", setValue: "", expectChanged: false, - expectedValue: 0, + expectedValue: "", }, { name: "flag set to 0 (disable reasoning) - should be changed", setValue: "0", expectChanged: true, - expectedValue: 0, + expectedValue: "0", }, { name: "flag set to -1 (unlimited) - should be changed", setValue: "-1", expectChanged: true, - expectedValue: -1, + expectedValue: "-1", }, { name: "flag set to positive value - should be changed", setValue: "1024", expectChanged: true, - expectedValue: 1024, + expectedValue: "1024", }, } @@ -77,13 +77,11 @@ func TestConfigureCmdReasoningBudgetFlagChanged(t *testing.T) { t.Errorf("Expected Changed() = %v, got %v", tt.expectChanged, isChanged) } - // Verify the value - value, err := cmd.Flags().GetInt64("reasoning-budget") - if err != nil { - t.Fatalf("Failed to get reasoning-budget flag value: %v", err) - } + // Verify the value using String() method + flag := cmd.Flags().Lookup("reasoning-budget") + value := flag.Value.String() if value != tt.expectedValue { - t.Errorf("Expected value = %d, got %d", tt.expectedValue, value) + t.Errorf("Expected value = %s, got %s", tt.expectedValue, value) } }) } @@ -120,9 +118,9 @@ func TestConfigureCmdContextSizeFlag(t *testing.T) { t.Fatal("--context-size flag not found") } - // Verify the default value is -1 (indicating not set) - if contextSizeFlag.DefValue != "-1" { - t.Errorf("Expected default context-size value to be '-1', got '%s'", contextSizeFlag.DefValue) + // Verify the default value is empty (nil pointer) + if contextSizeFlag.DefValue != "" { + t.Errorf("Expected default context-size value to be '' (nil), got '%s'", contextSizeFlag.DefValue) } // Test setting the flag value @@ -131,14 +129,10 @@ func TestConfigureCmdContextSizeFlag(t *testing.T) { t.Errorf("Failed to set context-size flag: %v", err) } - // Verify the value was set - contextSizeValue, err := cmd.Flags().GetInt64("context-size") - if err != nil { - t.Errorf("Failed to get context-size flag value: %v", err) - } - - if contextSizeValue != 8192 { - t.Errorf("Expected context-size flag value to be 8192, got %d", contextSizeValue) + // Verify the value was set using String() method + contextSizeValue := contextSizeFlag.Value.String() + if contextSizeValue != "8192" { + t.Errorf("Expected context-size flag value to be '8192', got '%s'", contextSizeValue) } } diff --git a/cmd/cli/commands/integration_test.go b/cmd/cli/commands/integration_test.go index 19729f59e..4b74d014f 100644 --- a/cmd/cli/commands/integration_test.go +++ b/cmd/cli/commands/integration_test.go @@ -216,7 +216,7 @@ func verifyModelInspect(t *testing.T, client *desktop.Client, ref, expectedID, e // createAndPushTestModel creates a minimal test model and pushes it to the local registry. // Returns the model ID, FQDNs for host and network access, and the manifest digest. -func createAndPushTestModel(t *testing.T, registryURL, modelRef string, contextSize uint64) (modelID, hostFQDN, networkFQDN, digest string) { +func createAndPushTestModel(t *testing.T, registryURL, modelRef string, contextSize *int32) (modelID, hostFQDN, networkFQDN, digest string) { ctx := context.Background() // Use the dummy GGUF file from assets @@ -234,8 +234,8 @@ func createAndPushTestModel(t *testing.T, registryURL, modelRef string, contextS require.NoError(t, err) // Set context size if specified - if contextSize > 0 { - pkg = pkg.WithContextSize(contextSize) + if contextSize != nil { + pkg = pkg.WithContextSize(*contextSize) } // Construct the full reference with the local registry host for pushing from test host @@ -287,7 +287,7 @@ func TestIntegration_PullModel(t *testing.T) { // Create and push two test models with different organizations // Model 1: custom org (test/test-model:latest) modelRef1 := "test/test-model:latest" - modelID1, hostFQDN1, networkFQDN1, digest1 := createAndPushTestModel(t, env.registryURL, modelRef1, 2048) + modelID1, hostFQDN1, networkFQDN1, digest1 := createAndPushTestModel(t, env.registryURL, modelRef1, int32ptr(2048)) t.Logf("Test model 1 pushed: %s (ID: %s) FQDN: %s Digest: %s", hostFQDN1, modelID1, networkFQDN1, digest1) // Generate test cases for custom org model (test/test-model) @@ -304,7 +304,7 @@ func TestIntegration_PullModel(t *testing.T) { // Model 2: default org (ai/test-model:latest) modelRef2 := "ai/test-model:latest" - modelID2, hostFQDN2, networkFQDN2, digest2 := createAndPushTestModel(t, env.registryURL, modelRef2, 2048) + modelID2, hostFQDN2, networkFQDN2, digest2 := createAndPushTestModel(t, env.registryURL, modelRef2, int32ptr(2048)) t.Logf("Test model 2 pushed: %s (ID: %s) FQDN: %s Digest: %s", hostFQDN2, modelID2, networkFQDN2, digest2) // Generate test cases for default org model (ai/test-model) @@ -420,7 +420,7 @@ func TestIntegration_InspectModel(t *testing.T) { // Create and push a test model with default org (ai/inspect-test:latest) modelRef := "ai/inspect-test:latest" - modelID, hostFQDN, networkFQDN, digest := createAndPushTestModel(t, env.registryURL, modelRef, 2048) + modelID, hostFQDN, networkFQDN, digest := createAndPushTestModel(t, env.registryURL, modelRef, int32ptr(2048)) t.Logf("Test model pushed: %s (ID: %s) FQDN: %s Digest: %s", hostFQDN, modelID, networkFQDN, digest) // Pull the model using a short reference @@ -479,7 +479,7 @@ func TestIntegration_TagModel(t *testing.T) { // Create and push a test model with default org (ai/tag-test:latest) modelRef := "ai/tag-test:latest" - modelID, hostFQDN, networkFQDN, digest := createAndPushTestModel(t, env.registryURL, modelRef, 2048) + modelID, hostFQDN, networkFQDN, digest := createAndPushTestModel(t, env.registryURL, modelRef, int32ptr(2048)) t.Logf("Test model pushed: %s (ID: %s) FQDN: %s Digest: %s", hostFQDN, modelID, networkFQDN, digest) // Pull the model using a simple reference @@ -657,7 +657,7 @@ func TestIntegration_PushModel(t *testing.T) { // Create and push a test model with default org (ai/tag-test:latest) modelRef := "ai/tag-test:latest" - modelID, hostFQDN, networkFQDN, digest := createAndPushTestModel(t, env.registryURL, modelRef, 2048) + modelID, hostFQDN, networkFQDN, digest := createAndPushTestModel(t, env.registryURL, modelRef, int32ptr(2048)) t.Logf("Test model pushed: %s (ID: %s) FQDN: %s Digest: %s", hostFQDN, modelID, networkFQDN, digest) // Pull the model using a simple reference @@ -791,7 +791,7 @@ func TestIntegration_RemoveModel(t *testing.T) { // Create and push a test model with default org (ai/rm-test:latest) modelRef := "ai/rm-test:latest" - modelID, hostFQDN, networkFQDN, digest := createAndPushTestModel(t, env.registryURL, modelRef, 2048) + modelID, hostFQDN, networkFQDN, digest := createAndPushTestModel(t, env.registryURL, modelRef, int32ptr(2048)) t.Logf("Test model pushed: %s (ID: %s) FQDN: %s Digest: %s", hostFQDN, modelID, networkFQDN, digest) // Generate all reference test cases @@ -842,9 +842,9 @@ func TestIntegration_RemoveModel(t *testing.T) { t.Run("remove multiple models", func(t *testing.T) { // Create and push two different models modelRef1 := "ai/rm-multi-1:latest" - modelID1, _, _, _ := createAndPushTestModel(t, env.registryURL, modelRef1, 2048) + modelID1, _, _, _ := createAndPushTestModel(t, env.registryURL, modelRef1, int32ptr(2048)) modelRef2 := "ai/rm-multi-2:latest" - modelID2, _, _, _ := createAndPushTestModel(t, env.registryURL, modelRef2, 2048) + modelID2, _, _, _ := createAndPushTestModel(t, env.registryURL, modelRef2, int32ptr(2048)) // Pull both models t.Logf("Pulling first model: rm-multi-1") @@ -1014,3 +1014,7 @@ func TestIntegration_RemoveModel(t *testing.T) { }) }) } + +func int32ptr(n int32) *int32 { + return &n +} diff --git a/cmd/cli/commands/package.go b/cmd/cli/commands/package.go index 48f3de95c..788d6c8cc 100644 --- a/cmd/cli/commands/package.go +++ b/cmd/cli/commands/package.go @@ -284,9 +284,9 @@ func packageModel(cmd *cobra.Command, opts packageOptions) error { distClient := initResult.distClient // Set context size - if opts.contextSize > 0 { + if cmd.Flags().Changed("context-size") { cmd.PrintErrf("Setting context size %d\n", opts.contextSize) - pkg = pkg.WithContextSize(opts.contextSize) + pkg = pkg.WithContextSize(int32(opts.contextSize)) } // Add license files diff --git a/cmd/cli/docs/reference/docker_model_configure.yaml b/cmd/cli/docs/reference/docker_model_configure.yaml index 09e5d46f9..728af82f1 100644 --- a/cmd/cli/docs/reference/docker_model_configure.yaml +++ b/cmd/cli/docs/reference/docker_model_configure.yaml @@ -6,8 +6,7 @@ pname: docker model plink: docker_model.yaml options: - option: context-size - value_type: int64 - default_value: "-1" + value_type: int32 description: context size (in tokens) deprecated: false hidden: false @@ -25,8 +24,7 @@ options: kubernetes: false swarm: false - option: reasoning-budget - value_type: int64 - default_value: "0" + value_type: int32 description: reasoning budget for reasoning models - llama.cpp only deprecated: false hidden: false diff --git a/cmd/mdltool/main.go b/cmd/mdltool/main.go index f905e721a..3c0c976d2 100644 --- a/cmd/mdltool/main.go +++ b/cmd/mdltool/main.go @@ -321,7 +321,7 @@ func cmdPackage(args []string) int { if contextSize > 0 { fmt.Println("Setting context size:", contextSize) - b = b.WithContextSize(contextSize) + b = b.WithContextSize(int32(contextSize)) } if mmproj != "" { diff --git a/pkg/distribution/builder/builder.go b/pkg/distribution/builder/builder.go index f1023ecbd..44aed7c99 100644 --- a/pkg/distribution/builder/builder.go +++ b/pkg/distribution/builder/builder.go @@ -67,7 +67,7 @@ func (b *Builder) WithLicense(path string) (*Builder, error) { }, nil } -func (b *Builder) WithContextSize(size uint64) *Builder { +func (b *Builder) WithContextSize(size int32) *Builder { return &Builder{ model: mutate.ContextSize(b.model, size), originalLayers: b.originalLayers, diff --git a/pkg/distribution/internal/mutate/model.go b/pkg/distribution/internal/mutate/model.go index 1d4b61d8b..1db1825ba 100644 --- a/pkg/distribution/internal/mutate/model.go +++ b/pkg/distribution/internal/mutate/model.go @@ -16,7 +16,7 @@ type model struct { base types.ModelArtifact appended []v1.Layer configMediaType ggcr.MediaType - contextSize *uint64 + contextSize *int32 } func (m *model) Descriptor() (types.Descriptor, error) { diff --git a/pkg/distribution/internal/mutate/mutate.go b/pkg/distribution/internal/mutate/mutate.go index b0baaa9d6..ae652788b 100644 --- a/pkg/distribution/internal/mutate/mutate.go +++ b/pkg/distribution/internal/mutate/mutate.go @@ -21,7 +21,7 @@ func ConfigMediaType(mdl types.ModelArtifact, mt ggcr.MediaType) types.ModelArti } } -func ContextSize(mdl types.ModelArtifact, cs uint64) types.ModelArtifact { +func ContextSize(mdl types.ModelArtifact, cs int32) types.ModelArtifact { return &model{ base: mdl, contextSize: &cs, diff --git a/pkg/distribution/internal/mutate/mutate_test.go b/pkg/distribution/internal/mutate/mutate_test.go index 09b3f254f..f5fcb1a2d 100644 --- a/pkg/distribution/internal/mutate/mutate_test.go +++ b/pkg/distribution/internal/mutate/mutate_test.go @@ -108,7 +108,7 @@ func TestContextSize(t *testing.T) { if cfg2.ContextSize == nil { t.Fatal("Expected non-nil context") } - if *cfg2.ContextSize != uint64(2096) { + if *cfg2.ContextSize != 2096 { t.Fatalf("Expected context size of 2096 got %d", *cfg2.ContextSize) } } diff --git a/pkg/distribution/internal/store/store_test.go b/pkg/distribution/internal/store/store_test.go index 4139dd0ae..9b5454f65 100644 --- a/pkg/distribution/internal/store/store_test.go +++ b/pkg/distribution/internal/store/store_test.go @@ -1040,7 +1040,7 @@ func TestWriteLightweight(t *testing.T) { } // Modify the model's config by changing context size - newContextSize := uint64(4096) + newContextSize := int32(4096) modifiedModel := mutate.ContextSize(baseModel, newContextSize) // Use WriteLightweight to write the modified model @@ -1135,7 +1135,7 @@ func TestWriteLightweight(t *testing.T) { } // Create a variant with different config - newContextSize := uint64(8192) + newContextSize := int32(8192) variant := mutate.ContextSize(baseModel, newContextSize) // Use WriteLightweight with multiple tags @@ -1213,7 +1213,7 @@ func TestWriteLightweight(t *testing.T) { } // Create a variant with different context size - newContextSize := uint64(2048) + newContextSize := int32(2048) variant := mutate.ContextSize(baseModel, newContextSize) // Use WriteLightweight for the variant @@ -1271,7 +1271,7 @@ func TestWriteLightweight(t *testing.T) { // Create multiple variants using WriteLightweight for i := 1; i <= 3; i++ { - contextSize := uint64(1024 * i) + contextSize := int32(1024 * i) variant := mutate.ContextSize(baseModel, contextSize) tag := fmt.Sprintf("integrity-test:variant%d", i) if err := s.WriteLightweight(variant, []string{tag}); err != nil { diff --git a/pkg/distribution/types/config.go b/pkg/distribution/types/config.go index 9311e6666..62e45ebad 100644 --- a/pkg/distribution/types/config.go +++ b/pkg/distribution/types/config.go @@ -67,7 +67,7 @@ type Config struct { Size string `json:"size,omitempty"` GGUF map[string]string `json:"gguf,omitempty"` Safetensors map[string]string `json:"safetensors,omitempty"` - ContextSize *uint64 `json:"context_size,omitempty"` + ContextSize *int32 `json:"context_size,omitempty"` } // Descriptor provides metadata about the provenance of the model. diff --git a/pkg/inference/backend.go b/pkg/inference/backend.go index d470bbe76..df3a5c185 100644 --- a/pkg/inference/backend.go +++ b/pkg/inference/backend.go @@ -57,12 +57,12 @@ type VLLMConfig struct { type LlamaCppConfig struct { // ReasoningBudget sets the reasoning budget for reasoning models. // Maps to llama.cpp's --reasoning-budget flag. - ReasoningBudget *int64 `json:"reasoning-budget,omitempty"` + ReasoningBudget *int32 `json:"reasoning-budget,omitempty"` } type BackendConfiguration struct { // Shared configuration across all backends - ContextSize int64 `json:"context-size,omitempty"` + ContextSize *int32 `json:"context-size,omitempty"` Speculative *SpeculativeDecodingConfig `json:"speculative,omitempty"` // Backend-specific configuration diff --git a/pkg/inference/backends/llamacpp/llamacpp.go b/pkg/inference/backends/llamacpp/llamacpp.go index dbff7cfbf..8c8ba6f76 100644 --- a/pkg/inference/backends/llamacpp/llamacpp.go +++ b/pkg/inference/backends/llamacpp/llamacpp.go @@ -198,7 +198,11 @@ func (l *llamaCpp) GetRequiredMemoryForModel(ctx context.Context, model string, return inference.RequiredMemory{}, &inference.ErrGGUFParse{Err: err} } - contextSize := GetContextSize(mdlConfig, config) + configuredContextSize := GetContextSize(mdlConfig, config) + contextSize := int32(4096) // default context size + if configuredContextSize != nil { + contextSize = *configuredContextSize + } var ngl uint64 if l.gpuSupported { @@ -240,9 +244,9 @@ func (l *llamaCpp) parseModel(ctx context.Context, model string) (*parser.GGUFFi } // estimateMemoryFromGGUF estimates memory requirements from a parsed GGUF file. -func (l *llamaCpp) estimateMemoryFromGGUF(ggufFile *parser.GGUFFile, contextSize uint64, ngl uint64) inference.RequiredMemory { +func (l *llamaCpp) estimateMemoryFromGGUF(ggufFile *parser.GGUFFile, contextSize int32, ngl uint64) inference.RequiredMemory { estimate := ggufFile.EstimateLLaMACppRun( - parser.WithLLaMACppContextSize(int32(contextSize)), + parser.WithLLaMACppContextSize(contextSize), parser.WithLLaMACppLogicalBatchSize(2048), parser.WithLLaMACppOffloadLayers(ngl), ) diff --git a/pkg/inference/backends/llamacpp/llamacpp_config.go b/pkg/inference/backends/llamacpp/llamacpp_config.go index ad45e7da0..8375eff1e 100644 --- a/pkg/inference/backends/llamacpp/llamacpp_config.go +++ b/pkg/inference/backends/llamacpp/llamacpp_config.go @@ -9,6 +9,8 @@ import ( "github.com/docker/model-runner/pkg/inference" ) +const UnlimitedContextSize = -1 + // Config is the configuration for the llama.cpp backend. type Config struct { // Args are the base arguments that are always included. @@ -68,11 +70,14 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference } if budget := GetReasoningBudget(config); budget != nil { - args = append(args, "--reasoning-budget", strconv.FormatInt(*budget, 10)) + args = append(args, "--reasoning-budget", strconv.FormatInt(int64(*budget), 10)) } // Add context size from model config or backend config - args = append(args, "--ctx-size", strconv.FormatUint(GetContextSize(bundle.RuntimeConfig(), config), 10)) + contextSize := GetContextSize(bundle.RuntimeConfig(), config) + if contextSize != nil { + args = append(args, "--ctx-size", strconv.FormatInt(int64(*contextSize), 10)) + } // Add arguments for Multimodal projector or jinja (they are mutually exclusive) if path := bundle.MMPROJPath(); path != "" { @@ -84,20 +89,19 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference return args, nil } -func GetContextSize(modelCfg types.Config, backendCfg *inference.BackendConfiguration) uint64 { +func GetContextSize(modelCfg types.Config, backendCfg *inference.BackendConfiguration) *int32 { // Model config takes precedence - if modelCfg.ContextSize != nil { - return *modelCfg.ContextSize + if modelCfg.ContextSize != nil && (*modelCfg.ContextSize == UnlimitedContextSize || *modelCfg.ContextSize > 0) { + return modelCfg.ContextSize } - // else use backend config - if backendCfg != nil && backendCfg.ContextSize > 0 { - return uint64(backendCfg.ContextSize) + // Fallback to backend config + if backendCfg != nil && backendCfg.ContextSize != nil && (*backendCfg.ContextSize == UnlimitedContextSize || *backendCfg.ContextSize > 0) { + return backendCfg.ContextSize } - // finally return default - return 4096 // llama.cpp default + return nil } -func GetReasoningBudget(backendCfg *inference.BackendConfiguration) *int64 { +func GetReasoningBudget(backendCfg *inference.BackendConfiguration) *int32 { if backendCfg != nil && backendCfg.LlamaCpp != nil && backendCfg.LlamaCpp.ReasoningBudget != nil { return backendCfg.LlamaCpp.ReasoningBudget } diff --git a/pkg/inference/backends/llamacpp/llamacpp_config_test.go b/pkg/inference/backends/llamacpp/llamacpp_config_test.go index b00e0cae3..5d04ad3e5 100644 --- a/pkg/inference/backends/llamacpp/llamacpp_config_test.go +++ b/pkg/inference/backends/llamacpp/llamacpp_config_test.go @@ -109,7 +109,6 @@ func TestGetArgs(t *testing.T) { expected: append(slices.Clone(baseArgs), "--model", modelPath, "--host", socket, - "--ctx-size", "4096", "--jinja", ), }, @@ -123,7 +122,6 @@ func TestGetArgs(t *testing.T) { "--model", modelPath, "--host", socket, "--embeddings", - "--ctx-size", "4096", "--jinja", ), }, @@ -134,7 +132,7 @@ func TestGetArgs(t *testing.T) { ggufPath: modelPath, }, config: &inference.BackendConfiguration{ - ContextSize: 1234, + ContextSize: int32ptr(1234), }, expected: append(slices.Clone(baseArgs), "--model", modelPath, @@ -144,17 +142,66 @@ func TestGetArgs(t *testing.T) { "--jinja", ), }, + { + name: "unlimited context size from backend config", + mode: inference.BackendModeEmbedding, + bundle: &fakeBundle{ + ggufPath: modelPath, + }, + config: &inference.BackendConfiguration{ + ContextSize: int32ptr(-1), + }, + expected: append(slices.Clone(baseArgs), + "--model", modelPath, + "--host", socket, + "--embeddings", + "--ctx-size", "-1", + "--jinja", + ), + }, + { + name: "0 context size from backend config ignored", + mode: inference.BackendModeEmbedding, + bundle: &fakeBundle{ + ggufPath: modelPath, + }, + config: &inference.BackendConfiguration{ + ContextSize: int32ptr(0), + }, + expected: append(slices.Clone(baseArgs), + "--model", modelPath, + "--host", socket, + "--embeddings", + "--jinja", + ), + }, + { + name: "invalid context size from backend config ignored", + mode: inference.BackendModeEmbedding, + bundle: &fakeBundle{ + ggufPath: modelPath, + }, + config: &inference.BackendConfiguration{ + ContextSize: int32ptr(-2), + }, + expected: append(slices.Clone(baseArgs), + "--model", modelPath, + "--host", socket, + "--embeddings", + "--jinja", + ), + }, { name: "context size from model config", mode: inference.BackendModeEmbedding, bundle: &fakeBundle{ ggufPath: modelPath, config: types.Config{ - ContextSize: uint64ptr(2096), + ContextSize: int32ptr(2096), }, }, config: &inference.BackendConfiguration{ - ContextSize: 1234, + ContextSize: int32ptr(1234), }, expected: append(slices.Clone(baseArgs), "--model", modelPath, @@ -175,7 +222,6 @@ func TestGetArgs(t *testing.T) { "--model", modelPath, "--host", socket, "--chat-template-file", "/path/to/bundle/template.jinja", - "--ctx-size", "4096", "--jinja", ), }, @@ -189,7 +235,6 @@ func TestGetArgs(t *testing.T) { expected: append(slices.Clone(baseArgs), "--model", modelPath, "--host", socket, - "--ctx-size", "4096", "--mmproj", "/path/to/model.mmproj", ), }, @@ -201,14 +246,13 @@ func TestGetArgs(t *testing.T) { }, config: &inference.BackendConfiguration{ LlamaCpp: &inference.LlamaCppConfig{ - ReasoningBudget: int64ptr(1024), + ReasoningBudget: int32ptr(1024), }, }, expected: append(slices.Clone(baseArgs), "--model", modelPath, "--host", socket, "--reasoning-budget", "1024", - "--ctx-size", "4096", "--jinja", ), }, @@ -220,14 +264,13 @@ func TestGetArgs(t *testing.T) { }, config: &inference.BackendConfiguration{ LlamaCpp: &inference.LlamaCppConfig{ - ReasoningBudget: int64ptr(-1), + ReasoningBudget: int32ptr(-1), }, }, expected: append(slices.Clone(baseArgs), "--model", modelPath, "--host", socket, "--reasoning-budget", "-1", - "--ctx-size", "4096", "--jinja", ), }, @@ -243,7 +286,6 @@ func TestGetArgs(t *testing.T) { expected: append(slices.Clone(baseArgs), "--model", modelPath, "--host", socket, - "--ctx-size", "4096", "--jinja", ), }, @@ -261,7 +303,6 @@ func TestGetArgs(t *testing.T) { expected: append(slices.Clone(baseArgs), "--model", modelPath, "--host", socket, - "--ctx-size", "4096", "--jinja", ), }, @@ -272,9 +313,9 @@ func TestGetArgs(t *testing.T) { ggufPath: modelPath, }, config: &inference.BackendConfiguration{ - ContextSize: 8192, + ContextSize: int32ptr(8192), LlamaCpp: &inference.LlamaCppConfig{ - ReasoningBudget: int64ptr(2048), + ReasoningBudget: int32ptr(2048), }, }, expected: append(slices.Clone(baseArgs), @@ -395,10 +436,6 @@ func (f *fakeBundle) RuntimeConfig() types.Config { return f.config } -func uint64ptr(n uint64) *uint64 { - return &n -} - -func int64ptr(n int64) *int64 { +func int32ptr(n int32) *int32 { return &n } diff --git a/pkg/inference/backends/mlx/mlx_config.go b/pkg/inference/backends/mlx/mlx_config.go index 97b74ce43..025f5b89c 100644 --- a/pkg/inference/backends/mlx/mlx_config.go +++ b/pkg/inference/backends/mlx/mlx_config.go @@ -65,15 +65,5 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference // Model config takes precedence over backend config. // Returns nil if neither is specified (MLX will use model defaults). func GetMaxTokens(modelCfg types.Config, backendCfg *inference.BackendConfiguration) *uint64 { - // Model config takes precedence - if modelCfg.ContextSize != nil { - return modelCfg.ContextSize - } - // else use backend config - if backendCfg != nil && backendCfg.ContextSize > 0 { - val := uint64(backendCfg.ContextSize) - return &val - } - // Return nil to let MLX use model defaults return nil } diff --git a/pkg/inference/backends/mlx/mlx_config_test.go b/pkg/inference/backends/mlx/mlx_config_test.go deleted file mode 100644 index 68d835e6b..000000000 --- a/pkg/inference/backends/mlx/mlx_config_test.go +++ /dev/null @@ -1,212 +0,0 @@ -package mlx - -import ( - "testing" - - "github.com/docker/model-runner/pkg/distribution/types" - "github.com/docker/model-runner/pkg/inference" -) - -type mockModelBundle struct { - safetensorsPath string - runtimeConfig types.Config -} - -func (m *mockModelBundle) GGUFPath() string { - return "" -} - -func (m *mockModelBundle) SafetensorsPath() string { - return m.safetensorsPath -} - -func (m *mockModelBundle) ChatTemplatePath() string { - return "" -} - -func (m *mockModelBundle) MMPROJPath() string { - return "" -} - -func (m *mockModelBundle) RuntimeConfig() types.Config { - return m.runtimeConfig -} - -func (m *mockModelBundle) RootDir() string { - return "/path/to/bundle" -} - -func TestGetArgs(t *testing.T) { - tests := []struct { - name string - config *inference.BackendConfiguration - bundle *mockModelBundle - expected []string - expectError bool - }{ - { - name: "empty safetensors path should error", - bundle: &mockModelBundle{ - safetensorsPath: "", - }, - config: nil, - expected: nil, - expectError: true, - }, - { - name: "basic args without context size", - bundle: &mockModelBundle{ - safetensorsPath: "/path/to/model", - }, - config: nil, - expected: []string{ - "-m", - "mlx_lm.server", - "--model", - "/path/to", - "--host", - "/tmp/socket", - }, - }, - { - name: "with backend context size", - bundle: &mockModelBundle{ - safetensorsPath: "/path/to/model", - }, - config: &inference.BackendConfiguration{ - ContextSize: 8192, - }, - expected: []string{ - "-m", - "mlx_lm.server", - "--model", - "/path/to", - "--host", - "/tmp/socket", - "--max-tokens", - "8192", - }, - }, - { - name: "with model context size (takes precedence)", - bundle: &mockModelBundle{ - safetensorsPath: "/path/to/model", - runtimeConfig: types.Config{ - ContextSize: ptrUint64(16384), - }, - }, - config: &inference.BackendConfiguration{ - ContextSize: 8192, - }, - expected: []string{ - "-m", - "mlx_lm.server", - "--model", - "/path/to", - "--host", - "/tmp/socket", - "--max-tokens", - "16384", - }, - }, - { - name: "reranking mode should error", - bundle: &mockModelBundle{ - safetensorsPath: "/path/to/model", - }, - config: nil, - expected: nil, - expectError: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - config := NewDefaultMLXConfig() - mode := inference.BackendModeCompletion - // For the reranking test case, use reranking mode - if tt.name == "reranking mode should error" { - mode = inference.BackendModeReranking - } - args, err := config.GetArgs(tt.bundle, "/tmp/socket", mode, tt.config) - - if tt.expectError { - if err == nil { - t.Fatalf("expected error but got none") - } - return - } - - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if len(args) != len(tt.expected) { - t.Fatalf("expected %d args, got %d\nexpected: %v\ngot: %v", len(tt.expected), len(args), tt.expected, args) - } - - for i, arg := range args { - if arg != tt.expected[i] { - t.Errorf("arg[%d]: expected %q, got %q", i, tt.expected[i], arg) - } - } - }) - } -} - -func TestGetMaxTokens(t *testing.T) { - tests := []struct { - name string - modelCfg types.Config - backendCfg *inference.BackendConfiguration - expectedValue *uint64 - }{ - { - name: "no config", - modelCfg: types.Config{}, - backendCfg: nil, - expectedValue: nil, - }, - { - name: "backend config only", - modelCfg: types.Config{}, - backendCfg: &inference.BackendConfiguration{ - ContextSize: 4096, - }, - expectedValue: ptrUint64(4096), - }, - { - name: "model config only", - modelCfg: types.Config{ - ContextSize: ptrUint64(8192), - }, - backendCfg: nil, - expectedValue: ptrUint64(8192), - }, - { - name: "model config takes precedence", - modelCfg: types.Config{ - ContextSize: ptrUint64(16384), - }, - backendCfg: &inference.BackendConfiguration{ - ContextSize: 4096, - }, - expectedValue: ptrUint64(16384), - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := GetMaxTokens(tt.modelCfg, tt.backendCfg) - if (result == nil) != (tt.expectedValue == nil) { - t.Errorf("expected nil=%v, got nil=%v", tt.expectedValue == nil, result == nil) - } else if result != nil && *result != *tt.expectedValue { - t.Errorf("expected %d, got %d", *tt.expectedValue, *result) - } - }) - } -} - -func ptrUint64(v uint64) *uint64 { - return &v -} diff --git a/pkg/inference/backends/vllm/vllm_config.go b/pkg/inference/backends/vllm/vllm_config.go index fd591c475..d1e692ae2 100644 --- a/pkg/inference/backends/vllm/vllm_config.go +++ b/pkg/inference/backends/vllm/vllm_config.go @@ -54,7 +54,7 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference // Add max-model-len if specified in model config or backend config if maxLen := GetMaxModelLen(bundle.RuntimeConfig(), config); maxLen != nil { - args = append(args, "--max-model-len", strconv.FormatUint(*maxLen, 10)) + args = append(args, "--max-model-len", strconv.FormatInt(int64(*maxLen), 10)) } // If nil, vLLM will automatically derive from the model config @@ -76,15 +76,14 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference // GetMaxModelLen returns the max model length (context size) from model config or backend config. // Model config takes precedence over backend config. // Returns nil if neither is specified (vLLM will auto-derive from model). -func GetMaxModelLen(modelCfg types.Config, backendCfg *inference.BackendConfiguration) *uint64 { +func GetMaxModelLen(modelCfg types.Config, backendCfg *inference.BackendConfiguration) *int32 { // Model config takes precedence if modelCfg.ContextSize != nil { return modelCfg.ContextSize } - // else use backend config - if backendCfg != nil && backendCfg.ContextSize > 0 { - val := uint64(backendCfg.ContextSize) - return &val + // Fallback to backend config + if backendCfg != nil && backendCfg.ContextSize != nil && *backendCfg.ContextSize > 0 { + return backendCfg.ContextSize } // Return nil to let vLLM auto-derive from model config return nil diff --git a/pkg/inference/backends/vllm/vllm_config_test.go b/pkg/inference/backends/vllm/vllm_config_test.go index e7b3844e1..afdd31963 100644 --- a/pkg/inference/backends/vllm/vllm_config_test.go +++ b/pkg/inference/backends/vllm/vllm_config_test.go @@ -72,7 +72,7 @@ func TestGetArgs(t *testing.T) { safetensorsPath: "/path/to/model", }, config: &inference.BackendConfiguration{ - ContextSize: 8192, + ContextSize: int32ptr(8192), }, expected: []string{ "serve", @@ -88,11 +88,11 @@ func TestGetArgs(t *testing.T) { bundle: &mockModelBundle{ safetensorsPath: "/path/to/model", runtimeConfig: types.Config{ - ContextSize: ptrUint64(16384), + ContextSize: int32ptr(16384), }, }, config: &inference.BackendConfiguration{ - ContextSize: 8192, + ContextSize: int32ptr(8192), }, expected: []string{ "serve", @@ -185,7 +185,7 @@ func TestGetArgs(t *testing.T) { safetensorsPath: "/path/to/model", }, config: &inference.BackendConfiguration{ - ContextSize: 4096, + ContextSize: int32ptr(4096), VLLM: &inference.VLLMConfig{ HFOverrides: inference.HFOverrides{ "model_type": "bert", @@ -239,7 +239,7 @@ func TestGetMaxModelLen(t *testing.T) { name string modelCfg types.Config backendCfg *inference.BackendConfiguration - expectedValue *uint64 + expectedValue *int32 }{ { name: "no config", @@ -251,27 +251,27 @@ func TestGetMaxModelLen(t *testing.T) { name: "backend config only", modelCfg: types.Config{}, backendCfg: &inference.BackendConfiguration{ - ContextSize: 4096, + ContextSize: int32ptr(4096), }, - expectedValue: ptrUint64(4096), + expectedValue: int32ptr(4096), }, { name: "model config only", modelCfg: types.Config{ - ContextSize: ptrUint64(8192), + ContextSize: int32ptr(8192), }, backendCfg: nil, - expectedValue: ptrUint64(8192), + expectedValue: int32ptr(8192), }, { name: "model config takes precedence", modelCfg: types.Config{ - ContextSize: ptrUint64(16384), + ContextSize: int32ptr(16384), }, backendCfg: &inference.BackendConfiguration{ - ContextSize: 4096, + ContextSize: int32ptr(4096), }, - expectedValue: ptrUint64(16384), + expectedValue: int32ptr(16384), }, } @@ -287,6 +287,6 @@ func TestGetMaxModelLen(t *testing.T) { } } -func ptrUint64(v uint64) *uint64 { - return &v +func int32ptr(n int32) *int32 { + return &n } diff --git a/pkg/inference/models/api.go b/pkg/inference/models/api.go index dd01c12e1..32e8c4a30 100644 --- a/pkg/inference/models/api.go +++ b/pkg/inference/models/api.go @@ -29,7 +29,7 @@ type ModelPackageRequest struct { // Tag is the name to give the new packaged model. Tag string `json:"tag"` // ContextSize specifies the context size to set for the new model. - ContextSize uint64 `json:"context-size,omitempty"` + ContextSize *int32 `json:"context-size,omitempty"` } // SimpleModel is a wrapper that allows creating a model with modified configuration diff --git a/pkg/inference/models/manager.go b/pkg/inference/models/manager.go index 2ffdc6314..a227e7e16 100644 --- a/pkg/inference/models/manager.go +++ b/pkg/inference/models/manager.go @@ -419,7 +419,7 @@ func (m *Manager) Push(model string, r *http.Request, w http.ResponseWriter) err return nil } -func (m *Manager) Package(ref string, tag string, contextSize uint64) error { +func (m *Manager) Package(ref string, tag string, contextSize *int32) error { // Create a builder from an existing model by getting the bundle first // Since ModelArtifact interface is needed to work with the builder bundle, err := m.distributionClient.GetBundle(ref) @@ -440,8 +440,8 @@ func (m *Manager) Package(ref string, tag string, contextSize uint64) error { } // Apply context size if specified - if contextSize > 0 { - bldr = bldr.WithContextSize(contextSize) + if contextSize != nil { + bldr = bldr.WithContextSize(*contextSize) } // Get the built model artifact diff --git a/pkg/inference/scheduling/http_handler.go b/pkg/inference/scheduling/http_handler.go index ed1072f7f..67d5cd50a 100644 --- a/pkg/inference/scheduling/http_handler.go +++ b/pkg/inference/scheduling/http_handler.go @@ -330,9 +330,7 @@ func (h *HTTPHandler) Configure(w http.ResponseWriter, r *http.Request) { } configureRequest := ConfigureRequest{ - BackendConfiguration: inference.BackendConfiguration{ - ContextSize: -1, - }, + BackendConfiguration: inference.BackendConfiguration{}, } if err := json.Unmarshal(body, &configureRequest); err != nil { http.Error(w, "invalid request", http.StatusBadRequest) diff --git a/pkg/ollama/http_handler.go b/pkg/ollama/http_handler.go index ce9cea373..3e4a6ab6a 100644 --- a/pkg/ollama/http_handler.go +++ b/pkg/ollama/http_handler.go @@ -21,13 +21,13 @@ import ( // Reasoning budget constants for the think parameter conversion const ( // reasoningBudgetUnlimited represents unlimited reasoning tokens (-1 for llama.cpp) - reasoningBudgetUnlimited int64 = -1 + reasoningBudgetUnlimited int32 = -1 // reasoningBudgetDisabled disables reasoning (0 tokens) - reasoningBudgetDisabled int64 = 0 + reasoningBudgetDisabled int32 = 0 // reasoningBudgetMedium represents a medium reasoning budget (1024 tokens) - reasoningBudgetMedium int64 = 1024 + reasoningBudgetMedium int32 = 1024 // reasoningBudgetLow represents a low reasoning budget (256 tokens) - reasoningBudgetLow int64 = 256 + reasoningBudgetLow int32 = 256 ) // Reasoning level string constants for the think parameter @@ -405,15 +405,14 @@ func (h *HTTPHandler) handleChat(w http.ResponseWriter, r *http.Request) { // configureModel extracts and applies model configuration options. // Handles num_ctx from options and think parameter for reasoning budget. -// Returns the context size for use in preloading scenarios. -func (h *HTTPHandler) configureModel(ctx context.Context, modelName string, options map[string]interface{}, think interface{}, userAgent string) int64 { - var contextSize int64 +func (h *HTTPHandler) configureModel(ctx context.Context, modelName string, options map[string]interface{}, think interface{}, userAgent string) { + var contextSize int32 var hasContextSize bool // Extract context size from options if options != nil { - if numCtxRaw, ok := options["num_ctx"]; ok { - contextSize = convertToInt64(numCtxRaw) + if numCtxRaw, ok := options["num_ctx"]; ok && numCtxRaw != nil { + contextSize = convertToInt32(numCtxRaw) hasContextSize = true } } @@ -424,23 +423,13 @@ func (h *HTTPHandler) configureModel(ctx context.Context, modelName string, opti // Only call ConfigureRunner if we have something to configure if hasContextSize || reasoningBudget != nil { sanitizedModelName := utils.SanitizeForLog(modelName, -1) - // Build reasoning budget string for logging (show "nil" when not specified) - var budgetStr string - if reasoningBudget != nil { - budgetStr = fmt.Sprintf("%d", *reasoningBudget) - } else { - budgetStr = "nil" - } - sanitizedContextSize := utils.SanitizeForLog(fmt.Sprintf("%d", contextSize), -1) - h.log.Infof("configureModel: configuring model %s (context_size=%s, has_context_size=%t, reasoning_budget=%s)", - sanitizedModelName, sanitizedContextSize, hasContextSize, budgetStr) - + h.log.Infof("configureModel: configuring model %s", sanitizedModelName) configureRequest := scheduling.ConfigureRequest{ Model: modelName, } // Only include ContextSize if explicitly defined if hasContextSize { - configureRequest.ContextSize = contextSize + configureRequest.ContextSize = &contextSize } // Set llama.cpp-specific reasoning budget if provided if reasoningBudget != nil { @@ -448,14 +437,12 @@ func (h *HTTPHandler) configureModel(ctx context.Context, modelName string, opti ReasoningBudget: reasoningBudget, } } - _, err := h.scheduler.ConfigureRunner(ctx, nil, configureRequest, userAgent) + _, err := h.scheduler.ConfigureRunner(ctx, nil, configureRequest, userAgent) // TODO add backend selection? if err != nil { // Log the error but continue with the request h.log.Warnf("configureModel: failed to configure model %s: %v", sanitizedModelName, err) } } - - return contextSize } // isZeroKeepAlive checks if the keep-alive duration string represents zero duration. @@ -481,33 +468,19 @@ func (h *HTTPHandler) handleGenerate(w http.ResponseWriter, r *http.Request) { modelName = req.Model } + // Normalize model name + modelName = models.NormalizeModelName(modelName) + if req.Prompt == "" && isZeroKeepAlive(req.KeepAlive) { h.unloadModel(ctx, w, modelName) return } // Configure model - ctxSize := h.configureModel(ctx, modelName, req.Options, req.Think, r.UserAgent()+" (Ollama API)") + h.configureModel(ctx, modelName, req.Options, req.Think, r.UserAgent()+" (Ollama API)") if req.Prompt == "" { - // Empty prompt - preload the model - // ConfigureRunner is idempotent, so calling it again with the same context size is safe - configureRequest := scheduling.ConfigureRequest{ - Model: modelName, - BackendConfiguration: inference.BackendConfiguration{ - ContextSize: ctxSize, // Use extracted value (or 0 for default) - }, - } - - _, err := h.scheduler.ConfigureRunner(ctx, nil, configureRequest, r.UserAgent()+" (Ollama API)") - if err != nil { - sanitizedErr := utils.SanitizeForLog(err.Error(), -1) - sanitizedModelName := utils.SanitizeForLog(modelName, -1) - h.log.Warnf("handleGenerate: failed to preload model %s: %v", sanitizedModelName, sanitizedErr) - http.Error(w, fmt.Sprintf("Failed to preload model: %v", err), http.StatusInternalServerError) - return - } - + // Empty prompt - preload the model (already configured above) // Return success response in Ollama format (empty JSON object) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) @@ -515,9 +488,6 @@ func (h *HTTPHandler) handleGenerate(w http.ResponseWriter, r *http.Request) { return } - // Normalize model name - modelName = models.NormalizeModelName(modelName) - // Convert to OpenAI format completion request openAIReq := map[string]interface{}{ "model": modelName, @@ -809,13 +779,13 @@ func convertMessages(messages []Message) []map[string]interface{} { // - bool: true (unlimited reasoning, -1) or false (no reasoning, 0) // - string: "high" (-1, unlimited), "medium" (1024 tokens), "low" (256 tokens) // Returns nil if think is nil or invalid, otherwise returns a pointer to the reasoning_budget value. -func convertThinkToReasoningBudget(think interface{}) *int64 { +func convertThinkToReasoningBudget(think interface{}) *int32 { if think == nil { return nil } - // Helper to create a pointer to an int64 value - ptr := func(v int64) *int64 { return &v } + // Helper to create a pointer to an int32 value + ptr := func(v int32) *int32 { return &v } switch v := think.(type) { case bool: @@ -839,22 +809,22 @@ func convertThinkToReasoningBudget(think interface{}) *int64 { } } -// convertToInt64 converts various numeric types to int64 -func convertToInt64(v interface{}) int64 { +// convertToInt32 converts various numeric types to int32 +func convertToInt32(v interface{}) int32 { switch val := v.(type) { case int: - return int64(val) - case int64: + return int32(val) + case int32: return val case float64: - return int64(val) + return int32(val) case float32: - return int64(val) + return int32(val) case string: // Sanitize string to remove newline/carriage return before parsing safeVal := utils.SanitizeForLog(val, -1) - if num, err := fmt.Sscanf(safeVal, "%d", new(int64)); err == nil && num == 1 { - var result int64 + if num, err := fmt.Sscanf(safeVal, "%d", new(int32)); err == nil && num == 1 { + var result int32 fmt.Sscanf(safeVal, "%d", &result) return result }