Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions cmd/cli/commands/compose.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func newUpCommand() *cobra.Command {
return err
}

if ctxSize > 0 {
if cmd.Flags().Changed("context-size") {
sendInfo(fmt.Sprintf("Setting context size to %d", ctxSize))
}

Expand All @@ -82,12 +82,18 @@ func newUpCommand() *cobra.Command {
}

for _, model := range models {
configuration := inference.BackendConfiguration{
Speculative: speculativeConfig,
}
if cmd.Flags().Changed("context-size") {
//TODO is the context size the same for all models?
v := int32(ctxSize)
configuration.ContextSize = &v
}

if err := desktopClient.ConfigureBackend(scheduling.ConfigureRequest{
Model: model,
BackendConfiguration: inference.BackendConfiguration{
ContextSize: ctxSize,
Speculative: speculativeConfig,
},
Model: model,
BackendConfiguration: configuration,
}); err != nil {
configErrFmtString := "failed to configure backend for model %s with context-size %d"
_ = sendErrorf(configErrFmtString+": %v", model, ctxSize, err)
Expand Down
10 changes: 8 additions & 2 deletions cmd/cli/commands/configure.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ func newConfigureCmd() *cobra.Command {
var numTokens int
var minAcceptanceRate float64
var hfOverrides string
var contextSize int64
var reasoningBudget int64

c := &cobra.Command{
Expand All @@ -34,6 +35,10 @@ func newConfigureCmd() *cobra.Command {
return nil
},
RunE: func(cmd *cobra.Command, args []string) error {
if cmd.Flags().Changed("context-size") {
v := int32(contextSize)
opts.ContextSize = &v
}
// Build the speculative config if any speculative flags are set
if draftModel != "" || numTokens > 0 || minAcceptanceRate > 0 {
opts.Speculative = &inference.SpeculativeDecodingConfig{
Expand Down Expand Up @@ -64,14 +69,15 @@ func newConfigureCmd() *cobra.Command {
if opts.LlamaCpp == nil {
opts.LlamaCpp = &inference.LlamaCppConfig{}
}
opts.LlamaCpp.ReasoningBudget = &reasoningBudget
v := int32(reasoningBudget)
opts.LlamaCpp.ReasoningBudget = &v
}
return desktopClient.ConfigureBackend(opts)
},
ValidArgsFunction: completion.ModelNames(getDesktopClient, -1),
}

c.Flags().Int64Var(&opts.ContextSize, "context-size", -1, "context size (in tokens)")
c.Flags().Int64Var(&contextSize, "context-size", 0, "context size (in tokens)")
c.Flags().StringVar(&draftModel, "speculative-draft-model", "", "draft model for speculative decoding")
c.Flags().IntVar(&numTokens, "speculative-num-tokens", 0, "number of tokens to predict speculatively")
c.Flags().Float64Var(&minAcceptanceRate, "speculative-min-acceptance-rate", 0, "minimum acceptance rate for speculative decoding")
Expand Down
4 changes: 2 additions & 2 deletions cmd/cli/commands/package.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,9 +284,9 @@ func packageModel(cmd *cobra.Command, opts packageOptions) error {
distClient := initResult.distClient

// Set context size
if opts.contextSize > 0 {
if cmd.Flags().Changed("context-size") {
cmd.PrintErrf("Setting context size %d\n", opts.contextSize)
pkg = pkg.WithContextSize(opts.contextSize)
pkg = pkg.WithContextSize(int32(opts.contextSize))
}
Comment on lines +287 to 290
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The conversion from uint64 to int32 for contextSize could lead to an integer overflow if a user provides a value larger than math.MaxInt32. While unlikely for a context size, adding a validation check would make the code more robust.

if cmd.Flags().Changed("context-size") {
		if opts.contextSize > 2147483647 { // math.MaxInt32
			return fmt.Errorf("context size %d is too large, must be less than or equal to 2147483647", opts.contextSize)
		}
		cmd.PrintErrf("Setting context size %d\n", opts.contextSize)
		pkg = pkg.WithContextSize(int32(opts.contextSize))
	}


// Add license files
Expand Down
2 changes: 1 addition & 1 deletion cmd/mdltool/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ func cmdPackage(args []string) int {

if contextSize > 0 {
fmt.Println("Setting context size:", contextSize)
b = b.WithContextSize(contextSize)
b = b.WithContextSize(int32(contextSize))
}
Comment on lines 322 to 325
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to another file, the conversion from uint64 to int32 for contextSize could overflow. It's best to add a check to ensure the value is within the valid range for an int32.

if contextSize > 0 {
		if contextSize > 2147483647 { // math.MaxInt32
			fmt.Fprintf(os.Stderr, "context size %d is too large, must be less than or equal to 2147483647\n", contextSize)
			return 1
		}
		fmt.Println("Setting context size:", contextSize)
		b = b.WithContextSize(int32(contextSize))
	}


if mmproj != "" {
Expand Down
2 changes: 1 addition & 1 deletion pkg/distribution/builder/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (b *Builder) WithLicense(path string) (*Builder, error) {
}, nil
}

func (b *Builder) WithContextSize(size uint64) *Builder {
func (b *Builder) WithContextSize(size int32) *Builder {
return &Builder{
model: mutate.ContextSize(b.model, size),
originalLayers: b.originalLayers,
Expand Down
2 changes: 1 addition & 1 deletion pkg/distribution/internal/mutate/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type model struct {
base types.ModelArtifact
appended []v1.Layer
configMediaType ggcr.MediaType
contextSize *uint64
contextSize *int32
}

func (m *model) Descriptor() (types.Descriptor, error) {
Expand Down
2 changes: 1 addition & 1 deletion pkg/distribution/internal/mutate/mutate.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func ConfigMediaType(mdl types.ModelArtifact, mt ggcr.MediaType) types.ModelArti
}
}

func ContextSize(mdl types.ModelArtifact, cs uint64) types.ModelArtifact {
func ContextSize(mdl types.ModelArtifact, cs int32) types.ModelArtifact {
return &model{
base: mdl,
contextSize: &cs,
Expand Down
2 changes: 1 addition & 1 deletion pkg/distribution/internal/mutate/mutate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func TestContextSize(t *testing.T) {
if cfg2.ContextSize == nil {
t.Fatal("Expected non-nil context")
}
if *cfg2.ContextSize != uint64(2096) {
if *cfg2.ContextSize != 2096 {
t.Fatalf("Expected context size of 2096 got %d", *cfg2.ContextSize)
}
}
8 changes: 4 additions & 4 deletions pkg/distribution/internal/store/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1040,7 +1040,7 @@ func TestWriteLightweight(t *testing.T) {
}

// Modify the model's config by changing context size
newContextSize := uint64(4096)
newContextSize := int32(4096)
modifiedModel := mutate.ContextSize(baseModel, newContextSize)

// Use WriteLightweight to write the modified model
Expand Down Expand Up @@ -1135,7 +1135,7 @@ func TestWriteLightweight(t *testing.T) {
}

// Create a variant with different config
newContextSize := uint64(8192)
newContextSize := int32(8192)
variant := mutate.ContextSize(baseModel, newContextSize)

// Use WriteLightweight with multiple tags
Expand Down Expand Up @@ -1213,7 +1213,7 @@ func TestWriteLightweight(t *testing.T) {
}

// Create a variant with different context size
newContextSize := uint64(2048)
newContextSize := int32(2048)
variant := mutate.ContextSize(baseModel, newContextSize)

// Use WriteLightweight for the variant
Expand Down Expand Up @@ -1271,7 +1271,7 @@ func TestWriteLightweight(t *testing.T) {

// Create multiple variants using WriteLightweight
for i := 1; i <= 3; i++ {
contextSize := uint64(1024 * i)
contextSize := int32(1024 * i)
variant := mutate.ContextSize(baseModel, contextSize)
tag := fmt.Sprintf("integrity-test:variant%d", i)
if err := s.WriteLightweight(variant, []string{tag}); err != nil {
Expand Down
2 changes: 1 addition & 1 deletion pkg/distribution/types/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ type Config struct {
Size string `json:"size,omitempty"`
GGUF map[string]string `json:"gguf,omitempty"`
Safetensors map[string]string `json:"safetensors,omitempty"`
ContextSize *uint64 `json:"context_size,omitempty"`
ContextSize *int32 `json:"context_size,omitempty"`
}

// Descriptor provides metadata about the provenance of the model.
Expand Down
4 changes: 2 additions & 2 deletions pkg/inference/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ type VLLMConfig struct {
type LlamaCppConfig struct {
// ReasoningBudget sets the reasoning budget for reasoning models.
// Maps to llama.cpp's --reasoning-budget flag.
ReasoningBudget *int64 `json:"reasoning-budget,omitempty"`
ReasoningBudget *int32 `json:"reasoning-budget,omitempty"`
}

type BackendConfiguration struct {
// Shared configuration across all backends
ContextSize int64 `json:"context-size,omitempty"`
ContextSize *int32 `json:"context-size,omitempty"`
Speculative *SpeculativeDecodingConfig `json:"speculative,omitempty"`

// Backend-specific configuration
Expand Down
10 changes: 7 additions & 3 deletions pkg/inference/backends/llamacpp/llamacpp.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,11 @@
return inference.RequiredMemory{}, &inference.ErrGGUFParse{Err: err}
}

contextSize := GetContextSize(mdlConfig, config)
configuredContextSize := GetContextSize(mdlConfig, config)
contextSize := int32(4096) // default context size
if configuredContextSize != nil {
contextSize = int32(*configuredContextSize)

Check failure on line 204 in pkg/inference/backends/llamacpp/llamacpp.go

View workflow job for this annotation

GitHub Actions / lint (darwin)

unnecessary conversion (unconvert)

Check failure on line 204 in pkg/inference/backends/llamacpp/llamacpp.go

View workflow job for this annotation

GitHub Actions / lint (windows)

unnecessary conversion (unconvert)

Check failure on line 204 in pkg/inference/backends/llamacpp/llamacpp.go

View workflow job for this annotation

GitHub Actions / lint (linux)

unnecessary conversion (unconvert)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The type cast int32(*configuredContextSize) is redundant because configuredContextSize is of type *int32, so dereferencing it with * already yields an int32 value. You can remove the explicit cast to improve code clarity.

Suggested change
contextSize = int32(*configuredContextSize)
contextSize = *configuredContextSize

}

var ngl uint64
if l.gpuSupported {
Expand Down Expand Up @@ -240,9 +244,9 @@
}

// estimateMemoryFromGGUF estimates memory requirements from a parsed GGUF file.
func (l *llamaCpp) estimateMemoryFromGGUF(ggufFile *parser.GGUFFile, contextSize uint64, ngl uint64) inference.RequiredMemory {
func (l *llamaCpp) estimateMemoryFromGGUF(ggufFile *parser.GGUFFile, contextSize int32, ngl uint64) inference.RequiredMemory {
estimate := ggufFile.EstimateLLaMACppRun(
parser.WithLLaMACppContextSize(int32(contextSize)),
parser.WithLLaMACppContextSize(contextSize),
parser.WithLLaMACppLogicalBatchSize(2048),
parser.WithLLaMACppOffloadLayers(ngl),
)
Expand Down
26 changes: 15 additions & 11 deletions pkg/inference/backends/llamacpp/llamacpp_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"github.com/docker/model-runner/pkg/inference"
)

const UnlimitedContextSize = -1

// Config is the configuration for the llama.cpp backend.
type Config struct {
// Args are the base arguments that are always included.
Expand Down Expand Up @@ -68,11 +70,14 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference
}

if budget := GetReasoningBudget(config); budget != nil {
args = append(args, "--reasoning-budget", strconv.FormatInt(*budget, 10))
args = append(args, "--reasoning-budget", strconv.FormatInt(int64(*budget), 10))
}

// Add context size from model config or backend config
args = append(args, "--ctx-size", strconv.FormatUint(GetContextSize(bundle.RuntimeConfig(), config), 10))
contextSize := GetContextSize(bundle.RuntimeConfig(), config)
if contextSize != nil {
args = append(args, "--ctx-size", strconv.FormatInt(int64(*contextSize), 10))
}

// Add arguments for Multimodal projector or jinja (they are mutually exclusive)
if path := bundle.MMPROJPath(); path != "" {
Expand All @@ -84,20 +89,19 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference
return args, nil
}

func GetContextSize(modelCfg types.Config, backendCfg *inference.BackendConfiguration) uint64 {
func GetContextSize(modelCfg types.Config, backendCfg *inference.BackendConfiguration) *int32 {
// Model config takes precedence
if modelCfg.ContextSize != nil {
return *modelCfg.ContextSize
if modelCfg.ContextSize != nil && (*modelCfg.ContextSize == UnlimitedContextSize || *modelCfg.ContextSize > 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The condition *modelCfg.ContextSize == UnlimitedContextSize is also checked for the backendCfg. It might be cleaner to check for *modelCfg.ContextSize < 0 to handle all negative values, assuming only -1 is a special valid negative value. However, the current implementation is correct if -1 is the only special value intended.

return modelCfg.ContextSize
}
// else use backend config
if backendCfg != nil && backendCfg.ContextSize > 0 {
return uint64(backendCfg.ContextSize)
// Fallback to backend config
if backendCfg != nil && backendCfg.ContextSize != nil && (*backendCfg.ContextSize == UnlimitedContextSize || *backendCfg.ContextSize > 0) {
return backendCfg.ContextSize
}
// finally return default
return 4096 // llama.cpp default
return nil
}

func GetReasoningBudget(backendCfg *inference.BackendConfiguration) *int64 {
func GetReasoningBudget(backendCfg *inference.BackendConfiguration) *int32 {
if backendCfg != nil && backendCfg.LlamaCpp != nil && backendCfg.LlamaCpp.ReasoningBudget != nil {
return backendCfg.LlamaCpp.ReasoningBudget
}
Expand Down
Loading
Loading