diff --git a/pkg/inference/scheduling/api.go b/pkg/inference/scheduling/api.go index ebd40571d..52c8ec56d 100644 --- a/pkg/inference/scheduling/api.go +++ b/pkg/inference/scheduling/api.go @@ -44,6 +44,16 @@ type OpenAIInferenceRequest struct { Model string `json:"model"` } +// OpenAIErrorResponse is used to format an OpenAI API compatible error response +// (see https://platform.openai.com/docs/api-reference/responses-streaming/error) +type OpenAIErrorResponse struct { + Type string `json:"type"` // always "error" + Code *string `json:"code"` + Message string `json:"message"` + Param *string `json:"param"` + SequenceNumber int `json:"sequence_number"` +} + // BackendStatus represents information about a running backend type BackendStatus struct { // BackendName is the name of the backend diff --git a/pkg/inference/scheduling/loader.go b/pkg/inference/scheduling/loader.go index 536bd80ff..48c3895de 100644 --- a/pkg/inference/scheduling/loader.go +++ b/pkg/inference/scheduling/loader.go @@ -408,7 +408,11 @@ func (l *loader) load(ctx context.Context, backendName, model string, mode infer select { case <-l.slots[existing].done: l.log.Warnf("%s runner for %s is defunct. Waiting for it to be evicted.", backendName, model) - goto WaitForChange + if l.references[existing] == 0 { + l.evictRunner(backendName, model, mode) + } else { + goto WaitForChange + } default: l.references[existing] += 1 l.timestamps[existing] = time.Time{} diff --git a/pkg/inference/scheduling/runner.go b/pkg/inference/scheduling/runner.go index 43f28e48e..15de5fec4 100644 --- a/pkg/inference/scheduling/runner.go +++ b/pkg/inference/scheduling/runner.go @@ -2,6 +2,7 @@ package scheduling import ( "context" + "encoding/json" "errors" "fmt" "io" @@ -143,6 +144,18 @@ func run( w.WriteHeader(http.StatusInternalServerError) select { case <-r.done: + res := OpenAIErrorResponse{ + Type: "error", + Code: nil, + Message: r.err.Error(), + Param: nil, + SequenceNumber: 1, + } + errJson, err := json.Marshal(&res) + if err == nil { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Write(errJson) + } return case <-time.After(30 * time.Second): }