diff --git a/.gitignore b/.gitignore index 50268021..cbb8dfa8 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ .local .claude/ .specify/ +.idea/ specs/ design/ internal/airport_test_service/ diff --git a/client/client.go b/client/client.go index 9023797d..514ac59f 100644 --- a/client/client.go +++ b/client/client.go @@ -105,6 +105,17 @@ func WithTimeout(timeout time.Duration) Option { } } +func WithTraceID(traceID string) Option { + return func(c *ClientConfig) { + if traceID != "" { + c.Transport = &traceIDTransport{ + traceID: traceID, + transport: c.Transport, + } + } + } +} + func WithTransport(transport http.RoundTripper) Option { return func(c *ClientConfig) { c.Transport = transport @@ -759,6 +770,19 @@ func (t *noTimezoneTransport) RoundTrip(req *http.Request) (*http.Response, erro return t.transport.RoundTrip(req) } +type traceIDTransport struct { + traceID string + transport http.RoundTripper +} + +func (t *traceIDTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req.Header.Set("X-Trace-Id", t.traceID) + if t.transport == nil { + return http.DefaultTransport.RoundTrip(req) + } + return t.transport.RoundTrip(req) +} + // hasTimezoneTransport walks the transport chain and returns true if a // timezoneTransport or noTimezoneTransport is already present. func hasTimezoneTransport(rt http.RoundTripper) bool { @@ -776,6 +800,8 @@ func hasTimezoneTransport(rt http.RoundTripper) bool { rt = t.transport case *tokenTransport: rt = t.transport + case *traceIDTransport: + rt = t.transport default: return false } diff --git a/engine.go b/engine.go index b393da80..eab6d04c 100644 --- a/engine.go +++ b/engine.go @@ -7,6 +7,7 @@ import ( "log/slog" "net/http" "net/http/pprof" + "strings" "time" adminui "github.com/hugr-lab/query-engine/pkg/admin-ui" @@ -27,6 +28,7 @@ import ( mcpserver "github.com/hugr-lab/query-engine/pkg/mcp" permissions "github.com/hugr-lab/query-engine/pkg/perm" "github.com/hugr-lab/query-engine/pkg/planner" + "github.com/hugr-lab/query-engine/pkg/trace" "github.com/hugr-lab/query-engine/types" "github.com/vektah/gqlparser/v2/ast" @@ -34,7 +36,8 @@ import ( ) type Service struct { - config Config + config Config + logLevel *slog.LevelVar router *http.ServeMux adminUI http.HandlerFunc @@ -100,11 +103,18 @@ type Info struct { } func New(config Config) (*Service, error) { + level := &slog.LevelVar{} + if config.Debug { + level.Set(slog.LevelDebug) + } else { + level.Set(slog.LevelInfo) + } return &Service{ - config: config, - router: http.NewServeMux(), - cache: cache.New(config.Cache), - s3: storage.New(), + config: config, + logLevel: level, + router: http.NewServeMux(), + cache: cache.New(config.Cache), + s3: storage.New(), }, nil } @@ -415,6 +425,8 @@ func (s *Service) endpoints() { s.router.Handle("/gis/", mw(http.StripPrefix("/gis", s.gis))) } + s.router.Handle("/admin/log-level", mw(http.HandlerFunc(s.logLevelHandler))) + if s.config.MCPEnabled { mcpSrv := mcpserver.New(s, nil, s.config.Debug) s.router.Handle("/mcp", mw(mcpSrv.Handler())) @@ -473,12 +485,30 @@ func (s *Service) parseRequest(r *http.Request) (req types.Request, err error) { func (s *Service) ProcessQuery(ctx context.Context, req types.Request) types.Response { start := time.Now() + logger := trace.LoggerFromContext(ctx) + + if ti := trace.FromContext(ctx); ti != nil { + ctx = trace.StartSpan(ctx, "query.parse") + } op, err := s.schema.ParseQuery(ctx, req.Query, req.Variables, req.OperationName) + trace.EndSpan(ctx) if err != nil { return types.ErrResponse(err) } parseDuration := time.Since(start) + hasTraceDirective := op.Definition.Directives.ForName(base.TraceDirectiveName) != nil + if hasTraceDirective { + d := op.Definition.Directives.ForName(base.TraceDirectiveName) + ti := trace.FromContext(ctx) + if ti == nil { + ti = trace.NewTraceInfo(trace.TraceIDFromContext(ctx), parseLogLevel(d)) + ctx = trace.ContextWithTrace(ctx, ti) + } else { + ti.SetLevel(parseLogLevel(d)) + } + } + var hints []types.QueryHint if req.ValidateOnly { hints = append(hints, types.ValidateOnlyHint()) @@ -498,16 +528,28 @@ func (s *Service) ProcessQuery(ctx context.Context, req types.Request) types.Res if data != nil { res.Data = data } - if ext != nil { - if op.Definition.Directives.ForName(base.StatsDirectiveName) != nil { - opStats, ok := ext["stats"].(map[string]any) - if !ok { - opStats = make(map[string]any) - } - opStats["parse_time"] = parseDuration.String() - opStats["total_time"] = time.Since(start).String() - ext["stats"] = opStats + + if ext == nil { + ext = make(map[string]any) + } + if op.Definition.Directives.ForName(base.StatsDirectiveName) != nil { + opStats, ok := ext["stats"].(map[string]any) + if !ok { + opStats = make(map[string]any) } + opStats["parse_time"] = parseDuration.String() + opStats["total_time"] = time.Since(start).String() + ext["stats"] = opStats + } + + if ti := trace.FromContext(ctx); ti != nil { + logger.Debug("trace.spans", "spans", ti.Result()) + if hasTraceDirective { + ext["trace"] = ti.Result() + } + } + + if len(ext) > 0 { res.Extensions = ext } return res @@ -568,3 +610,64 @@ func (s *Service) Commit(ctx context.Context) error { func (s *Service) Rollback(ctx context.Context) error { return s.db.Rollback(ctx) } + +func (s *Service) logLevelHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + switch r.Method { + case http.MethodGet: + json.NewEncoder(w).Encode(map[string]string{ + "level": s.logLevel.Level().String(), + }) + case http.MethodPost: + var body struct { + Level string `json:"level"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(map[string]string{"error": "invalid request body"}) + return + } + var lvl slog.Level + switch strings.ToUpper(body.Level) { + case "DEBUG": + lvl = slog.LevelDebug + case "INFO": + lvl = slog.LevelInfo + case "WARN": + lvl = slog.LevelWarn + case "ERROR": + lvl = slog.LevelError + default: + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(map[string]string{"error": "invalid level, use: DEBUG, INFO, WARN, ERROR"}) + return + } + s.logLevel.Set(lvl) + slog.Info("log level changed", "level", lvl.String()) + json.NewEncoder(w).Encode(map[string]string{ + "level": lvl.String(), + }) + default: + w.WriteHeader(http.StatusMethodNotAllowed) + } +} + +func parseLogLevel(d *ast.Directive) slog.Level { + if d == nil { + return slog.LevelDebug + } + arg := d.Arguments.ForName("level") + if arg == nil || arg.Value == nil { + return slog.LevelDebug + } + switch strings.ToUpper(arg.Value.Raw) { + case "ERROR": + return slog.LevelError + case "WARN": + return slog.LevelWarn + case "INFO": + return slog.LevelInfo + default: + return slog.LevelDebug + } +} diff --git a/hugr-tracing-concept.md b/hugr-tracing-concept.md new file mode 100644 index 00000000..7ebc875d --- /dev/null +++ b/hugr-tracing-concept.md @@ -0,0 +1,429 @@ +# Hugr Query Engine: Request-level Tracing & Structured Logging + +## Problem + +Today hugr does not make it clear what happened while handling a specific request. Logging is a mix of `log.Printf` (when `Debug=true`) and scattered `slog.Warn/Error` calls with no request correlation. On HTTP data source errors (non-200 responses), logs omit the response body and which GraphQL request caused the call. In a cluster, logs from different nodes cannot be correlated. + +## Solution + +Lightweight built-in tracing with no external dependencies (not OpenTelemetry), based on `context.Context` (already threaded through the call chain) and the standard library `log/slog`. + +### Core: two context levels + +`context.Context` holds two independent values: + +1. **Logger with trace_id** — always created in middleware. Wraps `slog.Default()` with an attached `trace_id`. Cheap; correlates logs per request even in production. + +2. **TraceInfo (with spans)** — created when tracing is active: either global level <= DEBUG (`POST /admin/log-level`), or the `@trace` directive on the request. Holds a span tree. + +Each span is tied to the goroutine’s context (not a single shared pointer), which works correctly with `AllowParallel=true` — parallel fields build their own subtrees independently, while a mutex only protects `append` to the parent span’s `Children` array. + +The context is available at every layer, including DuckDB UDF callbacks (duckdb-go keeps `ctx` from `QueryContext` and passes it to `RowContextExecutor`). + +When tracing is off (production, no global debug), `TraceInfo` is not created. Instrumentation overhead at each site is one `ctx.Value() == nil` check (~5 ns). + +### Data structures + +```go +package trace + +// TraceInfo is the root structure, one per request. Stored in context.Context. +type TraceInfo struct { + TraceID string // unique request UUID + Level slog.Level // log level for this request + Root *Span // root span (virtual; holds children) + mu sync.Mutex // protects Children append under parallel fields + StartTime time.Time +} + +// Span is one unit of work. Organized as a tree via Children. +type Span struct { + Name string `json:"name"` + Duration string `json:"duration,omitempty"` + Children []*Span `json:"children,omitempty"` + attrs map[string]any // server log only (not serialized to client) + start time.Time // not serialized + parent *Span // back-reference to parent +} +``` + +The logger with `trace_id` is stored separately in the context (via `ContextWithLogger`), not inside `TraceInfo`, so every request can have a `trace_id` logger even when `TraceInfo` is not created (production). + +The current span is bound to the goroutine’s `context.Context`, not a shared pointer in `TraceInfo`. That keeps parallel field execution (`AllowParallel=true`) correct: each goroutine builds its subtree via its own `ctx`, and the mutex only guards appending a child span to the parent. + +### `pkg/trace` API + +```go +// ContextWithLogger puts a trace_id logger in the context (middleware, always). +func ContextWithLogger(ctx context.Context, logger *slog.Logger, traceID string) context.Context + +// ContextWithTrace puts TraceInfo in the context (when tracing is active). +func ContextWithTrace(ctx context.Context, info *TraceInfo) context.Context + +// FromContext returns TraceInfo or nil (tracing off). +func FromContext(ctx context.Context) *TraceInfo + +// TraceIDFromContext returns the trace ID from context, or "" (for propagation in a cluster). +func TraceIDFromContext(ctx context.Context) string + +// LoggerFromContext returns the trace_id logger (always available) or slog.Default(). +func LoggerFromContext(ctx context.Context) *slog.Logger + +// StartSpan creates a child span under the current span (from ctx) and returns a new ctx. +// If tracing is off — returns ctx unchanged. +func StartSpan(ctx context.Context, name string, attrs ...any) context.Context + +// EndSpan ends the current span (from ctx), records duration. No-op if tracing is off. +func EndSpan(ctx context.Context) +``` + +Typical hugr usage — check plus two lines: + +```go +if ti := trace.FromContext(ctx); ti != nil { + ctx = trace.StartSpan(ctx, "planner.plan", "field", field.Name) + defer trace.EndSpan(ctx) +} +``` + +Checking `FromContext` before `StartSpan` avoids variadic allocations when tracing is off (zero cost in production). `LoggerFromContext` is always available — with `trace_id`, regardless of whether tracing is active. + +### Three control modes + +| Mode | Activation | Logger with trace_id | Spans | Spans in extensions | +|---|---|---|---|---| +| Production (quiet) | Default | Yes | No | No | +| Global debug | `POST /admin/log-level` (no restart) | Yes | **Yes** (server log) | No | +| Per-request trace | `@trace` directive on the GraphQL request | Yes | **Yes** (server log) | **Yes** (to client) | + +`POST /admin/log-level` changes the global level via `slog.LevelVar` (atomic) — instantly, no restart. At DEBUG, middleware creates `TraceInfo` with spans for every request — the span tree is written to the server log but **not** sent to the client. The `@trace(level: DEBUG)` directive additionally emits spans in the response `extensions`. + +### @trace directive + +```graphql +directive @trace(level: LogLevel = DEBUG) on QUERY | MUTATION +enum LogLevel @system { ERROR WARN INFO DEBUG } +``` + +Example: + +```graphql +query GetData @trace(level: DEBUG) { + catalog1 { users { id name } } +} +``` + +After parsing the request (`ParseQuery`), the engine checks for `@trace` and, if present, creates `TraceInfo` (production without global debug) or adjusts the level (if `TraceInfo` already exists from global debug). + +### Integration points: flow from request to HTTP/Airport + +``` +HTTP Request +│ +▼ +┌──────────────────────────────────────────────────────────────────┐ +│ (1) traceMiddleware │ +│ • Generates TraceID (UUID) or reads X-Trace-Id from header │ +│ • Creates logger with trace_id (always) │ +│ • If global level <= DEBUG — creates TraceInfo with spans │ +│ • Sets X-Trace-Id on the response header │ +└──────────────────────┬───────────────────────────────────────────┘ + │ + ▼ +┌──────────────────────────────────────────────────────────────────┐ +│ (2) ProcessQuery │ +│ • span "query.parse" (if TraceInfo exists) │ +│ • Parses @trace directive → creates TraceInfo (if missing) │ +│ or switches level (if already present from global debug) │ +│ • logger.Debug("trace.spans") — writes tree to server log │ +│ • ext["trace"] = ti.Result() — only with @trace │ +└──────────────────────┬───────────────────────────────────────────┘ + │ + ▼ +┌──────────────────────────────────────────────────────────────────┐ +│ (3) processDataQuery (inside dataFunc) │ +│ • span "planner.plan" — build query plan │ +│ • logger.Debug("planner.sql", "sql", plan.Log()) │ +│ • logger.Debug("query.user", "user", ..., "role", ...) │ +│ • span "db.execute" — run SQL in DuckDB │ +│ • logger.Error on planning and execution errors │ +└──────────────────────┬───────────────────────────────────────────┘ + │ + ▼ +┌──────────────────────────────────────────────────────────────────┐ +│ (4) db.Pool / SQL execution │ +│ • ctx stored in duckdb-go contextStore by connId │ +│ • On UDF call — ctx is retrieved and passed to callback │ +└──────────────────────┬───────────────────────────────────────────┘ + │ + ▼ +┌──────────────────────────────────────────────────────────────────┐ +│ (5) Data Sources │ +│ │ +│ HTTP (pkg/data-sources/service.go → HttpRequest): │ +│ • span "http.request" │ +│ • logger.Warn("http.request.error") — on request error │ +│ • On non-200: read body preview (up to 512 bytes) │ +│ • logger.Warn("http.response.error", │ +│ "status", code, "body_preview", preview) │ +│ • logger.Warn("http.decode.error") — on JSON decode error │ +│ │ +│ OAuth (sources/http/client.go): │ +│ • logger.Warn("oauth.token.request.error") — request error │ +│ • logger.Warn("oauth.token.unauthorized") — 401 response │ +│ • logger.Warn("oauth.token.error") — non-200 with body preview │ +│ │ +│ Airport (DuckDB ATTACH → gRPC): │ +│ • Errors surface as DuckDB errors, logged with trace_id │ +│ • gRPC metadata is not controlled from Go │ +└──────────────────────┬───────────────────────────────────────────┘ + │ + ▼ +┌──────────────────────────────────────────────────────────────────┐ +│ (6) Response │ +│ • X-Trace-Id header on HTTP response (always) │ +│ • Spans to server log (global debug and @trace) │ +│ • Spans to client extensions — only with @trace │ +│ • extensions: names and timings only (security) │ +└──────────────────────────────────────────────────────────────────┘ +``` + +At each point, `trace.LoggerFromContext(ctx)` is used — a logger with `trace_id`. If tracing is off, `slog.Default()` is returned. + +### Implemented code + +**1. Middleware (middlewares.go) — logger with trace_id + TraceInfo when debug:** + +```go +func traceMiddleware(level *slog.LevelVar) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + traceID := r.Header.Get("X-Trace-Id") + if traceID == "" { + traceID = uuid.NewString() + } + w.Header().Set("X-Trace-Id", traceID) + + logger := slog.Default().With("trace_id", traceID) + ctx := trace.ContextWithLogger(r.Context(), logger, traceID) + + if level.Level() <= slog.LevelDebug { + info := trace.NewTraceInfo(traceID, level.Level()) + ctx = trace.ContextWithTrace(ctx, info) + } + + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} +``` + +**2. ProcessQuery (engine.go) — @trace handling and assembling the result:** + +```go +func (s *Service) ProcessQuery(ctx context.Context, req types.Request) types.Response { + start := time.Now() + logger := trace.LoggerFromContext(ctx) + + if ti := trace.FromContext(ctx); ti != nil { + ctx = trace.StartSpan(ctx, "query.parse") + } + op, err := s.schema.ParseQuery(ctx, req.Query, req.Variables, req.OperationName) + trace.EndSpan(ctx) + if err != nil { + return types.ErrResponse(err) + } + parseDuration := time.Since(start) + + hasTraceDirective := op.Definition.Directives.ForName(base.TraceDirectiveName) != nil + if hasTraceDirective { + d := op.Definition.Directives.ForName(base.TraceDirectiveName) + ti := trace.FromContext(ctx) + if ti == nil { + ti = trace.NewTraceInfo(trace.TraceIDFromContext(ctx), parseLogLevel(d)) + ctx = trace.ContextWithTrace(ctx, ti) + } else { + ti.SetLevel(parseLogLevel(d)) + } + } + + // ... ValidateOnly, ProcessOperation, @stats ... + + if ti := trace.FromContext(ctx); ti != nil { + logger.Debug("trace.spans", "spans", ti.Result()) + if hasTraceDirective { + ext["trace"] = ti.Result() + } + } + + if len(ext) > 0 { + res.Extensions = ext + } + return res +} +``` + +**3. processDataQuery (query.go) — spans and structured logging:** + +```go +func (s *Service) processDataQuery(ctx context.Context, provider catalog.Provider, + query base.QueryRequest, vars map[string]any) (data any, ext map[string]any, err error) { + // ... + logger := trace.LoggerFromContext(ctx) + + dataFunc := func() (any, error) { + if ti := trace.FromContext(ctx); ti != nil { + ctx = trace.StartSpan(ctx, "planner.plan", "field", query.Field.Name) + } + plan, err := s.planner.Plan(ctx, provider, query.Field, vars) + trace.EndSpan(ctx) + if err != nil { + logger.Error("planner.plan.error", "field", query.Field.Name, "error", err) + return nil, err + } + + // compile (errors logged; no separate span — compile is synchronous and fast) + err = plan.Compile() + if err != nil { + logger.Error("planner.compile.error", "field", query.Field.Name, "error", err) + return nil, err + } + + logger.Debug("planner.sql", "field", query.Field.Name, + "alias", query.Field.Alias, "sql", plan.Log()) + + if ai := auth.AuthInfoFromContext(ctx); ai != nil { + logger.Debug("query.user", "user", ai.UserName, "role", ai.Role, + "field", query.Field.Name) + } + + if ti := trace.FromContext(ctx); ti != nil { + ctx = trace.StartSpan(ctx, "db.execute", "field", query.Field.Name) + } + result, err := plan.Execute(ctx, s.db) + trace.EndSpan(ctx) + if err != nil { + logger.Error("db.execute.error", "field", query.Field.Name, "error", err) + } + return result, err + } + + // ... cache logic, @stats ... +} +``` + +**4. HttpRequest (pkg/data-sources/service.go) — HTTP error logging:** + +```go +func (s *Service) HttpRequest(ctx context.Context, source, path, method, + headers, params, body, jqq string) (any, error) { + + logger := trace.LoggerFromContext(ctx) + + if ti := trace.FromContext(ctx); ti != nil { + ctx = trace.StartSpan(ctx, "http.request", "source", source, "path", path, "method", method) + defer trace.EndSpan(ctx) + } + + // ... resolve data source ... + + res, err := httpDs.Request(ctx, path, method, headers, params, body) + if err != nil { + logger.Warn("http.request.error", "source", source, "path", path, "error", err) + return nil, err + } + defer res.Body.Close() + + if res.StatusCode != 200 { + preview, _ := io.ReadAll(io.LimitReader(res.Body, 512)) + logger.Warn("http.response.error", + "source", source, "path", path, "method", method, + "status", res.StatusCode, "body_preview", string(preview)) + return nil, fmt.Errorf("request failed with status code %d: %s", res.StatusCode, res.Status) + } + + var data any + if err := json.NewDecoder(res.Body).Decode(&data); err != nil { + logger.Warn("http.decode.error", "source", source, "path", path, "error", err) + return nil, err + } + return data, nil +} +``` + +**5. Server log — sample output:** + +``` +level=WARN msg="http.response.error" trace_id=550e8400-... source=ext_api path=/users method=GET status=502 body_preview="Bad Gateway" +level=DEBUG msg="planner.sql" trace_id=550e8400-... field=users sql="SELECT u.id, u.name FROM catalog1.users u" +level=ERROR msg="db.execute.error" trace_id=550e8400-... field=orders error="connection refused" +``` + +All lines carry `trace_id` — filter one request: `grep 550e8400`. + +### Log vs extensions split + +To avoid leaking data to the client (SQL, response bodies, internal structure): + +- **Server log (stderr):** full detail — SQL, body preview, headers, parameters +- **GraphQL extensions (to client):** span names and timings only, no sensitive fields + +```json +{ + "extensions": { + "trace": { + "trace_id": "550e8400-...", + "total_time": "42ms", + "spans": [ + { "name": "query.parse", "duration": "1ms" }, + { "name": "planner.plan", "duration": "3ms" }, + { "name": "db.execute", "duration": "37ms" }, + { "name": "http.request", "duration": "28ms" } + ] + } + } +} +``` + +### Operating in a cluster + +User requests run entirely on one node — tracing uses in-process `context.Context`. For cluster operations (`Broadcast`, `ForwardToManagement`), `X-Trace-Id` is added to outbound HTTP headers via `client.WithTraceID(traceID)`; middleware on the receiving node picks it up instead of generating a new one. One operation — one `trace_id` across nodes. + +### Airport (hugr-app) + +DuckDB calls Airport catalogs over gRPC via its C++ extension — Go does not control those calls. Passing `trace_id` in gRPC metadata would require changes to the DuckDB extension. From the engine side, the SQL against the Airport catalog and any DuckDB error are logged. + +## Implemented changes + +| File / package | What was implemented | +|---|---| +| `pkg/trace/` (new) | `TraceInfo`, `Span`, context helpers (`ContextWithLogger`, `ContextWithTrace`, `StartSpan`, `EndSpan`, `FromContext`, `LoggerFromContext`, `TraceIDFromContext`), tests (21 tests) | +| `query_directives.graphql` | `@trace` directive + `LogLevel` enum | +| `constants.go`, `directives.go` | Register `trace` as a query-side directive | +| `middlewares.go` | `traceMiddleware`: TraceID, `X-Trace-Id`, base `TraceInfo` | +| `engine.go` | `Service.logLevel *slog.LevelVar`, `GET/POST /admin/log-level`; in `ProcessQuery` — `@trace` handling, merge trace into extensions, `parseLogLevel` | +| `query.go` | Replace `log.Printf` (debug) with `trace.StartSpan` + `trace.LoggerFromContext`: spans `planner.plan`, `db.execute`; structured logging for SQL, errors, user info | +| `stream.go` | Replace `log.Printf` with `trace.LoggerFromContext` for stream requests | +| `subscription.go` | Replace `log.Printf` with `trace.LoggerFromContext` for subscription requests | +| `pkg/data-sources/service.go` | `HttpRequest`: span `http.request`, structured logging for non-200 with body preview (up to 512 bytes), request errors, JSON decode errors | +| `pkg/data-sources/sources/http/client.go` | OAuth: structured logging `oauth.token.request.error`, `oauth.token.unauthorized`, `oauth.token.error` with body preview | +| `client/client.go` | `WithTraceID` option + `traceIDTransport` to propagate `X-Trace-Id` in HTTP headers | +| `pkg/cluster/coordinator.go` | `Broadcast` propagates `trace_id` via `client.WithTraceID` | +| `pkg/cluster/worker.go` | `ForwardToManagement` propagates `trace_id` via `client.WithTraceID` | + +### Not implemented in this iteration + +- Replacing `log.Printf` in `graphql-ws.go` and `ipc-stream.go` — these are IPC/WebSocket lifecycle logs (connect, ping/pong, close), not tied to a specific GraphQL request; they need a separate approach to pass trace context over the WS session +- A dedicated span for `planner.compile` — compile is synchronous and fast; SQL logging via `logger.Debug("planner.sql")` already covers it + +## Out of scope for v1 + +- Restricting `@trace` by role/permissions +- Sampling under global debug at high load +- OpenTelemetry integration (can be added later on top of this design) +- Passing trace context into Airport/gRPC (DuckDB limitation) +- Trace-aware logging for IPC/WebSocket lifecycle + +## Separate bug (found during audit) + +In `pkg/planner/node_select_vector.go` (line 153), `queries.CreateEmbedding` uses `context.Background()` instead of the request context — request cancellation does not reach the embedding sub-query. Fix: pass `ctx` from the closure. diff --git a/middlewares.go b/middlewares.go index 1c4cb36f..7425a683 100644 --- a/middlewares.go +++ b/middlewares.go @@ -8,14 +8,17 @@ import ( "fmt" "io" "net" + "log/slog" "net/http" "strings" "time" "github.com/andybalholm/brotli" + "github.com/google/uuid" "github.com/hugr-lab/query-engine/pkg/auth" "github.com/hugr-lab/query-engine/pkg/db" "github.com/hugr-lab/query-engine/pkg/perm" + "github.com/hugr-lab/query-engine/pkg/trace" "github.com/hugr-lab/query-engine/types" ) @@ -56,6 +59,8 @@ func (s *Service) middlewares() func(next http.Handler) http.Handler { if s.perm != nil { mm = append(mm, s.checkEndpointPermissionsMW) } + // tracing + mm = append(mm, traceMiddleware(s.logLevel)) // timezone mm = append(mm, timezoneMW) // compress @@ -134,6 +139,28 @@ func timezoneMW(next http.Handler) http.Handler { }) } +func traceMiddleware(level *slog.LevelVar) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + traceID := r.Header.Get("X-Trace-Id") + if traceID == "" { + traceID = uuid.NewString() + } + w.Header().Set("X-Trace-Id", traceID) + + logger := slog.Default().With("trace_id", traceID) + ctx := trace.ContextWithLogger(r.Context(), logger, traceID) + + if level.Level() <= slog.LevelDebug { + info := trace.NewTraceInfo(traceID, level.Level()) + ctx = trace.ContextWithTrace(ctx, info) + } + + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + func compressMW(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Skip compression for WebSocket upgrades — the connection will be diff --git a/pkg/catalog/compiler/base/constants.go b/pkg/catalog/compiler/base/constants.go index 364ce69c..a51d4910 100644 --- a/pkg/catalog/compiler/base/constants.go +++ b/pkg/catalog/compiler/base/constants.go @@ -28,6 +28,7 @@ const ( UnnestDirectiveName = "unnest" NoPushdownDirectiveName = "no_pushdown" AtDirectiveName = "at" + TraceDirectiveName = "trace" ) // Query/mutation directive names. diff --git a/pkg/catalog/compiler/base/directives.go b/pkg/catalog/compiler/base/directives.go index 1095cdb1..2518c090 100644 --- a/pkg/catalog/compiler/base/directives.go +++ b/pkg/catalog/compiler/base/directives.go @@ -75,6 +75,7 @@ func QuerySideDirectives() []string { GisFeatureDirectiveName, NoPushdownDirectiveName, AtDirectiveName, + TraceDirectiveName, } } diff --git a/pkg/catalog/compiler/base/query_directives.graphql b/pkg/catalog/compiler/base/query_directives.graphql index 001c350c..cf51dda7 100644 --- a/pkg/catalog/compiler/base/query_directives.graphql +++ b/pkg/catalog/compiler/base/query_directives.graphql @@ -32,6 +32,18 @@ directive @add_h3( simplify: Boolean = true ) on FIELD +""" +Enable request-level tracing. Spans are collected and returned in extensions.trace. +""" +directive @trace(level: LogLevel = DEBUG) on QUERY | MUTATION + +enum LogLevel @system { + ERROR + WARN + INFO + DEBUG +} + enum GeometrySpatialQueryType @system { """ checks if the geometries intersect diff --git a/pkg/cluster/coordinator.go b/pkg/cluster/coordinator.go index 6fcf1e45..1125f1df 100644 --- a/pkg/cluster/coordinator.go +++ b/pkg/cluster/coordinator.go @@ -8,6 +8,7 @@ import ( "github.com/hugr-lab/query-engine/client" "github.com/hugr-lab/query-engine/pkg/auth" + "github.com/hugr-lab/query-engine/pkg/trace" "github.com/hugr-lab/query-engine/types" ) @@ -27,11 +28,13 @@ func NewCoordinator(config ClusterConfig, qe types.Querier) *Coordinator { } } -func (c *Coordinator) newClient(url string) *client.Client { - return client.NewClient(url, +func (c *Coordinator) newClient(url string, opts ...client.Option) *client.Client { + baseOpts := []client.Option{ client.WithApiKeyCustomHeader(c.config.Secret, "x-hugr-secret"), client.WithTimeout(c.config.Heartbeat), - ) + } + baseOpts = append(baseOpts, opts...) + return client.NewClient(url, baseOpts...) } // ActiveWorkers reads active worker nodes via GraphQL. @@ -76,12 +79,13 @@ func (c *Coordinator) Broadcast(ctx context.Context, query string, vars map[stri results := make([]NodeResult, len(workers)) var wg sync.WaitGroup + traceID := trace.TraceIDFromContext(ctx) for i, w := range workers { wg.Add(1) go func(idx int, node NodeInfo) { defer wg.Done() - client := c.newClient(node.URL) - res, err := client.Query(ctx, query, vars) + cl := c.newClient(node.URL, client.WithTraceID(traceID)) + res, err := cl.Query(ctx, query, vars) if err != nil { results[idx] = NodeResult{Node: node.Name, Error: err.Error()} return diff --git a/pkg/cluster/worker.go b/pkg/cluster/worker.go index b90a3e4a..30c1994a 100644 --- a/pkg/cluster/worker.go +++ b/pkg/cluster/worker.go @@ -11,6 +11,7 @@ import ( "github.com/hugr-lab/query-engine/client" "github.com/hugr-lab/query-engine/pkg/auth" "github.com/hugr-lab/query-engine/pkg/db" + "github.com/hugr-lab/query-engine/pkg/trace" "github.com/hugr-lab/query-engine/types" ) @@ -35,11 +36,13 @@ func NewWorkerClient(config ClusterConfig, qe types.Querier, pool *db.Pool) *Wor } } -func (w *WorkerClient) newClient(url string) *client.Client { - return client.NewClient(url, +func (w *WorkerClient) newClient(url string, opts ...client.Option) *client.Client { + baseOpts := []client.Option{ client.WithApiKeyCustomHeader(w.config.Secret, "x-hugr-secret"), client.WithTimeout(w.config.Heartbeat), - ) + } + baseOpts = append(baseOpts, opts...) + return client.NewClient(url, baseOpts...) } // ManagementURL discovers and caches the management node URL via GraphQL. @@ -83,7 +86,8 @@ func (w *WorkerClient) ForwardToManagement(ctx context.Context, query string, va return types.ErrResult(fmt.Errorf("cannot find management node: %w", err)), nil } - c := w.newClient(mgmtURL) + traceID := trace.TraceIDFromContext(ctx) + c := w.newClient(mgmtURL, client.WithTraceID(traceID)) res, err := c.Query(ctx, query, vars) if err != nil { return types.ErrResult(fmt.Errorf("forward to management: %w", err)), nil diff --git a/pkg/data-sources/service.go b/pkg/data-sources/service.go index 8c8fb36e..6000e617 100644 --- a/pkg/data-sources/service.go +++ b/pkg/data-sources/service.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "log/slog" "sync" @@ -25,6 +26,7 @@ import ( "github.com/hugr-lab/query-engine/pkg/db" "github.com/hugr-lab/query-engine/pkg/engines" "github.com/hugr-lab/query-engine/pkg/jq" + "github.com/hugr-lab/query-engine/pkg/trace" "github.com/hugr-lab/query-engine/types" //lint:ignore ST1001 "github.com/hugr-lab/query-engine/pkg/data-sources/sources" is a valid package name @@ -364,6 +366,13 @@ func NewDataSource(ctx context.Context, ds types.DataSource, attached bool) (Sou } func (s *Service) HttpRequest(ctx context.Context, source, path, method, headers, params, body, jqq string) (any, error) { + logger := trace.LoggerFromContext(ctx) + + if ti := trace.FromContext(ctx); ti != nil { + ctx = trace.StartSpan(ctx, "http.request", "source", source, "path", path, "method", method) + defer trace.EndSpan(ctx) + } + s.mu.RLock() defer s.mu.RUnlock() ds, ok := s.dataSources[source] @@ -379,6 +388,7 @@ func (s *Service) HttpRequest(ctx context.Context, source, path, method, headers } res, err := httpDs.Request(ctx, path, method, headers, params, body) if err != nil { + logger.Warn("http.request.error", "source", source, "path", path, "error", err) return nil, err } defer res.Body.Close() @@ -386,7 +396,11 @@ func (s *Service) HttpRequest(ctx context.Context, source, path, method, headers return nil, nil } if res.StatusCode != 200 { - return nil, fmt.Errorf("request failed with status code %d:%s", res.StatusCode, res.Status) + preview, _ := io.ReadAll(io.LimitReader(res.Body, 512)) + logger.Warn("http.response.error", + "source", source, "path", path, "method", method, + "status", res.StatusCode, "body_preview", string(preview)) + return nil, fmt.Errorf("request failed with status code %d: %s", res.StatusCode, res.Status) } if res.Body == nil { return nil, errors.New("response body is nil") @@ -394,6 +408,7 @@ func (s *Service) HttpRequest(ctx context.Context, source, path, method, headers var data any err = json.NewDecoder(res.Body).Decode(&data) if err != nil { + logger.Warn("http.decode.error", "source", source, "path", path, "error", err) return nil, err } if jqq != "" { diff --git a/pkg/data-sources/sources/http/client.go b/pkg/data-sources/sources/http/client.go index f9cf7faa..e2717a4f 100644 --- a/pkg/data-sources/sources/http/client.go +++ b/pkg/data-sources/sources/http/client.go @@ -15,6 +15,7 @@ import ( "github.com/getkin/kin-openapi/openapi3" "github.com/hugr-lab/query-engine/pkg/jq" + "github.com/hugr-lab/query-engine/pkg/trace" "golang.org/x/oauth2" "golang.org/x/oauth2/clientcredentials" ) @@ -636,18 +637,23 @@ func customTokenRequest(ctx context.Context, tokenUrl string, data any, param *t req.Header.Set("Content-Type", "application/json") } + logger := trace.LoggerFromContext(ctx) res, err := http.DefaultClient.Do(req) if err != nil { + logger.Warn("oauth.token.request.error", "url", tokenUrl, "error", err) return nil, err } defer res.Body.Close() if res.StatusCode == http.StatusUnauthorized { + logger.Warn("oauth.token.unauthorized", "url", tokenUrl, "status", res.Status) e := ErrUnauthorizedTokenRequest(res.Status) return nil, &e } if res.StatusCode != http.StatusOK { msg, err := io.ReadAll(res.Body) if err == nil { + logger.Warn("oauth.token.error", "url", tokenUrl, "status", res.StatusCode, + "body_preview", string(msg)) return nil, fmt.Errorf("unexpected status code %d: %s", res.StatusCode, msg) } return nil, fmt.Errorf("unexpected status code %d", res.StatusCode) diff --git a/pkg/trace/middleware_test.go b/pkg/trace/middleware_test.go new file mode 100644 index 00000000..2e3b4b94 --- /dev/null +++ b/pkg/trace/middleware_test.go @@ -0,0 +1,96 @@ +package trace_test + +import ( + "log/slog" + "net/http" + "net/http/httptest" + "testing" + + "github.com/hugr-lab/query-engine/pkg/trace" +) + +func TestTraceMiddlewareIntegration(t *testing.T) { + var capturedLogger *slog.Logger + var capturedTraceID string + var capturedTraceInfo *trace.TraceInfo + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + capturedLogger = trace.LoggerFromContext(ctx) + capturedTraceID = trace.TraceIDFromContext(ctx) + capturedTraceInfo = trace.FromContext(ctx) + w.WriteHeader(http.StatusOK) + }) + + t.Run("generates trace_id when not provided", func(t *testing.T) { + req := httptest.NewRequest("GET", "/query", nil) + w := httptest.NewRecorder() + + ctx := trace.ContextWithLogger(req.Context(), slog.Default(), "") + req = req.WithContext(ctx) + + handler.ServeHTTP(w, req) + + if capturedLogger == nil { + t.Error("logger should not be nil even without middleware") + } + }) + + t.Run("propagates X-Trace-Id from request", func(t *testing.T) { + req := httptest.NewRequest("GET", "/query", nil) + req.Header.Set("X-Trace-Id", "my-trace-123") + + logger := slog.Default().With("trace_id", "my-trace-123") + ctx := trace.ContextWithLogger(req.Context(), logger, "my-trace-123") + req = req.WithContext(ctx) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + if capturedTraceID != "my-trace-123" { + t.Errorf("trace ID = %q, want %q", capturedTraceID, "my-trace-123") + } + }) + + t.Run("TraceInfo created when debug level", func(t *testing.T) { + req := httptest.NewRequest("GET", "/query", nil) + logger := slog.Default().With("trace_id", "debug-test") + ctx := trace.ContextWithLogger(req.Context(), logger, "debug-test") + ti := trace.NewTraceInfo("debug-test", slog.LevelDebug) + ctx = trace.ContextWithTrace(ctx, ti) + req = req.WithContext(ctx) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + if capturedTraceInfo == nil { + t.Error("TraceInfo should be present when debug level is set") + } + }) + + t.Run("TraceInfo nil when info level", func(t *testing.T) { + capturedTraceInfo = nil + req := httptest.NewRequest("GET", "/query", nil) + logger := slog.Default().With("trace_id", "info-test") + ctx := trace.ContextWithLogger(req.Context(), logger, "info-test") + req = req.WithContext(ctx) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + if capturedTraceInfo != nil { + t.Error("TraceInfo should be nil in production (info level)") + } + }) +} + +func TestResultContainsTraceID(t *testing.T) { + ti := trace.NewTraceInfo("tree-test", slog.LevelDebug) + result := ti.Result() + if result["trace_id"] != "tree-test" { + t.Errorf("trace_id = %v, want %q", result["trace_id"], "tree-test") + } + if result["total_time"] == nil { + t.Error("total_time should be set") + } +} diff --git a/pkg/trace/trace.go b/pkg/trace/trace.go new file mode 100644 index 00000000..da6e6f01 --- /dev/null +++ b/pkg/trace/trace.go @@ -0,0 +1,180 @@ +package trace + +import ( + "context" + "log/slog" + "sync" + "time" +) + +// TraceInfo holds per-request tracing state: trace ID, log level, and span tree. +// Created when tracing is active (global debug or @trace directive). +type TraceInfo struct { + TraceID string + Level slog.Level + Root *Span + mu sync.Mutex + StartTime time.Time +} + +// Span represents a timed segment of work within a request. +// Spans form a tree via Children. +type Span struct { + Name string `json:"name"` + Duration string `json:"duration,omitempty"` + Children []*Span `json:"children,omitempty"` + attrs map[string]any + start time.Time + parent *Span +} + +// NewTraceInfo creates a new TraceInfo with a root span. +func NewTraceInfo(traceID string, level slog.Level) *TraceInfo { + return &TraceInfo{ + TraceID: traceID, + Level: level, + Root: &Span{Name: "root", start: time.Now()}, + StartTime: time.Now(), + } +} + +// SetLevel updates the log level for this request. +func (t *TraceInfo) SetLevel(level slog.Level) { + t.mu.Lock() + defer t.mu.Unlock() + t.Level = level +} + +// Result returns a safe-for-client representation of the trace (names and durations only). +func (t *TraceInfo) Result() map[string]any { + t.mu.Lock() + defer t.mu.Unlock() + return map[string]any{ + "trace_id": t.TraceID, + "total_time": time.Since(t.StartTime).String(), + "spans": spansResult(t.Root.Children), + } +} + +func spansResult(spans []*Span) []map[string]any { + if len(spans) == 0 { + return nil + } + result := make([]map[string]any, len(spans)) + for i, s := range spans { + m := map[string]any{ + "name": s.Name, + "duration": s.Duration, + } + if children := spansResult(s.Children); children != nil { + m["children"] = children + } + result[i] = m + } + return result +} + +type traceInfoKey struct{} +type loggerKey struct{} +type traceIDKey struct{} +type spanKey struct{} + +// ContextWithLogger stores a trace-aware logger and trace ID in the context. +func ContextWithLogger(ctx context.Context, logger *slog.Logger, traceID string) context.Context { + ctx = context.WithValue(ctx, loggerKey{}, logger) + ctx = context.WithValue(ctx, traceIDKey{}, traceID) + return ctx +} + +// ContextWithTrace stores TraceInfo in the context. +func ContextWithTrace(ctx context.Context, info *TraceInfo) context.Context { + ctx = context.WithValue(ctx, traceInfoKey{}, info) + return context.WithValue(ctx, spanKey{}, info.Root) +} + +// FromContext returns TraceInfo from context, or nil if tracing is not active. +func FromContext(ctx context.Context) *TraceInfo { + if v := ctx.Value(traceInfoKey{}); v != nil { + if ti, ok := v.(*TraceInfo); ok { + return ti + } + } + return nil +} + +// TraceIDFromContext returns the trace ID string from context, or "". +func TraceIDFromContext(ctx context.Context) string { + if v := ctx.Value(traceIDKey{}); v != nil { + if id, ok := v.(string); ok { + return id + } + } + return "" +} + +// LoggerFromContext returns the trace-aware logger, or slog.Default(). +func LoggerFromContext(ctx context.Context) *slog.Logger { + if v := ctx.Value(loggerKey{}); v != nil { + if l, ok := v.(*slog.Logger); ok { + return l + } + } + return slog.Default() +} + +func currentSpan(ctx context.Context) *Span { + if v := ctx.Value(spanKey{}); v != nil { + if s, ok := v.(*Span); ok { + return s + } + } + return nil +} + +// StartSpan creates a child span under the current span (from ctx). +// Returns a new context with the child span as current. +// If tracing is not active, returns ctx unchanged. +func StartSpan(ctx context.Context, name string, attrs ...any) context.Context { + ti := FromContext(ctx) + if ti == nil { + return ctx + } + parent := currentSpan(ctx) + if parent == nil { + parent = ti.Root + } + span := &Span{ + Name: name, + start: time.Now(), + parent: parent, + attrs: attrsToMap(attrs), + } + ti.mu.Lock() + parent.Children = append(parent.Children, span) + ti.mu.Unlock() + + return context.WithValue(ctx, spanKey{}, span) +} + +// EndSpan ends the current span (from ctx), recording its duration. +// Safe to call when tracing is not active (no-op). +func EndSpan(ctx context.Context) { + span := currentSpan(ctx) + if span == nil || span.start.IsZero() { + return + } + span.Duration = time.Since(span.start).String() +} + +func attrsToMap(attrs []any) map[string]any { + if len(attrs) == 0 { + return nil + } + m := make(map[string]any, len(attrs)/2) + for i := 0; i+1 < len(attrs); i += 2 { + if key, ok := attrs[i].(string); ok { + m[key] = attrs[i+1] + } + } + return m +} diff --git a/pkg/trace/trace_test.go b/pkg/trace/trace_test.go new file mode 100644 index 00000000..c5b8075d --- /dev/null +++ b/pkg/trace/trace_test.go @@ -0,0 +1,237 @@ +package trace + +import ( + "context" + "log/slog" + "sync" + "testing" + "time" +) + +func TestContextWithLogger(t *testing.T) { + ctx := context.Background() + logger := slog.Default().With("trace_id", "test-123") + ctx = ContextWithLogger(ctx, logger, "test-123") + + got := LoggerFromContext(ctx) + if got != logger { + t.Error("LoggerFromContext should return the stored logger") + } + + id := TraceIDFromContext(ctx) + if id != "test-123" { + t.Errorf("TraceIDFromContext = %q, want %q", id, "test-123") + } +} + +func TestLoggerFromContext_Default(t *testing.T) { + ctx := context.Background() + got := LoggerFromContext(ctx) + if got != slog.Default() { + t.Error("LoggerFromContext on empty context should return slog.Default()") + } +} + +func TestTraceIDFromContext_Empty(t *testing.T) { + ctx := context.Background() + if id := TraceIDFromContext(ctx); id != "" { + t.Errorf("TraceIDFromContext on empty context = %q, want empty", id) + } +} + +func TestFromContext_Nil(t *testing.T) { + ctx := context.Background() + if ti := FromContext(ctx); ti != nil { + t.Error("FromContext on empty context should return nil") + } +} + +func TestNewTraceInfo(t *testing.T) { + ti := NewTraceInfo("abc-123", slog.LevelDebug) + if ti.TraceID != "abc-123" { + t.Errorf("TraceID = %q, want %q", ti.TraceID, "abc-123") + } + if ti.Level != slog.LevelDebug { + t.Errorf("Level = %v, want %v", ti.Level, slog.LevelDebug) + } + if ti.Root == nil { + t.Fatal("Root span should not be nil") + } + if len(ti.Root.Children) != 0 { + t.Error("Root should have no children initially") + } +} + +func TestStartSpan_EndSpan(t *testing.T) { + ti := NewTraceInfo("test", slog.LevelDebug) + ctx := context.Background() + ctx = ContextWithTrace(ctx, ti) + + ctx2 := StartSpan(ctx, "planner.plan", "field", "users") + time.Sleep(time.Millisecond) + EndSpan(ctx2) + + if len(ti.Root.Children) != 1 { + t.Fatalf("expected 1 child span, got %d", len(ti.Root.Children)) + } + span := ti.Root.Children[0] + if span.Name != "planner.plan" { + t.Errorf("span name = %q, want %q", span.Name, "planner.plan") + } + if span.Duration == "" { + t.Error("span duration should be set after EndSpan") + } + if span.attrs["field"] != "users" { + t.Errorf("span attrs[field] = %v, want %q", span.attrs["field"], "users") + } +} + +func TestStartSpan_NoTrace(t *testing.T) { + ctx := context.Background() + ctx2 := StartSpan(ctx, "should.noop") + if ctx2 != ctx { + t.Error("StartSpan without TraceInfo should return the same context") + } +} + +func TestEndSpan_NoTrace(t *testing.T) { + EndSpan(context.Background()) +} + +func TestNestedSpans(t *testing.T) { + ti := NewTraceInfo("test", slog.LevelDebug) + ctx := ContextWithTrace(context.Background(), ti) + + ctx1 := StartSpan(ctx, "parent") + ctx2 := StartSpan(ctx1, "child") + EndSpan(ctx2) + EndSpan(ctx1) + + if len(ti.Root.Children) != 1 { + t.Fatalf("expected 1 root child, got %d", len(ti.Root.Children)) + } + parent := ti.Root.Children[0] + if parent.Name != "parent" { + t.Errorf("parent span name = %q, want %q", parent.Name, "parent") + } + if len(parent.Children) != 1 { + t.Fatalf("expected 1 child under parent, got %d", len(parent.Children)) + } + child := parent.Children[0] + if child.Name != "child" { + t.Errorf("child span name = %q, want %q", child.Name, "child") + } +} + +func TestParallelSpans(t *testing.T) { + ti := NewTraceInfo("test", slog.LevelDebug) + ctx := ContextWithTrace(context.Background(), ti) + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + spanCtx := StartSpan(ctx, "parallel", "idx", idx) + time.Sleep(time.Millisecond) + EndSpan(spanCtx) + }(i) + } + wg.Wait() + + if len(ti.Root.Children) != 10 { + t.Errorf("expected 10 children, got %d", len(ti.Root.Children)) + } +} + +func TestParallelSpans_CorrectParenting(t *testing.T) { + ti := NewTraceInfo("test", slog.LevelDebug) + ctx := ContextWithTrace(context.Background(), ti) + + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + fieldCtx := StartSpan(ctx, "field") + childCtx := StartSpan(fieldCtx, "planner") + EndSpan(childCtx) + EndSpan(fieldCtx) + }(i) + } + wg.Wait() + + if len(ti.Root.Children) != 5 { + t.Fatalf("expected 5 root children, got %d", len(ti.Root.Children)) + } + for _, span := range ti.Root.Children { + if span.Name != "field" { + t.Errorf("root child name = %q, want %q", span.Name, "field") + } + if len(span.Children) != 1 { + t.Errorf("expected 1 child under field, got %d", len(span.Children)) + continue + } + if span.Children[0].Name != "planner" { + t.Errorf("child name = %q, want %q", span.Children[0].Name, "planner") + } + } +} + +func TestResult(t *testing.T) { + ti := NewTraceInfo("abc", slog.LevelDebug) + ctx := ContextWithTrace(context.Background(), ti) + + ctx1 := StartSpan(ctx, "parse") + EndSpan(ctx1) + + ctx2 := StartSpan(ctx, "execute") + EndSpan(ctx2) + + result := ti.Result() + if result["trace_id"] != "abc" { + t.Errorf("trace_id = %v, want %q", result["trace_id"], "abc") + } + spans, ok := result["spans"].([]map[string]any) + if !ok { + t.Fatalf("spans type = %T, want []map[string]any", result["spans"]) + } + if len(spans) != 2 { + t.Errorf("expected 2 spans, got %d", len(spans)) + } +} + +func TestSetLevel(t *testing.T) { + ti := NewTraceInfo("test", slog.LevelInfo) + if ti.Level != slog.LevelInfo { + t.Errorf("initial level = %v, want %v", ti.Level, slog.LevelInfo) + } + ti.SetLevel(slog.LevelDebug) + if ti.Level != slog.LevelDebug { + t.Errorf("after SetLevel = %v, want %v", ti.Level, slog.LevelDebug) + } +} + +func TestAttrsToMap(t *testing.T) { + m := attrsToMap([]any{"key1", "val1", "key2", 42}) + if m["key1"] != "val1" { + t.Errorf("key1 = %v, want %q", m["key1"], "val1") + } + if m["key2"] != 42 { + t.Errorf("key2 = %v, want %d", m["key2"], 42) + } +} + +func TestAttrsToMap_Empty(t *testing.T) { + m := attrsToMap(nil) + if m != nil { + t.Errorf("attrsToMap(nil) = %v, want nil", m) + } +} + +func TestAttrsToMap_OddCount(t *testing.T) { + m := attrsToMap([]any{"key1", "val1", "orphan"}) + if len(m) != 1 { + t.Errorf("expected 1 entry, got %d", len(m)) + } +} diff --git a/query.go b/query.go index defc73e8..e46a7e17 100644 --- a/query.go +++ b/query.go @@ -18,6 +18,7 @@ import ( "github.com/hugr-lab/query-engine/pkg/db" "github.com/hugr-lab/query-engine/pkg/jq" "github.com/hugr-lab/query-engine/pkg/metadata" + "github.com/hugr-lab/query-engine/pkg/trace" "github.com/hugr-lab/query-engine/types" "github.com/vektah/gqlparser/v2/ast" "golang.org/x/sync/errgroup" @@ -311,43 +312,45 @@ func (s *Service) processQueryParallel( func (s *Service) processDataQuery(ctx context.Context, provider catalog.Provider, query base.QueryRequest, vars map[string]any) (data any, ext map[string]any, err error) { defer recoverPanic(&err) start := time.Now() + logger := trace.LoggerFromContext(ctx) var plannerTime, compileTime time.Duration dataFunc := func() (any, error) { + if ti := trace.FromContext(ctx); ti != nil { + ctx = trace.StartSpan(ctx, "planner.plan", "field", query.Field.Name) + } plan, err := s.planner.Plan(ctx, provider, query.Field, vars) + trace.EndSpan(ctx) if err != nil { + logger.Error("planner.plan.error", "field", query.Field.Name, "error", err) return nil, err } plannerTime = time.Since(start) err = plan.Compile() if err != nil { + logger.Error("planner.compile.error", "field", query.Field.Name, "error", err) return nil, err } compileTime = time.Since(start) - if s.config.Debug { - ai := auth.AuthInfoFromContext(ctx) - if ai != nil { - log.Printf("User: %s, Role: %s, Query: %s (%s), SQL: %s", - ai.UserName, - ai.Role, - query.Field.Alias, - query.Field.Name, - plan.Log(), - ) - } - if auth.IsFullAccess(ctx) { - log.Printf("Internal query: %s (%s), SQL: %s", - query.Field.Alias, - query.Field.Name, - plan.Log(), - ) - } + logger.Debug("planner.sql", "field", query.Field.Name, + "alias", query.Field.Alias, "sql", plan.Log()) + + if ai := auth.AuthInfoFromContext(ctx); ai != nil { + logger.Debug("query.user", "user", ai.UserName, "role", ai.Role, + "field", query.Field.Name) } if types.IsValidateOnlyContext(ctx) { return nil, nil } - // execute query - return plan.Execute(ctx, s.db) + if ti := trace.FromContext(ctx); ti != nil { + ctx = trace.StartSpan(ctx, "db.execute", "field", query.Field.Name) + } + result, err := plan.Execute(ctx, s.db) + trace.EndSpan(ctx) + if err != nil { + logger.Error("db.execute.error", "field", query.Field.Name, "error", err) + } + return result, err } ci := cache.QueryInfo(query.Field, vars) diff --git a/stream.go b/stream.go index 72712404..7a299b00 100644 --- a/stream.go +++ b/stream.go @@ -11,6 +11,7 @@ import ( "github.com/hugr-lab/query-engine/pkg/auth" "github.com/hugr-lab/query-engine/pkg/catalog/compiler/base" "github.com/hugr-lab/query-engine/pkg/catalog/sdl" + "github.com/hugr-lab/query-engine/pkg/trace" "github.com/hugr-lab/query-engine/types" "github.com/vektah/gqlparser/v2/ast" ) @@ -69,24 +70,12 @@ func (s *Service) ProcessStreamQuery(ctx context.Context, query string, vars map return nil, nil, fmt.Errorf("failed to compile query: %w", err) } - if s.config.Debug { - ai := auth.AuthInfoFromContext(ctx) - if ai != nil { - log.Printf("Stream: User: %s, Role: %s, Query: %s (%s), SQL: %s", - ai.UserName, - ai.Role, - q.Field.Alias, - q.Field.Name, - plan.Log(), - ) - } - if auth.IsFullAccess(ctx) { - log.Printf("Stream: Internal query: %s (%s), SQL: %s", - q.Field.Alias, - q.Field.Name, - plan.Log(), - ) - } + logger := trace.LoggerFromContext(ctx) + logger.Debug("stream.sql", "field", q.Field.Name, "alias", q.Field.Alias, "sql", plan.Log()) + if ai := auth.AuthInfoFromContext(ctx); ai != nil { + logger.Debug("stream.user", "user", ai.UserName, "role", ai.Role, "field", q.Field.Name) + } else if auth.IsFullAccess(ctx) { + logger.Debug("stream.user", "internal", true, "field", q.Field.Name, "alias", q.Field.Alias) } return plan.ExecuteStream(ctx, s.db) diff --git a/subscription.go b/subscription.go index 7fcc410c..b0b641e7 100644 --- a/subscription.go +++ b/subscription.go @@ -14,6 +14,7 @@ import ( "github.com/hugr-lab/query-engine/pkg/catalog/compiler/base" "github.com/hugr-lab/query-engine/pkg/catalog/sdl" "github.com/hugr-lab/query-engine/pkg/data-sources/sources" + "github.com/hugr-lab/query-engine/pkg/trace" "github.com/hugr-lab/query-engine/types" "github.com/vektah/gqlparser/v2/ast" ) @@ -105,9 +106,7 @@ func (s *Service) executeQueryTick(ctx context.Context, queries []base.QueryRequ reader, err := s.executeStreamPath(ctx, provider, q, vars) if err != nil { - if s.config.Debug { - log.Printf("subscription path %s error: %v", path, err) - } + trace.LoggerFromContext(ctx).Warn("subscription.path.error", "path", path, "error", err) return } select { @@ -130,14 +129,12 @@ func (s *Service) executeStreamPath(ctx context.Context, provider catalog.Provid return nil, fmt.Errorf("compile: %w", err) } - if s.config.Debug { - if ai := auth.AuthInfoFromContext(ctx); ai != nil { - log.Printf("Subscription stream: User: %s, Role: %s, Query: %s (%s), SQL: %s", - ai.UserName, ai.Role, q.Field.Alias, q.Field.Name, plan.Log()) - } else if auth.IsFullAccess(ctx) { - log.Printf("Subscription stream: Internal: %s (%s), SQL: %s", - q.Field.Alias, q.Field.Name, plan.Log()) - } + logger := trace.LoggerFromContext(ctx) + logger.Debug("subscription.sql", "field", q.Field.Name, "alias", q.Field.Alias, "sql", plan.Log()) + if ai := auth.AuthInfoFromContext(ctx); ai != nil { + logger.Debug("subscription.user", "user", ai.UserName, "role", ai.Role, "field", q.Field.Name) + } else if auth.IsFullAccess(ctx) { + logger.Debug("subscription.user", "internal", true, "field", q.Field.Name, "alias", q.Field.Alias) } table, finalize, err := plan.ExecuteStream(ctx, s.db) @@ -315,9 +312,9 @@ func (r *metadataReader) Next() bool { } func (r *metadataReader) Record() arrow.RecordBatch { return r.current } -func (r *metadataReader) RecordBatch() arrow.RecordBatch { return r.current } -func (r *metadataReader) Err() error { return r.reader.Err() } -func (r *metadataReader) Retain() { r.reader.Retain() } +func (r *metadataReader) RecordBatch() arrow.RecordBatch { return r.current } +func (r *metadataReader) Err() error { return r.reader.Err() } +func (r *metadataReader) Retain() { r.reader.Retain() } func (r *metadataReader) Release() { r.once.Do(func() { if r.current != nil {