Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
.local
.claude/
.specify/
.idea/
specs/
design/
internal/airport_test_service/
Expand Down
26 changes: 26 additions & 0 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,17 @@ func WithTimeout(timeout time.Duration) Option {
}
}

func WithTraceID(traceID string) Option {
return func(c *ClientConfig) {
if traceID != "" {
c.Transport = &traceIDTransport{
traceID: traceID,
transport: c.Transport,
}
}
}
}

func WithTransport(transport http.RoundTripper) Option {
return func(c *ClientConfig) {
c.Transport = transport
Expand Down Expand Up @@ -759,6 +770,19 @@ func (t *noTimezoneTransport) RoundTrip(req *http.Request) (*http.Response, erro
return t.transport.RoundTrip(req)
}

type traceIDTransport struct {
traceID string
transport http.RoundTripper
}

func (t *traceIDTransport) RoundTrip(req *http.Request) (*http.Response, error) {
req.Header.Set("X-Trace-Id", t.traceID)
if t.transport == nil {
return http.DefaultTransport.RoundTrip(req)
}
return t.transport.RoundTrip(req)
}

// hasTimezoneTransport walks the transport chain and returns true if a
// timezoneTransport or noTimezoneTransport is already present.
func hasTimezoneTransport(rt http.RoundTripper) bool {
Expand All @@ -776,6 +800,8 @@ func hasTimezoneTransport(rt http.RoundTripper) bool {
rt = t.transport
case *tokenTransport:
rt = t.transport
case *traceIDTransport:
rt = t.transport
default:
return false
}
Expand Down
131 changes: 117 additions & 14 deletions engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"log/slog"
"net/http"
"net/http/pprof"
"strings"
"time"

adminui "github.com/hugr-lab/query-engine/pkg/admin-ui"
Expand All @@ -27,14 +28,16 @@ import (
mcpserver "github.com/hugr-lab/query-engine/pkg/mcp"
permissions "github.com/hugr-lab/query-engine/pkg/perm"
"github.com/hugr-lab/query-engine/pkg/planner"
"github.com/hugr-lab/query-engine/pkg/trace"
"github.com/hugr-lab/query-engine/types"

"github.com/vektah/gqlparser/v2/ast"
"github.com/vektah/gqlparser/v2/gqlerror"
)

type Service struct {
config Config
config Config
logLevel *slog.LevelVar

router *http.ServeMux
adminUI http.HandlerFunc
Expand Down Expand Up @@ -100,11 +103,18 @@ type Info struct {
}

func New(config Config) (*Service, error) {
level := &slog.LevelVar{}
if config.Debug {
level.Set(slog.LevelDebug)
} else {
level.Set(slog.LevelInfo)
}
return &Service{
config: config,
router: http.NewServeMux(),
cache: cache.New(config.Cache),
s3: storage.New(),
config: config,
logLevel: level,
router: http.NewServeMux(),
cache: cache.New(config.Cache),
s3: storage.New(),
}, nil
}

Expand Down Expand Up @@ -415,6 +425,8 @@ func (s *Service) endpoints() {
s.router.Handle("/gis/", mw(http.StripPrefix("/gis", s.gis)))
}

s.router.Handle("/admin/log-level", mw(http.HandlerFunc(s.logLevelHandler)))

if s.config.MCPEnabled {
mcpSrv := mcpserver.New(s, nil, s.config.Debug)
s.router.Handle("/mcp", mw(mcpSrv.Handler()))
Expand Down Expand Up @@ -473,12 +485,30 @@ func (s *Service) parseRequest(r *http.Request) (req types.Request, err error) {

func (s *Service) ProcessQuery(ctx context.Context, req types.Request) types.Response {
start := time.Now()
logger := trace.LoggerFromContext(ctx)

if ti := trace.FromContext(ctx); ti != nil {
ctx = trace.StartSpan(ctx, "query.parse")
}
op, err := s.schema.ParseQuery(ctx, req.Query, req.Variables, req.OperationName)
trace.EndSpan(ctx)
if err != nil {
return types.ErrResponse(err)
}
parseDuration := time.Since(start)

hasTraceDirective := op.Definition.Directives.ForName(base.TraceDirectiveName) != nil
if hasTraceDirective {
d := op.Definition.Directives.ForName(base.TraceDirectiveName)
ti := trace.FromContext(ctx)
if ti == nil {
ti = trace.NewTraceInfo(trace.TraceIDFromContext(ctx), parseLogLevel(d))
ctx = trace.ContextWithTrace(ctx, ti)
} else {
ti.SetLevel(parseLogLevel(d))
}
}

var hints []types.QueryHint
if req.ValidateOnly {
hints = append(hints, types.ValidateOnlyHint())
Expand All @@ -498,16 +528,28 @@ func (s *Service) ProcessQuery(ctx context.Context, req types.Request) types.Res
if data != nil {
res.Data = data
}
if ext != nil {
if op.Definition.Directives.ForName(base.StatsDirectiveName) != nil {
opStats, ok := ext["stats"].(map[string]any)
if !ok {
opStats = make(map[string]any)
}
opStats["parse_time"] = parseDuration.String()
opStats["total_time"] = time.Since(start).String()
ext["stats"] = opStats

if ext == nil {
ext = make(map[string]any)
}
if op.Definition.Directives.ForName(base.StatsDirectiveName) != nil {
opStats, ok := ext["stats"].(map[string]any)
if !ok {
opStats = make(map[string]any)
}
opStats["parse_time"] = parseDuration.String()
opStats["total_time"] = time.Since(start).String()
ext["stats"] = opStats
}

if ti := trace.FromContext(ctx); ti != nil {
logger.Debug("trace.spans", "spans", ti.Result())
if hasTraceDirective {
ext["trace"] = ti.Result()
}
}

if len(ext) > 0 {
res.Extensions = ext
}
return res
Expand Down Expand Up @@ -568,3 +610,64 @@ func (s *Service) Commit(ctx context.Context) error {
func (s *Service) Rollback(ctx context.Context) error {
return s.db.Rollback(ctx)
}

func (s *Service) logLevelHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch r.Method {
case http.MethodGet:
json.NewEncoder(w).Encode(map[string]string{
"level": s.logLevel.Level().String(),
})
case http.MethodPost:
var body struct {
Level string `json:"level"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{"error": "invalid request body"})
return
}
var lvl slog.Level
switch strings.ToUpper(body.Level) {
case "DEBUG":
lvl = slog.LevelDebug
case "INFO":
lvl = slog.LevelInfo
case "WARN":
lvl = slog.LevelWarn
case "ERROR":
lvl = slog.LevelError
default:
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{"error": "invalid level, use: DEBUG, INFO, WARN, ERROR"})
return
}
s.logLevel.Set(lvl)
slog.Info("log level changed", "level", lvl.String())
json.NewEncoder(w).Encode(map[string]string{
"level": lvl.String(),
})
default:
w.WriteHeader(http.StatusMethodNotAllowed)
}
}

func parseLogLevel(d *ast.Directive) slog.Level {
if d == nil {
return slog.LevelDebug
}
arg := d.Arguments.ForName("level")
if arg == nil || arg.Value == nil {
return slog.LevelDebug
}
switch strings.ToUpper(arg.Value.Raw) {
case "ERROR":
return slog.LevelError
case "WARN":
return slog.LevelWarn
case "INFO":
return slog.LevelInfo
default:
return slog.LevelDebug
}
}
Loading