diff --git a/commands/completion/functions.go b/commands/completion/functions.go index 91a44bb9..7cd79ddf 100644 --- a/commands/completion/functions.go +++ b/commands/completion/functions.go @@ -20,7 +20,7 @@ func ModelNames(desktopClient func() *desktop.Client, limit int) cobra.Completio if limit > 0 && len(args) >= limit { return nil, cobra.ShellCompDirectiveNoFileComp } - models, err := desktopClient().List() + models, err := desktopClient().List(cmd.Context()) if err != nil { return nil, cobra.ShellCompDirectiveError } diff --git a/commands/compose.go b/commands/compose.go index 6045eb96..8c09c00f 100644 --- a/commands/compose.go +++ b/commands/compose.go @@ -1,6 +1,7 @@ package commands import ( + "context" "encoding/json" "errors" "fmt" @@ -57,7 +58,7 @@ func newUpCommand() *cobra.Command { return errors.New("unable to determine standalone runner endpoint") } - if err := downloadModelsOnlyIfNotFound(desktopClient, models); err != nil { + if err := downloadModelsOnlyIfNotFound(cmd.Context(), desktopClient, models); err != nil { return err } @@ -69,7 +70,7 @@ func newUpCommand() *cobra.Command { } for _, model := range models { - if err := desktopClient.ConfigureBackend(scheduling.ConfigureRequest{ + if err := desktopClient.ConfigureBackend(cmd.Context(), scheduling.ConfigureRequest{ Model: model, ContextSize: ctxSize, RawRuntimeFlags: rawRuntimeFlags, @@ -137,8 +138,8 @@ func newMetadataCommand(upCmd, downCmd *cobra.Command) *cobra.Command { return c } -func downloadModelsOnlyIfNotFound(desktopClient *desktop.Client, models []string) error { - modelsDownloaded, err := desktopClient.List() +func downloadModelsOnlyIfNotFound(ctx context.Context, desktopClient *desktop.Client, models []string) error { + modelsDownloaded, err := desktopClient.List(ctx) if err != nil { _ = sendErrorf("Failed to get models list: %v", err) return err @@ -156,7 +157,7 @@ func downloadModelsOnlyIfNotFound(desktopClient *desktop.Client, models []string } return false }) { - _, _, err = desktopClient.Pull(model, false, func(s string) { + _, _, err = desktopClient.Pull(ctx, model, false, func(s string) { _ = sendInfo(s) }) if err != nil { diff --git a/commands/configure.go b/commands/configure.go index 90b0ad4f..4f930547 100644 --- a/commands/configure.go +++ b/commands/configure.go @@ -39,7 +39,7 @@ func newConfigureCmd() *cobra.Command { return nil }, RunE: func(cmd *cobra.Command, args []string) error { - return desktopClient.ConfigureBackend(opts) + return desktopClient.ConfigureBackend(cmd.Context(), opts) }, ValidArgsFunction: completion.ModelNames(getDesktopClient, -1), } diff --git a/commands/df.go b/commands/df.go index c583c299..f5f72b1e 100644 --- a/commands/df.go +++ b/commands/df.go @@ -15,7 +15,7 @@ func newDFCmd() *cobra.Command { Use: "df", Short: "Show Docker Model Runner disk usage", RunE: func(cmd *cobra.Command, args []string) error { - df, err := desktopClient.DF() + df, err := desktopClient.DF(cmd.Context()) if err != nil { err = handleClientError(err, "Failed to list running models") return handleNotRunningError(err) diff --git a/commands/inspect.go b/commands/inspect.go index 8f1e81a6..9ebb80a3 100644 --- a/commands/inspect.go +++ b/commands/inspect.go @@ -1,6 +1,7 @@ package commands import ( + "context" "fmt" "github.com/docker/model-cli/commands/completion" @@ -32,7 +33,7 @@ func newInspectCmd() *cobra.Command { if openai && remote { return fmt.Errorf("--remote flag cannot be used with --openai flag") } - inspectedModel, err := inspectModel(args, openai, remote, desktopClient) + inspectedModel, err := inspectModel(cmd.Context(), args, openai, remote, desktopClient) if err != nil { return err } @@ -46,17 +47,17 @@ func newInspectCmd() *cobra.Command { return c } -func inspectModel(args []string, openai bool, remote bool, desktopClient *desktop.Client) (string, error) { +func inspectModel(ctx context.Context, args []string, openai bool, remote bool, desktopClient *desktop.Client) (string, error) { modelName := args[0] if openai { - model, err := desktopClient.InspectOpenAI(modelName) + model, err := desktopClient.InspectOpenAI(ctx, modelName) if err != nil { err = handleClientError(err, "Failed to get model "+modelName) return "", handleNotRunningError(err) } return formatter.ToStandardJSON(model) } - model, err := desktopClient.Inspect(modelName, remote) + model, err := desktopClient.Inspect(ctx, modelName, remote) if err != nil { err = handleClientError(err, "Failed to get model "+modelName) return "", handleNotRunningError(err) diff --git a/commands/install-runner.go b/commands/install-runner.go index e073008e..aedc9bf1 100644 --- a/commands/install-runner.go +++ b/commands/install-runner.go @@ -4,10 +4,11 @@ import ( "context" "errors" "fmt" - "github.com/docker/model-cli/pkg/types" "os" "time" + "github.com/docker/model-cli/pkg/types" + "github.com/docker/docker/api/types/container" "github.com/docker/model-cli/commands/completion" "github.com/docker/model-cli/desktop" @@ -32,7 +33,7 @@ const ( // version can take several seconds. func waitForStandaloneRunnerAfterInstall(ctx context.Context) error { for tries := installWaitTries; tries > 0; tries-- { - if status := desktopClient.Status(); status.Error == nil && status.Running { + if status := desktopClient.Status(ctx); status.Error == nil && status.Running { return nil } select { diff --git a/commands/list.go b/commands/list.go index fbaaf58c..e3dc2c18 100644 --- a/commands/list.go +++ b/commands/list.go @@ -2,6 +2,7 @@ package commands import ( "bytes" + "context" "fmt" "os" "time" @@ -50,7 +51,7 @@ func newListCmd() *cobra.Command { if _, err := ensureStandaloneRunnerAvailable(cmd.Context(), standaloneInstallPrinter); err != nil { return fmt.Errorf("unable to initialize standalone model runner: %w", err) } - models, err := listModels(openai, backend, desktopClient, quiet, jsonFormat, apiKey) + models, err := listModels(cmd.Context(), openai, backend, desktopClient, quiet, jsonFormat, apiKey) if err != nil { return err } @@ -67,16 +68,16 @@ func newListCmd() *cobra.Command { return c } -func listModels(openai bool, backend string, desktopClient *desktop.Client, quiet bool, jsonFormat bool, apiKey string) (string, error) { +func listModels(ctx context.Context, openai bool, backend string, desktopClient *desktop.Client, quiet bool, jsonFormat bool, apiKey string) (string, error) { if openai || backend == "openai" { - models, err := desktopClient.ListOpenAI(backend, apiKey) + models, err := desktopClient.ListOpenAI(ctx, backend, apiKey) if err != nil { err = handleClientError(err, "Failed to list models") return "", handleNotRunningError(err) } return formatter.ToStandardJSON(models) } - models, err := desktopClient.List() + models, err := desktopClient.List(ctx) if err != nil { err = handleClientError(err, "Failed to list models") return "", handleNotRunningError(err) diff --git a/commands/package.go b/commands/package.go index 3e3cec02..167b78d8 100644 --- a/commands/package.go +++ b/commands/package.go @@ -221,7 +221,7 @@ func (t *modelRunnerTarget) Write(ctx context.Context, mdl types.ModelArtifact, return fmt.Errorf("get model ID: %w", err) } if t.tag.String() != "" { - if err := desktopClient.Tag(id, parseRepo(t.tag), t.tag.TagStr()); err != nil { + if err := desktopClient.Tag(ctx, id, parseRepo(t.tag), t.tag.TagStr()); err != nil { return fmt.Errorf("tag model: %w", err) } } diff --git a/commands/ps.go b/commands/ps.go index 293e0b8b..7395e613 100644 --- a/commands/ps.go +++ b/commands/ps.go @@ -17,7 +17,7 @@ func newPSCmd() *cobra.Command { Use: "ps", Short: "List running models", RunE: func(cmd *cobra.Command, args []string) error { - ps, err := desktopClient.PS() + ps, err := desktopClient.PS(cmd.Context()) if err != nil { err = handleClientError(err, "Failed to list running models") return handleNotRunningError(err) diff --git a/commands/pull.go b/commands/pull.go index a85f2024..fe9b12f4 100644 --- a/commands/pull.go +++ b/commands/pull.go @@ -47,7 +47,7 @@ func pullModel(cmd *cobra.Command, desktopClient *desktop.Client, model string, } else { progress = RawProgress } - response, progressShown, err := desktopClient.Pull(model, ignoreRuntimeMemoryCheck, progress) + response, progressShown, err := desktopClient.Pull(cmd.Context(), model, ignoreRuntimeMemoryCheck, progress) // Add a newline before any output (success or error) if progress was shown. if progressShown { diff --git a/commands/push.go b/commands/push.go index ed94f1a6..b64c111b 100644 --- a/commands/push.go +++ b/commands/push.go @@ -34,7 +34,7 @@ func newPushCmd() *cobra.Command { } func pushModel(cmd *cobra.Command, desktopClient *desktop.Client, model string) error { - response, progressShown, err := desktopClient.Push(model, TUIProgress) + response, progressShown, err := desktopClient.Push(cmd.Context(), model, TUIProgress) // Add a newline before any output (success or error) if progress was shown. if progressShown { diff --git a/commands/rm.go b/commands/rm.go index 95159fb2..686aaf2f 100644 --- a/commands/rm.go +++ b/commands/rm.go @@ -27,7 +27,7 @@ func newRemoveCmd() *cobra.Command { if _, err := ensureStandaloneRunnerAvailable(cmd.Context(), cmd); err != nil { return fmt.Errorf("unable to initialize standalone model runner: %w", err) } - response, err := desktopClient.Remove(args, force) + response, err := desktopClient.Remove(cmd.Context(), args, force) if response != "" { cmd.Print(response) } diff --git a/commands/run.go b/commands/run.go index 24c73f6e..4e148747 100644 --- a/commands/run.go +++ b/commands/run.go @@ -2,10 +2,12 @@ package commands import ( "bufio" + "context" "errors" "fmt" "io" "os" + "os/signal" "strings" "github.com/docker/model-cli/commands/completion" @@ -136,7 +138,7 @@ func newRunCmd() *cobra.Command { // Do not validate the model in case of using OpenAI's backend, let OpenAI handle it if backend != "openai" { - _, err := desktopClient.Inspect(model, false) + _, err := desktopClient.Inspect(cmd.Context(), model, false) if err != nil { if !errors.Is(err, desktop.ErrNotFound) { return handleNotRunningError(handleClientError(err, "Failed to inspect model")) @@ -149,7 +151,7 @@ func newRunCmd() *cobra.Command { } if prompt != "" { - if err := desktopClient.Chat(backend, model, prompt, apiKey); err != nil { + if err := desktopClient.Chat(cmd.Context(), backend, model, prompt, apiKey); err != nil { return handleClientError(err, "Failed to generate a response") } cmd.Println() @@ -178,8 +180,12 @@ func newRunCmd() *cobra.Command { continue } - if err := desktopClient.Chat(backend, model, userInput, apiKey); err != nil { - cmd.PrintErr(handleClientError(err, "Failed to generate a response")) + if err := cancellableChat(cmd.Context(), desktopClient, backend, model, userInput, apiKey); err != nil { + if errors.Is(err, context.Canceled) { + fmt.Println("\nChat cancelled - Press Ctrl-C again to exit.") + } else { + cmd.PrintErr(handleClientError(err, "Failed to generate a response")) + } continue } @@ -208,3 +214,14 @@ func newRunCmd() *cobra.Command { return c } + +// cancellableChat sends a chat request that can be cancelled with Ctrl-C, both on Unix and Windows. +func cancellableChat(ctx context.Context, desktopClient *desktop.Client, backend, model, userInput, apiKey string) error { + // Create a NotifyContext that will handle os.Interrupt by cancelling the chat request. + // Calling stop at the end restores the previous signal handling, allowing Ctrl-C to exit the program. + // On Windows, the mapping from CTRL_C_EVENT to os.Interrupt can be seen at + // https://github.com/golang/go/blob/13bb48e6fbc35419a28747688426eb3684242fbc/src/runtime/os_windows.go#L1029 + chatContext, stop := signal.NotifyContext(ctx, os.Interrupt) + defer stop() + return desktopClient.Chat(chatContext, backend, model, userInput, apiKey) +} diff --git a/commands/status.go b/commands/status.go index a2bff45f..de1ceda6 100644 --- a/commands/status.go +++ b/commands/status.go @@ -3,9 +3,10 @@ package commands import ( "encoding/json" "fmt" - "github.com/docker/model-cli/pkg/types" "os" + "github.com/docker/model-cli/pkg/types" + "github.com/docker/cli/cli-plugins/hooks" "github.com/docker/model-cli/commands/completion" "github.com/docker/model-cli/desktop" @@ -22,7 +23,7 @@ func newStatusCmd() *cobra.Command { if err != nil { return fmt.Errorf("unable to initialize standalone model runner: %w", err) } - status := desktopClient.Status() + status := desktopClient.Status(cmd.Context()) if status.Error != nil { return handleClientError(status.Error, "Failed to get Docker Model Runner status") } diff --git a/commands/tag.go b/commands/tag.go index 9396e89f..305cdefe 100644 --- a/commands/tag.go +++ b/commands/tag.go @@ -42,7 +42,7 @@ func tagModel(cmd *cobra.Command, desktopClient *desktop.Client, source, target return fmt.Errorf("invalid tag: %w", err) } // Make tag request with model runner client - if err := desktopClient.Tag(source, parseRepo(tag), tag.TagStr()); err != nil { + if err := desktopClient.Tag(cmd.Context(), source, parseRepo(tag), tag.TagStr()); err != nil { return fmt.Errorf("failed to tag model: %w", err) } cmd.Printf("Model %q tagged successfully with %q\n", source, target) diff --git a/commands/unload.go b/commands/unload.go index 97f3faa8..1702e34b 100644 --- a/commands/unload.go +++ b/commands/unload.go @@ -17,7 +17,7 @@ func newUnloadCmd() *cobra.Command { Use: "unload " + cmdArgs, Short: "Unload running models", RunE: func(cmd *cobra.Command, models []string) error { - unloadResp, err := desktopClient.Unload(desktop.UnloadRequest{All: all, Backend: backend, Models: models}) + unloadResp, err := desktopClient.Unload(cmd.Context(), desktop.UnloadRequest{All: all, Backend: backend, Models: models}) if err != nil { err = handleClientError(err, "Failed to unload models") return handleNotRunningError(err) diff --git a/desktop/desktop.go b/desktop/desktop.go index 3fff3bfd..f3f95005 100644 --- a/desktop/desktop.go +++ b/desktop/desktop.go @@ -65,9 +65,9 @@ func normalizeHuggingFaceModelName(model string) string { return model } -func (c *Client) Status() Status { +func (c *Client) Status(ctx context.Context) Status { // TODO: Query "/". - resp, err := c.doRequest(http.MethodGet, inference.ModelsPrefix, nil) + resp, err := c.doRequest(ctx, http.MethodGet, inference.ModelsPrefix, nil) if err != nil { err = c.handleQueryError(err, inference.ModelsPrefix) if errors.Is(err, ErrServiceUnavailable) { @@ -83,7 +83,7 @@ func (c *Client) Status() Status { defer resp.Body.Close() if resp.StatusCode == http.StatusOK { var status []byte - statusResp, err := c.doRequest(http.MethodGet, inference.InferencePrefix+"/status", nil) + statusResp, err := c.doRequest(ctx, http.MethodGet, inference.InferencePrefix+"/status", nil) if err != nil { status = []byte(fmt.Sprintf("error querying status: %v", err)) } else { @@ -106,7 +106,7 @@ func (c *Client) Status() Status { } } -func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, progress func(string)) (string, bool, error) { +func (c *Client) Pull(ctx context.Context, model string, ignoreRuntimeMemoryCheck bool, progress func(string)) (string, bool, error) { model = normalizeHuggingFaceModelName(model) jsonData, err := json.Marshal(dmrm.ModelCreateRequest{From: model, IgnoreRuntimeMemoryCheck: ignoreRuntimeMemoryCheck}) if err != nil { @@ -115,6 +115,7 @@ func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, progress func createPath := inference.ModelsPrefix + "/create" resp, err := c.doRequest( + ctx, http.MethodPost, createPath, bytes.NewReader(jsonData), @@ -174,10 +175,11 @@ func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, progress func return "", progressShown, fmt.Errorf("unexpected end of stream while pulling model %s", model) } -func (c *Client) Push(model string, progress func(string)) (string, bool, error) { +func (c *Client) Push(ctx context.Context, model string, progress func(string)) (string, bool, error) { model = normalizeHuggingFaceModelName(model) pushPath := inference.ModelsPrefix + "/" + model + "/push" resp, err := c.doRequest( + ctx, http.MethodPost, pushPath, nil, // Assuming no body is needed for the push request @@ -225,9 +227,9 @@ func (c *Client) Push(model string, progress func(string)) (string, bool, error) return "", progressShown, fmt.Errorf("unexpected end of stream while pushing model %s", model) } -func (c *Client) List() ([]dmrm.Model, error) { +func (c *Client) List(ctx context.Context) ([]dmrm.Model, error) { modelsRoute := inference.ModelsPrefix - body, err := c.listRaw(modelsRoute, "") + body, err := c.listRaw(ctx, modelsRoute, "") if err != nil { return []dmrm.Model{}, err } @@ -240,14 +242,14 @@ func (c *Client) List() ([]dmrm.Model, error) { return modelsJson, nil } -func (c *Client) ListOpenAI(backend, apiKey string) (dmrm.OpenAIModelList, error) { +func (c *Client) ListOpenAI(ctx context.Context, backend, apiKey string) (dmrm.OpenAIModelList, error) { if backend == "" { backend = DefaultBackend } modelsRoute := fmt.Sprintf("%s/%s/v1/models", inference.InferencePrefix, backend) // Use doRequestWithAuth to support API key authentication - resp, err := c.doRequestWithAuth(http.MethodGet, modelsRoute, nil, "openai", apiKey) + resp, err := c.doRequestWithAuth(ctx, http.MethodGet, modelsRoute, nil, "openai", apiKey) if err != nil { return dmrm.OpenAIModelList{}, c.handleQueryError(err, modelsRoute) } @@ -269,19 +271,19 @@ func (c *Client) ListOpenAI(backend, apiKey string) (dmrm.OpenAIModelList, error return modelsJson, nil } -func (c *Client) Inspect(model string, remote bool) (dmrm.Model, error) { +func (c *Client) Inspect(ctx context.Context, model string, remote bool) (dmrm.Model, error) { model = normalizeHuggingFaceModelName(model) if model != "" { if !strings.Contains(strings.Trim(model, "/"), "/") { // Do an extra API call to check if the model parameter isn't a model ID. - modelId, err := c.fullModelID(model) + modelId, err := c.fullModelID(ctx, model) if err != nil { return dmrm.Model{}, fmt.Errorf("invalid model name: %s", model) } model = modelId } } - rawResponse, err := c.listRawWithQuery(fmt.Sprintf("%s/%s", inference.ModelsPrefix, model), model, remote) + rawResponse, err := c.listRawWithQuery(ctx, fmt.Sprintf("%s/%s", inference.ModelsPrefix, model), model, remote) if err != nil { return dmrm.Model{}, err } @@ -293,17 +295,17 @@ func (c *Client) Inspect(model string, remote bool) (dmrm.Model, error) { return modelInspect, nil } -func (c *Client) InspectOpenAI(model string) (dmrm.OpenAIModel, error) { +func (c *Client) InspectOpenAI(ctx context.Context, model string) (dmrm.OpenAIModel, error) { model = normalizeHuggingFaceModelName(model) modelsRoute := inference.InferencePrefix + "/v1/models" if !strings.Contains(strings.Trim(model, "/"), "/") { // Do an extra API call to check if the model parameter isn't a model ID. var err error - if model, err = c.fullModelID(model); err != nil { + if model, err = c.fullModelID(ctx, model); err != nil { return dmrm.OpenAIModel{}, fmt.Errorf("invalid model name: %s", model) } } - rawResponse, err := c.listRaw(fmt.Sprintf("%s/%s", modelsRoute, model), model) + rawResponse, err := c.listRaw(ctx, fmt.Sprintf("%s/%s", modelsRoute, model), model) if err != nil { return dmrm.OpenAIModel{}, err } @@ -314,16 +316,16 @@ func (c *Client) InspectOpenAI(model string) (dmrm.OpenAIModel, error) { return modelInspect, nil } -func (c *Client) listRaw(route string, model string) ([]byte, error) { - return c.listRawWithQuery(route, model, false) +func (c *Client) listRaw(ctx context.Context, route string, model string) ([]byte, error) { + return c.listRawWithQuery(ctx, route, model, false) } -func (c *Client) listRawWithQuery(route string, model string, remote bool) ([]byte, error) { +func (c *Client) listRawWithQuery(ctx context.Context, route string, model string, remote bool) ([]byte, error) { if remote { route += "?remote=true" } - resp, err := c.doRequest(http.MethodGet, route, nil) + resp, err := c.doRequest(ctx, http.MethodGet, route, nil) if err != nil { return nil, c.handleQueryError(err, route) } @@ -343,8 +345,8 @@ func (c *Client) listRawWithQuery(route string, model string, remote bool) ([]by return body, nil } -func (c *Client) fullModelID(id string) (string, error) { - bodyResponse, err := c.listRaw(inference.ModelsPrefix, "") +func (c *Client) fullModelID(ctx context.Context, id string) (string, error) { + bodyResponse, err := c.listRaw(ctx, inference.ModelsPrefix, "") if err != nil { return "", err } @@ -371,11 +373,11 @@ const ( chatPrinterReasoning ) -func (c *Client) Chat(backend, model, prompt, apiKey string) error { +func (c *Client) Chat(ctx context.Context, backend, model, prompt, apiKey string) error { model = normalizeHuggingFaceModelName(model) if !strings.Contains(strings.Trim(model, "/"), "/") { // Do an extra API call to check if the model parameter isn't a model ID. - if expanded, err := c.fullModelID(model); err == nil { + if expanded, err := c.fullModelID(ctx, model); err == nil { model = expanded } } @@ -404,6 +406,7 @@ func (c *Client) Chat(backend, model, prompt, apiKey string) error { } resp, err := c.doRequestWithAuth( + ctx, http.MethodPost, completionsPath, bytes.NewReader(jsonData), @@ -474,13 +477,13 @@ func (c *Client) Chat(backend, model, prompt, apiKey string) error { return nil } -func (c *Client) Remove(models []string, force bool) (string, error) { +func (c *Client) Remove(ctx context.Context, models []string, force bool) (string, error) { modelRemoved := "" for _, model := range models { model = normalizeHuggingFaceModelName(model) // Check if not a model ID passed as parameter. if !strings.Contains(model, "/") { - if expanded, err := c.fullModelID(model); err == nil { + if expanded, err := c.fullModelID(ctx, model); err == nil { model = expanded } } @@ -492,7 +495,7 @@ func (c *Client) Remove(models []string, force bool) (string, error) { strconv.FormatBool(force), ) - resp, err := c.doRequest(http.MethodDelete, removePath, nil) + resp, err := c.doRequest(ctx, http.MethodDelete, removePath, nil) if err != nil { return modelRemoved, c.handleQueryError(err, removePath) } @@ -542,9 +545,9 @@ type BackendStatus struct { LastUsed time.Time `json:"last_used,omitempty"` } -func (c *Client) PS() ([]BackendStatus, error) { +func (c *Client) PS(ctx context.Context) ([]BackendStatus, error) { psPath := inference.InferencePrefix + "/ps" - resp, err := c.doRequest(http.MethodGet, psPath, nil) + resp, err := c.doRequest(ctx, http.MethodGet, psPath, nil) if err != nil { return []BackendStatus{}, c.handleQueryError(err, psPath) } @@ -569,9 +572,9 @@ type DiskUsage struct { DefaultBackendDiskUsage int64 `json:"default_backend_disk_usage"` } -func (c *Client) DF() (DiskUsage, error) { +func (c *Client) DF(ctx context.Context) (DiskUsage, error) { dfPath := inference.InferencePrefix + "/df" - resp, err := c.doRequest(http.MethodGet, dfPath, nil) + resp, err := c.doRequest(ctx, http.MethodGet, dfPath, nil) if err != nil { return DiskUsage{}, c.handleQueryError(err, dfPath) } @@ -602,14 +605,14 @@ type UnloadResponse struct { UnloadedRunners int `json:"unloaded_runners"` } -func (c *Client) Unload(req UnloadRequest) (UnloadResponse, error) { +func (c *Client) Unload(ctx context.Context, req UnloadRequest) (UnloadResponse, error) { unloadPath := inference.InferencePrefix + "/unload" jsonData, err := json.Marshal(req) if err != nil { return UnloadResponse{}, fmt.Errorf("error marshaling request: %w", err) } - resp, err := c.doRequest(http.MethodPost, unloadPath, bytes.NewReader(jsonData)) + resp, err := c.doRequest(ctx, http.MethodPost, unloadPath, bytes.NewReader(jsonData)) if err != nil { return UnloadResponse{}, c.handleQueryError(err, unloadPath) } @@ -633,14 +636,14 @@ func (c *Client) Unload(req UnloadRequest) (UnloadResponse, error) { return unloadResp, nil } -func (c *Client) ConfigureBackend(request scheduling.ConfigureRequest) error { +func (c *Client) ConfigureBackend(ctx context.Context, request scheduling.ConfigureRequest) error { configureBackendPath := inference.InferencePrefix + "/_configure" jsonData, err := json.Marshal(request) if err != nil { return fmt.Errorf("error marshaling request: %w", err) } - resp, err := c.doRequest(http.MethodPost, configureBackendPath, bytes.NewReader(jsonData)) + resp, err := c.doRequest(ctx, http.MethodPost, configureBackendPath, bytes.NewReader(jsonData)) if err != nil { return c.handleQueryError(err, configureBackendPath) } @@ -658,13 +661,13 @@ func (c *Client) ConfigureBackend(request scheduling.ConfigureRequest) error { } // doRequest is a helper function that performs HTTP requests and handles 503 responses -func (c *Client) doRequest(method, path string, body io.Reader) (*http.Response, error) { - return c.doRequestWithAuth(method, path, body, "", "") +func (c *Client) doRequest(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) { + return c.doRequestWithAuth(ctx, method, path, body, "", "") } // doRequestWithAuth is a helper function that performs HTTP requests with optional authentication -func (c *Client) doRequestWithAuth(method, path string, body io.Reader, backend, apiKey string) (*http.Response, error) { - req, err := http.NewRequest(method, c.modelRunner.URL(path), body) +func (c *Client) doRequestWithAuth(ctx context.Context, method, path string, body io.Reader, backend, apiKey string) (*http.Response, error) { + req, err := http.NewRequestWithContext(ctx, method, c.modelRunner.URL(path), body) if err != nil { return nil, fmt.Errorf("error creating request: %w", err) } @@ -699,12 +702,12 @@ func (c *Client) handleQueryError(err error, path string) error { return fmt.Errorf("error querying %s: %w", path, err) } -func (c *Client) Tag(source, targetRepo, targetTag string) error { +func (c *Client) Tag(ctx context.Context, source, targetRepo, targetTag string) error { source = normalizeHuggingFaceModelName(source) // Check if the source is a model ID, and expand it if necessary if !strings.Contains(strings.Trim(source, "/"), "/") { // Do an extra API call to check if the model parameter might be a model ID - if expanded, err := c.fullModelID(source); err == nil { + if expanded, err := c.fullModelID(ctx, source); err == nil { source = expanded } } @@ -717,7 +720,7 @@ func (c *Client) Tag(source, targetRepo, targetTag string) error { targetTag, ) - resp, err := c.doRequest(http.MethodPost, tagPath, nil) + resp, err := c.doRequest(ctx, http.MethodPost, tagPath, nil) if err != nil { return c.handleQueryError(err, tagPath) } diff --git a/desktop/desktop_test.go b/desktop/desktop_test.go index db0f68ef..b64cddee 100644 --- a/desktop/desktop_test.go +++ b/desktop/desktop_test.go @@ -36,7 +36,7 @@ func TestPullHuggingFaceModel(t *testing.T) { Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pulled successfully"}`)), }, nil) - _, _, err := client.Pull(modelName, false, func(s string) {}) + _, _, err := client.Pull(t.Context(), modelName, false, func(s string) {}) assert.NoError(t, err) } @@ -63,7 +63,7 @@ func TestChatHuggingFaceModel(t *testing.T) { Body: io.NopCloser(bytes.NewBufferString("data: {\"choices\":[{\"delta\":{\"content\":\"Hello there!\"}}]}\n")), }, nil) - err := client.Chat("", modelName, prompt, "") + err := client.Chat(t.Context(), "", modelName, prompt, "") assert.NoError(t, err) } @@ -97,7 +97,7 @@ func TestInspectHuggingFaceModel(t *testing.T) { }`)), }, nil) - model, err := client.Inspect(modelName, false) + model, err := client.Inspect(t.Context(), modelName, false) assert.NoError(t, err) assert.Equal(t, expectedLowercase, model.Tags[0]) } @@ -122,7 +122,7 @@ func TestNonHuggingFaceModel(t *testing.T) { Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pulled successfully"}`)), }, nil) - _, _, err := client.Pull(modelName, false, func(s string) {}) + _, _, err := client.Pull(t.Context(), modelName, false, func(s string) {}) assert.NoError(t, err) } @@ -145,7 +145,7 @@ func TestPushHuggingFaceModel(t *testing.T) { Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pushed successfully"}`)), }, nil) - _, _, err := client.Push(modelName, func(s string) {}) + _, _, err := client.Push(t.Context(), modelName, func(s string) {}) assert.NoError(t, err) } @@ -168,7 +168,7 @@ func TestRemoveHuggingFaceModel(t *testing.T) { Body: io.NopCloser(bytes.NewBufferString("Model removed successfully")), }, nil) - _, err := client.Remove([]string{modelName}, false) + _, err := client.Remove(t.Context(), []string{modelName}, false) assert.NoError(t, err) } @@ -193,7 +193,7 @@ func TestTagHuggingFaceModel(t *testing.T) { Body: io.NopCloser(bytes.NewBufferString("Tag created successfully")), }, nil) - assert.NoError(t, client.Tag(sourceModel, targetRepo, targetTag)) + assert.NoError(t, client.Tag(t.Context(), sourceModel, targetRepo, targetTag)) } func TestInspectOpenAIHuggingFaceModel(t *testing.T) { @@ -220,7 +220,7 @@ func TestInspectOpenAIHuggingFaceModel(t *testing.T) { }`)), }, nil) - model, err := client.InspectOpenAI(modelName) + model, err := client.InspectOpenAI(t.Context(), modelName) assert.NoError(t, err) assert.Equal(t, expectedLowercase, model.ID) }