Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pkg/inference/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ func (m BackendMode) String() string {
}

type BackendConfiguration struct {
ContextSize int64
RawFlags []string
ContextSize int64 `json:"context_size,omitempty"`
RawFlags []string `json:"flags,omitempty"`
}

// Backend is the interface implemented by inference engine backends. Backend
Expand Down
7 changes: 6 additions & 1 deletion pkg/inference/scheduling/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/docker/model-runner/pkg/inference"
"github.com/docker/model-runner/pkg/inference/models"
"github.com/docker/model-runner/pkg/logging"
"github.com/docker/model-runner/pkg/metrics"
)

const (
Expand Down Expand Up @@ -92,13 +93,16 @@ type loader struct {
timestamps []time.Time
// runnerConfigs maps model names to runner configurations
runnerConfigs map[runnerKey]inference.BackendConfiguration
// openAIRecorder is used to record OpenAI API inference requests and responses.
openAIRecorder *metrics.OpenAIRecorder
}

// newLoader creates a new loader.
func newLoader(
log logging.Logger,
backends map[string]inference.Backend,
modelManager *models.Manager,
openAIRecorder *metrics.OpenAIRecorder,
) *loader {
// Compute the number of runner slots to allocate. Because of RAM and VRAM
// limitations, it's unlikely that we'll ever be able to fully populate
Expand Down Expand Up @@ -153,6 +157,7 @@ func newLoader(
allocations: make([]uint64, nSlots),
timestamps: make([]time.Time, nSlots),
runnerConfigs: make(map[runnerKey]inference.BackendConfiguration),
openAIRecorder: openAIRecorder,
}
l.guard <- struct{}{}
return l
Expand Down Expand Up @@ -462,7 +467,7 @@ func (l *loader) load(ctx context.Context, backendName, model string, mode infer
}
// Create the runner.
l.log.Infof("Loading %s backend runner with model %s in %s mode", backendName, model, mode)
runner, err := run(l.log, backend, model, mode, slot, runnerConfig)
runner, err := run(l.log, backend, model, mode, slot, runnerConfig, l.openAIRecorder)
if err != nil {
l.log.Warnf("Unable to start %s backend runner with model %s in %s mode: %v",
backendName, model, mode, err,
Expand Down
29 changes: 19 additions & 10 deletions pkg/inference/scheduling/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (

"github.com/docker/model-runner/pkg/inference"
"github.com/docker/model-runner/pkg/logging"
"github.com/docker/model-runner/pkg/metrics"
)

const (
Expand Down Expand Up @@ -63,6 +64,8 @@ type runner struct {
proxy *httputil.ReverseProxy
// proxyLog is the stream used for logging by proxy.
proxyLog io.Closer
// openAIRecorder is used to record OpenAI API inference requests and responses.
openAIRecorder *metrics.OpenAIRecorder
// err is the error returned by the runner's backend, only valid after done is closed.
err error
}
Expand All @@ -75,6 +78,7 @@ func run(
mode inference.BackendMode,
slot int,
runnerConfig *inference.BackendConfiguration,
openAIRecorder *metrics.OpenAIRecorder,
) (*runner, error) {
// Create a dialer / transport that target backend on the specified slot.
socket, err := RunnerSocketPath(slot)
Expand Down Expand Up @@ -124,16 +128,17 @@ func run(
runDone := make(chan struct{})

r := &runner{
log: log,
backend: backend,
model: model,
mode: mode,
cancel: runCancel,
done: runDone,
transport: transport,
client: client,
proxy: proxy,
proxyLog: proxyLog,
log: log,
backend: backend,
model: model,
mode: mode,
cancel: runCancel,
done: runDone,
transport: transport,
client: client,
proxy: proxy,
proxyLog: proxyLog,
openAIRecorder: openAIRecorder,
}

proxy.ErrorHandler = func(w http.ResponseWriter, req *http.Request, err error) {
Expand Down Expand Up @@ -164,6 +169,8 @@ func run(
}
}

r.openAIRecorder.SetConfigForModel(model, runnerConfig)

// Start the backend run loop.
go func() {
if err := backend.Run(runCtx, socket, model, mode, runnerConfig); err != nil {
Expand Down Expand Up @@ -236,6 +243,8 @@ func (r *runner) terminate() {
if err := r.proxyLog.Close(); err != nil {
r.log.Warnf("Unable to close reverse proxy log writer: %v", err)
}

r.openAIRecorder.RemoveModel(r.model)
}

// ServeHTTP implements net/http.Handler.ServeHTTP. It forwards requests to the
Expand Down
16 changes: 15 additions & 1 deletion pkg/inference/scheduling/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ type Scheduler struct {
router *http.ServeMux
// tracker is the metrics tracker.
tracker *metrics.Tracker
// openAIRecorder is used to record OpenAI API inference requests and responses.
openAIRecorder *metrics.OpenAIRecorder
// lock is used to synchronize access to the scheduler's router.
lock sync.Mutex
}
Expand All @@ -54,16 +56,19 @@ func NewScheduler(
allowedOrigins []string,
tracker *metrics.Tracker,
) *Scheduler {
openAIRecorder := metrics.NewOpenAIRecorder(log.WithField("component", "openai-recorder"))

// Create the scheduler.
s := &Scheduler{
log: log,
backends: backends,
defaultBackend: defaultBackend,
modelManager: modelManager,
installer: newInstaller(log, backends, httpClient),
loader: newLoader(log, backends, modelManager),
loader: newLoader(log, backends, modelManager, openAIRecorder),
router: http.NewServeMux(),
tracker: tracker,
openAIRecorder: openAIRecorder,
}

// Register routes.
Expand Down Expand Up @@ -115,6 +120,7 @@ func (s *Scheduler) routeHandlers(allowedOrigins []string) map[string]http.Handl
m["POST "+inference.InferencePrefix+"/unload"] = s.Unload
m["POST "+inference.InferencePrefix+"/{backend}/_configure"] = s.Configure
m["POST "+inference.InferencePrefix+"/_configure"] = s.Configure
m["GET "+inference.InferencePrefix+"/requests"] = s.openAIRecorder.GetRecordsByModelHandler()
return m
}

Expand Down Expand Up @@ -232,6 +238,14 @@ func (s *Scheduler) handleOpenAIInference(w http.ResponseWriter, r *http.Request
s.tracker.TrackModel(model)
}

// Record the request in the OpenAI recorder.
recordID := s.openAIRecorder.RecordRequest(request.Model, r, body)
w = s.openAIRecorder.NewResponseRecorder(w)
defer func() {
// Record the response in the OpenAI recorder.
s.openAIRecorder.RecordResponse(recordID, request.Model, w)
}()

// Request a runner to execute the request and defer its release.
runner, err := s.loader.load(r.Context(), backend.Name(), request.Model, backendMode)
if err != nil {
Expand Down
Loading