diff --git a/pkg/inference/backend.go b/pkg/inference/backend.go index 48676626c..019e52983 100644 --- a/pkg/inference/backend.go +++ b/pkg/inference/backend.go @@ -30,8 +30,8 @@ func (m BackendMode) String() string { } type BackendConfiguration struct { - ContextSize int64 - RawFlags []string + ContextSize int64 `json:"context_size,omitempty"` + RawFlags []string `json:"flags,omitempty"` } // Backend is the interface implemented by inference engine backends. Backend diff --git a/pkg/inference/scheduling/loader.go b/pkg/inference/scheduling/loader.go index 9c6922912..769eef27d 100644 --- a/pkg/inference/scheduling/loader.go +++ b/pkg/inference/scheduling/loader.go @@ -13,6 +13,7 @@ import ( "github.com/docker/model-runner/pkg/inference" "github.com/docker/model-runner/pkg/inference/models" "github.com/docker/model-runner/pkg/logging" + "github.com/docker/model-runner/pkg/metrics" ) const ( @@ -92,6 +93,8 @@ type loader struct { timestamps []time.Time // runnerConfigs maps model names to runner configurations runnerConfigs map[runnerKey]inference.BackendConfiguration + // openAIRecorder is used to record OpenAI API inference requests and responses. + openAIRecorder *metrics.OpenAIRecorder } // newLoader creates a new loader. @@ -99,6 +102,7 @@ func newLoader( log logging.Logger, backends map[string]inference.Backend, modelManager *models.Manager, + openAIRecorder *metrics.OpenAIRecorder, ) *loader { // Compute the number of runner slots to allocate. Because of RAM and VRAM // limitations, it's unlikely that we'll ever be able to fully populate @@ -153,6 +157,7 @@ func newLoader( allocations: make([]uint64, nSlots), timestamps: make([]time.Time, nSlots), runnerConfigs: make(map[runnerKey]inference.BackendConfiguration), + openAIRecorder: openAIRecorder, } l.guard <- struct{}{} return l @@ -462,7 +467,7 @@ func (l *loader) load(ctx context.Context, backendName, model string, mode infer } // Create the runner. l.log.Infof("Loading %s backend runner with model %s in %s mode", backendName, model, mode) - runner, err := run(l.log, backend, model, mode, slot, runnerConfig) + runner, err := run(l.log, backend, model, mode, slot, runnerConfig, l.openAIRecorder) if err != nil { l.log.Warnf("Unable to start %s backend runner with model %s in %s mode: %v", backendName, model, mode, err, diff --git a/pkg/inference/scheduling/runner.go b/pkg/inference/scheduling/runner.go index 15de5fec4..374d15b61 100644 --- a/pkg/inference/scheduling/runner.go +++ b/pkg/inference/scheduling/runner.go @@ -15,6 +15,7 @@ import ( "github.com/docker/model-runner/pkg/inference" "github.com/docker/model-runner/pkg/logging" + "github.com/docker/model-runner/pkg/metrics" ) const ( @@ -63,6 +64,8 @@ type runner struct { proxy *httputil.ReverseProxy // proxyLog is the stream used for logging by proxy. proxyLog io.Closer + // openAIRecorder is used to record OpenAI API inference requests and responses. + openAIRecorder *metrics.OpenAIRecorder // err is the error returned by the runner's backend, only valid after done is closed. err error } @@ -75,6 +78,7 @@ func run( mode inference.BackendMode, slot int, runnerConfig *inference.BackendConfiguration, + openAIRecorder *metrics.OpenAIRecorder, ) (*runner, error) { // Create a dialer / transport that target backend on the specified slot. socket, err := RunnerSocketPath(slot) @@ -124,16 +128,17 @@ func run( runDone := make(chan struct{}) r := &runner{ - log: log, - backend: backend, - model: model, - mode: mode, - cancel: runCancel, - done: runDone, - transport: transport, - client: client, - proxy: proxy, - proxyLog: proxyLog, + log: log, + backend: backend, + model: model, + mode: mode, + cancel: runCancel, + done: runDone, + transport: transport, + client: client, + proxy: proxy, + proxyLog: proxyLog, + openAIRecorder: openAIRecorder, } proxy.ErrorHandler = func(w http.ResponseWriter, req *http.Request, err error) { @@ -164,6 +169,8 @@ func run( } } + r.openAIRecorder.SetConfigForModel(model, runnerConfig) + // Start the backend run loop. go func() { if err := backend.Run(runCtx, socket, model, mode, runnerConfig); err != nil { @@ -236,6 +243,8 @@ func (r *runner) terminate() { if err := r.proxyLog.Close(); err != nil { r.log.Warnf("Unable to close reverse proxy log writer: %v", err) } + + r.openAIRecorder.RemoveModel(r.model) } // ServeHTTP implements net/http.Handler.ServeHTTP. It forwards requests to the diff --git a/pkg/inference/scheduling/scheduler.go b/pkg/inference/scheduling/scheduler.go index 9b9e3fc46..5a9d0e5b1 100644 --- a/pkg/inference/scheduling/scheduler.go +++ b/pkg/inference/scheduling/scheduler.go @@ -40,6 +40,8 @@ type Scheduler struct { router *http.ServeMux // tracker is the metrics tracker. tracker *metrics.Tracker + // openAIRecorder is used to record OpenAI API inference requests and responses. + openAIRecorder *metrics.OpenAIRecorder // lock is used to synchronize access to the scheduler's router. lock sync.Mutex } @@ -54,6 +56,8 @@ func NewScheduler( allowedOrigins []string, tracker *metrics.Tracker, ) *Scheduler { + openAIRecorder := metrics.NewOpenAIRecorder(log.WithField("component", "openai-recorder")) + // Create the scheduler. s := &Scheduler{ log: log, @@ -61,9 +65,10 @@ func NewScheduler( defaultBackend: defaultBackend, modelManager: modelManager, installer: newInstaller(log, backends, httpClient), - loader: newLoader(log, backends, modelManager), + loader: newLoader(log, backends, modelManager, openAIRecorder), router: http.NewServeMux(), tracker: tracker, + openAIRecorder: openAIRecorder, } // Register routes. @@ -115,6 +120,7 @@ func (s *Scheduler) routeHandlers(allowedOrigins []string) map[string]http.Handl m["POST "+inference.InferencePrefix+"/unload"] = s.Unload m["POST "+inference.InferencePrefix+"/{backend}/_configure"] = s.Configure m["POST "+inference.InferencePrefix+"/_configure"] = s.Configure + m["GET "+inference.InferencePrefix+"/requests"] = s.openAIRecorder.GetRecordsByModelHandler() return m } @@ -232,6 +238,14 @@ func (s *Scheduler) handleOpenAIInference(w http.ResponseWriter, r *http.Request s.tracker.TrackModel(model) } + // Record the request in the OpenAI recorder. + recordID := s.openAIRecorder.RecordRequest(request.Model, r, body) + w = s.openAIRecorder.NewResponseRecorder(w) + defer func() { + // Record the response in the OpenAI recorder. + s.openAIRecorder.RecordResponse(recordID, request.Model, w) + }() + // Request a runner to execute the request and defer its release. runner, err := s.loader.load(r.Context(), backend.Name(), request.Model, backendMode) if err != nil { diff --git a/pkg/metrics/openai_recorder.go b/pkg/metrics/openai_recorder.go new file mode 100644 index 000000000..a5e691229 --- /dev/null +++ b/pkg/metrics/openai_recorder.go @@ -0,0 +1,271 @@ +package metrics + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "strings" + "sync" + "time" + + "github.com/docker/model-runner/pkg/inference" + "github.com/docker/model-runner/pkg/logging" +) + +type responseRecorder struct { + http.ResponseWriter + body *bytes.Buffer + statusCode int +} + +func (rr *responseRecorder) Write(b []byte) (int, error) { + rr.body.Write(b) + return rr.ResponseWriter.Write(b) +} + +func (rr *responseRecorder) WriteHeader(statusCode int) { + rr.statusCode = statusCode + rr.ResponseWriter.WriteHeader(statusCode) +} + +type RequestResponsePair struct { + ID string `json:"id"` + Model string `json:"model"` + Method string `json:"method"` + URL string `json:"url"` + Request string `json:"request"` + Response string `json:"response"` + Timestamp time.Time `json:"timestamp"` + StatusCode int `json:"status_code"` + UserAgent string `json:"user_agent,omitempty"` +} + +type ModelData struct { + Config inference.BackendConfiguration `json:"config"` + Records []*RequestResponsePair `json:"records"` +} + +type OpenAIRecorder struct { + log logging.Logger + records map[string]*ModelData + m sync.RWMutex +} + +func NewOpenAIRecorder(log logging.Logger) *OpenAIRecorder { + return &OpenAIRecorder{ + log: log, + records: make(map[string]*ModelData), + } +} + +func (r *OpenAIRecorder) SetConfigForModel(model string, config *inference.BackendConfiguration) { + if config == nil { + r.log.Warnf("SetConfigForModel called with nil config for model %s", model) + return + } + + r.m.Lock() + defer r.m.Unlock() + + if r.records[model] == nil { + r.records[model] = &ModelData{ + Records: make([]*RequestResponsePair, 0, 10), + Config: inference.BackendConfiguration{}, + } + } + + r.records[model].Config = *config +} + +func (r *OpenAIRecorder) RecordRequest(model string, req *http.Request, body []byte) string { + r.m.Lock() + defer r.m.Unlock() + + recordID := fmt.Sprintf("%s_%d", model, time.Now().UnixNano()) + + record := &RequestResponsePair{ + ID: recordID, + Model: model, + Method: req.Method, + URL: req.URL.Path, + Request: string(body), + Timestamp: time.Now(), + UserAgent: req.UserAgent(), + } + + if r.records[model] == nil { + r.records[model] = &ModelData{ + Records: make([]*RequestResponsePair, 0, 10), + Config: inference.BackendConfiguration{}, + } + } + + r.records[model].Records = append(r.records[model].Records, record) + + if len(r.records[model].Records) > 10 { + r.records[model].Records = r.records[model].Records[1:] + } + + return recordID +} + +func (r *OpenAIRecorder) NewResponseRecorder(w http.ResponseWriter) http.ResponseWriter { + rc := &responseRecorder{ + ResponseWriter: w, + body: &bytes.Buffer{}, + statusCode: http.StatusOK, + } + return rc +} + +func (r *OpenAIRecorder) RecordResponse(id, model string, rw http.ResponseWriter) { + rr := rw.(*responseRecorder) + + responseBody := rr.body.String() + statusCode := rr.statusCode + + var response string + if strings.Contains(responseBody, "data: ") { + response = r.convertStreamingResponse(responseBody) + } else { + response = responseBody + } + + r.m.Lock() + defer r.m.Unlock() + + if modelData, exists := r.records[model]; exists { + for _, record := range modelData.Records { + if record.ID == id { + record.Response = response + record.StatusCode = statusCode + return + } + } + r.log.Errorf("Matching request (id=%s) not found for model %s - %d\n%s", id, model, statusCode, response) + } else { + r.log.Errorf("Model %s not found in records - %d\n%s", model, statusCode, response) + } +} + +func (r *OpenAIRecorder) convertStreamingResponse(streamingBody string) string { + lines := strings.Split(streamingBody, "\n") + var contentBuilder strings.Builder + var lastChunk map[string]interface{} + + for _, line := range lines { + if strings.HasPrefix(line, "data: ") { + data := strings.TrimPrefix(line, "data: ") + if data == "[DONE]" { + break + } + + var chunk map[string]interface{} + if err := json.Unmarshal([]byte(data), &chunk); err != nil { + continue + } + + lastChunk = chunk + + if choices, ok := chunk["choices"].([]interface{}); ok && len(choices) > 0 { + if choice, ok := choices[0].(map[string]interface{}); ok { + if delta, ok := choice["delta"].(map[string]interface{}); ok { + if content, ok := delta["content"].(string); ok { + contentBuilder.WriteString(content) + } + } + } + } + } + } + + if lastChunk == nil { + return streamingBody + } + + finalResponse := make(map[string]interface{}) + + for key, value := range lastChunk { + finalResponse[key] = value + } + + if choices, ok := finalResponse["choices"].([]interface{}); ok && len(choices) > 0 { + if choice, ok := choices[0].(map[string]interface{}); ok { + choice["message"] = map[string]interface{}{ + "role": "assistant", + "content": contentBuilder.String(), + } + delete(choice, "delta") + + if _, ok := choice["finish_reason"]; !ok { + choice["finish_reason"] = "stop" + } + } + } + + finalResponse["object"] = "chat.completion" + + jsonResult, err := json.Marshal(finalResponse) + if err != nil { + return streamingBody + } + + return string(jsonResult) +} + +func (r *OpenAIRecorder) GetRecordsByModelHandler() http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + w.Header().Set("Content-Type", "application/json") + + model := req.URL.Query().Get("model") + + if model == "" { + http.Error(w, "A 'model' query parameter is required", http.StatusBadRequest) + } else { + // Retrieve records for the specified model. + records := r.GetRecordsByModel(model) + if records == nil { + // No records found for the specified model. + http.Error(w, fmt.Sprintf("No records found for model '%s'", model), http.StatusNotFound) + return + } + + if err := json.NewEncoder(w).Encode(map[string]interface{}{ + "model": model, + "records": records, + "count": len(records), + "config": r.records[model].Config, + }); err != nil { + http.Error(w, fmt.Sprintf("Failed to encode records for model '%s': %v", model, err), + http.StatusInternalServerError) + return + } + } + } +} + +func (r *OpenAIRecorder) GetRecordsByModel(model string) []*RequestResponsePair { + r.m.RLock() + defer r.m.RUnlock() + + if modelData, exists := r.records[model]; exists { + result := make([]*RequestResponsePair, len(modelData.Records)) + copy(result, modelData.Records) + return result + } + + return nil +} + +func (r *OpenAIRecorder) RemoveModel(model string) { + r.m.Lock() + defer r.m.Unlock() + + if _, exists := r.records[model]; exists { + delete(r.records, model) + r.log.Infof("Removed records for model: %s", model) + } else { + r.log.Warnf("No records found for model: %s", model) + } +}