From cf8fdd4f22e27fadf4e2edea92ab321722ce778c Mon Sep 17 00:00:00 2001 From: Jeremy Alvis Date: Fri, 27 Mar 2026 12:31:45 -0700 Subject: [PATCH 01/16] Initial pass of moving from gorm to golang-migrate + sqlc Signed-off-by: Jeremy Alvis --- .github/workflows/migration-immutability.yaml | 35 + .github/workflows/tag.yaml | 1 + Makefile | 9 +- go/api/database/models.go | 273 +++--- go/core/cmd/migrate/main.go | 311 ++++++ go/core/cmd/migrate/main_test.go | 247 +++++ .../reconciler/mcp_server_reconciler_test.go | 17 +- go/core/internal/database/client.go | 690 ------------- go/core/internal/database/client_postgres.go | 908 ++++++++++++++++++ go/core/internal/database/client_test.go | 68 +- go/core/internal/database/connect.go | 87 ++ .../{manager_test.go => connect_test.go} | 4 +- go/core/internal/database/fake/client.go | 19 +- go/core/internal/database/gen/agents.sql.go | 98 ++ go/core/internal/database/gen/crewai.sql.go | 191 ++++ go/core/internal/database/gen/db.go | 31 + go/core/internal/database/gen/events.sql.go | 340 +++++++ go/core/internal/database/gen/feedback.sql.go | 89 ++ .../internal/database/gen/langgraph.sql.go | 322 +++++++ go/core/internal/database/gen/memory.sql.go | 203 ++++ go/core/internal/database/gen/models.go | 155 +++ .../database/gen/push_notifications.sql.go | 100 ++ go/core/internal/database/gen/querier.go | 74 ++ go/core/internal/database/gen/sessions.sql.go | 164 ++++ go/core/internal/database/gen/tasks.sql.go | 109 +++ go/core/internal/database/gen/tools.sql.go | 263 +++++ go/core/internal/database/manager.go | 207 ---- go/core/internal/database/queries/agents.sql | 21 + go/core/internal/database/queries/crewai.sql | 38 + go/core/internal/database/queries/events.sql | 49 + .../internal/database/queries/feedback.sql | 9 + .../internal/database/queries/langgraph.sql | 58 ++ go/core/internal/database/queries/memory.sql | 32 + .../database/queries/push_notifications.sql | 20 + .../internal/database/queries/sessions.sql | 28 + go/core/internal/database/queries/tasks.sql | 25 + go/core/internal/database/queries/tools.sql | 50 + go/core/internal/database/service.go | 88 -- go/core/internal/database/sqlc.yaml | 27 + go/core/internal/database/testhelpers_test.go | 22 +- go/core/internal/database/upgrade_test.go | 358 +++++++ go/core/internal/dbtest/dbtest.go | 113 +++ .../httpserver/handlers/checkpoints.go | 6 +- .../internal/httpserver/handlers/memory.go | 2 +- .../internal/httpserver/middleware_error.go | 4 +- go/core/internal/httpserver/server.go | 8 - go/core/pkg/app/app.go | 22 +- go/core/pkg/env/database.go | 11 - .../migrations/core/000001_initial.down.sql | 12 + .../pkg/migrations/core/000001_initial.up.sql | 161 ++++ .../core/000002_add_session_source.down.sql | 2 + .../core/000002_add_session_source.up.sql | 6 + go/core/pkg/migrations/migrations.go | 9 + .../vector/000001_vector_support.down.sql | 2 + .../vector/000001_vector_support.up.sql | 17 + .../000002_add_memory_hnsw_index.down.sql | 1 + .../000002_add_memory_hnsw_index.up.sql | 4 + .../000003_memory_uuid_default.down.sql | 1 + .../vector/000003_memory_uuid_default.up.sql | 1 + go/go.mod | 10 +- go/go.sum | 17 +- .../templates/controller-deployment.yaml | 30 + helm/kagent/values.yaml | 6 + .../langgraph/currency/currency/agent.py | 1 + 64 files changed, 5012 insertions(+), 1274 deletions(-) create mode 100644 .github/workflows/migration-immutability.yaml create mode 100644 go/core/cmd/migrate/main.go create mode 100644 go/core/cmd/migrate/main_test.go delete mode 100644 go/core/internal/database/client.go create mode 100644 go/core/internal/database/client_postgres.go create mode 100644 go/core/internal/database/connect.go rename go/core/internal/database/{manager_test.go => connect_test.go} (92%) create mode 100644 go/core/internal/database/gen/agents.sql.go create mode 100644 go/core/internal/database/gen/crewai.sql.go create mode 100644 go/core/internal/database/gen/db.go create mode 100644 go/core/internal/database/gen/events.sql.go create mode 100644 go/core/internal/database/gen/feedback.sql.go create mode 100644 go/core/internal/database/gen/langgraph.sql.go create mode 100644 go/core/internal/database/gen/memory.sql.go create mode 100644 go/core/internal/database/gen/models.go create mode 100644 go/core/internal/database/gen/push_notifications.sql.go create mode 100644 go/core/internal/database/gen/querier.go create mode 100644 go/core/internal/database/gen/sessions.sql.go create mode 100644 go/core/internal/database/gen/tasks.sql.go create mode 100644 go/core/internal/database/gen/tools.sql.go delete mode 100644 go/core/internal/database/manager.go create mode 100644 go/core/internal/database/queries/agents.sql create mode 100644 go/core/internal/database/queries/crewai.sql create mode 100644 go/core/internal/database/queries/events.sql create mode 100644 go/core/internal/database/queries/feedback.sql create mode 100644 go/core/internal/database/queries/langgraph.sql create mode 100644 go/core/internal/database/queries/memory.sql create mode 100644 go/core/internal/database/queries/push_notifications.sql create mode 100644 go/core/internal/database/queries/sessions.sql create mode 100644 go/core/internal/database/queries/tasks.sql create mode 100644 go/core/internal/database/queries/tools.sql delete mode 100644 go/core/internal/database/service.go create mode 100644 go/core/internal/database/sqlc.yaml create mode 100644 go/core/internal/database/upgrade_test.go delete mode 100644 go/core/pkg/env/database.go create mode 100644 go/core/pkg/migrations/core/000001_initial.down.sql create mode 100644 go/core/pkg/migrations/core/000001_initial.up.sql create mode 100644 go/core/pkg/migrations/core/000002_add_session_source.down.sql create mode 100644 go/core/pkg/migrations/core/000002_add_session_source.up.sql create mode 100644 go/core/pkg/migrations/migrations.go create mode 100644 go/core/pkg/migrations/vector/000001_vector_support.down.sql create mode 100644 go/core/pkg/migrations/vector/000001_vector_support.up.sql create mode 100644 go/core/pkg/migrations/vector/000002_add_memory_hnsw_index.down.sql create mode 100644 go/core/pkg/migrations/vector/000002_add_memory_hnsw_index.up.sql create mode 100644 go/core/pkg/migrations/vector/000003_memory_uuid_default.down.sql create mode 100644 go/core/pkg/migrations/vector/000003_memory_uuid_default.up.sql diff --git a/.github/workflows/migration-immutability.yaml b/.github/workflows/migration-immutability.yaml new file mode 100644 index 000000000..d22063906 --- /dev/null +++ b/.github/workflows/migration-immutability.yaml @@ -0,0 +1,35 @@ +name: Migration Immutability + +on: + pull_request: + branches: [main, release/v0.7.x] + paths: + - "go/core/pkg/migrations/**" + +jobs: + check: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Fail if any existing migration file was modified + run: | + # List files under go/core/pkg/migrations/ that were changed relative + # to the merge base of this PR. We only care about modifications (M) + # and renames (R); additions (A) are fine. + BASE=$(git merge-base HEAD origin/${{ github.base_ref }}) + MODIFIED=$(git diff --name-only --diff-filter=MR "$BASE" HEAD \ + -- 'go/core/pkg/migrations/**/*.sql') + + if [ -n "$MODIFIED" ]; then + echo "ERROR: The following migration files were modified." + echo "Migration files are immutable once merged." + echo "Fix bugs with a new migration instead." + echo "" + echo "$MODIFIED" + exit 1 + fi + + echo "OK: no existing migration files were modified." diff --git a/.github/workflows/tag.yaml b/.github/workflows/tag.yaml index 25cc8f47a..10ec1f4fc 100644 --- a/.github/workflows/tag.yaml +++ b/.github/workflows/tag.yaml @@ -18,6 +18,7 @@ jobs: matrix: image: - controller + - migrate - ui - app - golang-adk diff --git a/Makefile b/Makefile index 964456843..d6066b583 100644 --- a/Makefile +++ b/Makefile @@ -40,6 +40,7 @@ APP_IMAGE_NAME ?= app KAGENT_ADK_IMAGE_NAME ?= kagent-adk GOLANG_ADK_IMAGE_NAME ?= golang-adk SKILLS_INIT_IMAGE_NAME ?= skills-init +MIGRATE_IMAGE_NAME ?= migrate CONTROLLER_IMAGE_TAG ?= $(VERSION) UI_IMAGE_TAG ?= $(VERSION) @@ -47,6 +48,7 @@ APP_IMAGE_TAG ?= $(VERSION) KAGENT_ADK_IMAGE_TAG ?= $(VERSION) GOLANG_ADK_IMAGE_TAG ?= $(VERSION) SKILLS_INIT_IMAGE_TAG ?= $(VERSION) +MIGRATE_IMAGE_TAG ?= $(VERSION) CONTROLLER_IMG ?= $(DOCKER_REGISTRY)/$(DOCKER_REPO)/$(CONTROLLER_IMAGE_NAME):$(CONTROLLER_IMAGE_TAG) UI_IMG ?= $(DOCKER_REGISTRY)/$(DOCKER_REPO)/$(UI_IMAGE_NAME):$(UI_IMAGE_TAG) @@ -54,6 +56,7 @@ APP_IMG ?= $(DOCKER_REGISTRY)/$(DOCKER_REPO)/$(APP_IMAGE_NAME):$(APP_IMAGE_TAG) KAGENT_ADK_IMG ?= $(DOCKER_REGISTRY)/$(DOCKER_REPO)/$(KAGENT_ADK_IMAGE_NAME):$(KAGENT_ADK_IMAGE_TAG) GOLANG_ADK_IMG ?= $(DOCKER_REGISTRY)/$(DOCKER_REPO)/$(GOLANG_ADK_IMAGE_NAME):$(GOLANG_ADK_IMAGE_TAG) SKILLS_INIT_IMG ?= $(DOCKER_REGISTRY)/$(DOCKER_REPO)/$(SKILLS_INIT_IMAGE_NAME):$(SKILLS_INIT_IMAGE_TAG) +MIGRATE_IMG ?= $(DOCKER_REGISTRY)/$(DOCKER_REPO)/$(MIGRATE_IMAGE_NAME):$(MIGRATE_IMAGE_TAG) #take from go/go.mod AWK ?= $(shell command -v gawk || command -v awk) @@ -219,7 +222,7 @@ prune-docker-images: docker images --filter dangling=true -q | xargs -r docker rmi || : .PHONY: build -build: buildx-create build-controller build-ui build-app build-golang-adk build-skills-init +build: buildx-create build-controller build-migrate build-ui build-app build-golang-adk build-skills-init @echo "Build completed successfully." @echo "Controller Image: $(CONTROLLER_IMG)" @echo "UI Image: $(UI_IMG)" @@ -267,6 +270,10 @@ controller-manifests: build-controller: buildx-create controller-manifests $(DOCKER_BUILDER) build $(DOCKER_BUILD_ARGS) $(TOOLS_IMAGE_BUILD_ARGS) --build-arg BUILD_PACKAGE=core/cmd/controller/main.go -t $(CONTROLLER_IMG) -f go/Dockerfile ./go +.PHONY: build-migrate +build-migrate: buildx-create + $(DOCKER_BUILDER) build $(DOCKER_BUILD_ARGS) $(TOOLS_IMAGE_BUILD_ARGS) --build-arg BUILD_PACKAGE=core/cmd/migrate/main.go -t $(MIGRATE_IMG) -f go/Dockerfile ./go + .PHONY: build-ui build-ui: buildx-create $(DOCKER_BUILDER) build $(DOCKER_BUILD_ARGS) $(TOOLS_IMAGE_BUILD_ARGS) -t $(UI_IMG) -f ui/Dockerfile ./ui diff --git a/go/api/database/models.go b/go/api/database/models.go index 391c70e47..5520b0039 100644 --- a/go/api/database/models.go +++ b/go/api/database/models.go @@ -4,42 +4,35 @@ import ( "encoding/json" "time" - "github.com/google/uuid" "github.com/kagent-dev/kagent/go/api/adk" "github.com/pgvector/pgvector-go" - "gorm.io/gorm" "trpc.group/trpc-go/trpc-a2a-go/protocol" ) -// Agent represents an agent configuration type Agent struct { - ID string `gorm:"primaryKey" json:"id"` - CreatedAt time.Time `gorm:"autoCreateTime" json:"created_at"` - UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at"` - DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at"` + ID string `json:"id"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + DeletedAt *time.Time `json:"deleted_at,omitempty"` - Type string `gorm:"not null" json:"type"` - // Config is optional and may be nil for some agent types. - // For agent types that require configuration, this field should be populated. - // For agent types that do not require configuration, this field should be nil. - Config *adk.AgentConfig `gorm:"type:json" json:"config"` + Type string `json:"type"` + Config *adk.AgentConfig `json:"config"` } type Event struct { - ID string `gorm:"primaryKey;not null" json:"id"` - SessionID string `gorm:"index" json:"session_id"` - UserID string `gorm:"primaryKey;not null" json:"user_id"` - CreatedAt time.Time `gorm:"autoCreateTime" json:"created_at"` - UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at"` - DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at"` + ID string `json:"id"` + SessionID string `json:"session_id"` + UserID string `json:"user_id"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + DeletedAt *time.Time `json:"deleted_at,omitempty"` - Data string `gorm:"type:text;not null" json:"data"` // JSON serialized protocol.Message + Data string `json:"data"` // JSON-serialized protocol.Message } func (m *Event) Parse() (protocol.Message, error) { var data protocol.Message - err := json.Unmarshal([]byte(m.Data), &data) - if err != nil { + if err := json.Unmarshal([]byte(m.Data), &data); err != nil { return protocol.Message{}, err } return data, nil @@ -48,11 +41,11 @@ func (m *Event) Parse() (protocol.Message, error) { func ParseMessages(messages []Event) ([]*protocol.Message, error) { result := make([]*protocol.Message, 0, len(messages)) for _, message := range messages { - parsedMessage, err := message.Parse() + parsed, err := message.Parse() if err != nil { return nil, err } - result = append(result, &parsedMessage) + result = append(result, &parsed) } return result, nil } @@ -68,32 +61,31 @@ const ( ) type Session struct { - ID string `gorm:"primaryKey;not null" json:"id"` - Name *string `gorm:"index" json:"name,omitempty"` - UserID string `gorm:"primaryKey" json:"user_id"` - CreatedAt time.Time `gorm:"autoCreateTime" json:"created_at"` - UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at"` - DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at"` - - AgentID *string `gorm:"index" json:"agent_id"` + ID string `json:"id"` + Name *string `json:"name,omitempty"` + UserID string `json:"user_id"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + DeletedAt *time.Time `json:"deleted_at,omitempty"` + + AgentID *string `json:"agent_id,omitempty"` // Source indicates how this session was created. // SessionSourceUser = user-initiated, SessionSourceAgent = created by a parent agent's A2A call. - Source *SessionSource `gorm:"index" json:"source,omitempty"` + Source *SessionSource `json:"source,omitempty"` } type Task struct { - ID string `gorm:"primaryKey;not null" json:"id"` - CreatedAt time.Time `gorm:"autoCreateTime" json:"created_at"` - UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at"` - DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at"` - Data string `gorm:"type:text;not null" json:"data"` // JSON serialized task data - SessionID string `gorm:"index" json:"session_id"` + ID string `json:"id"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + DeletedAt *time.Time `json:"deleted_at,omitempty"` + Data string `json:"data"` // JSON-serialized task data + SessionID string `json:"session_id"` } func (t *Task) Parse() (protocol.Task, error) { var data protocol.Task - err := json.Unmarshal([]byte(t.Data), &data) - if err != nil { + if err := json.Unmarshal([]byte(t.Data), &data); err != nil { return protocol.Task{}, err } return data, nil @@ -102,162 +94,129 @@ func (t *Task) Parse() (protocol.Task, error) { func ParseTasks(tasks []Task) ([]*protocol.Task, error) { result := make([]*protocol.Task, 0, len(tasks)) for _, task := range tasks { - parsedTask, err := task.Parse() + parsed, err := task.Parse() if err != nil { return nil, err } - result = append(result, &parsedTask) + result = append(result, &parsed) } return result, nil } type PushNotification struct { - ID string `gorm:"primaryKey;not null" json:"id"` - TaskID string `gorm:"not null;index" json:"task_id"` - CreatedAt time.Time `gorm:"autoCreateTime" json:"created_at"` - UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at"` - DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at"` - Data string `gorm:"type:text;not null" json:"data"` // JSON serialized push notification config + ID string `json:"id"` + TaskID string `json:"task_id"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + DeletedAt *time.Time `json:"deleted_at,omitempty"` + Data string `json:"data"` // JSON-serialized push notification config } // FeedbackIssueType represents the category of feedback issue type FeedbackIssueType string const ( - FeedbackIssueTypeInstructions FeedbackIssueType = "instructions" // Did not follow instructions - FeedbackIssueTypeFactual FeedbackIssueType = "factual" // Not factually correct - FeedbackIssueTypeIncomplete FeedbackIssueType = "incomplete" // Incomplete response - FeedbackIssueTypeTool FeedbackIssueType = "tool" // Should have run the tool + FeedbackIssueTypeInstructions FeedbackIssueType = "instructions" + FeedbackIssueTypeFactual FeedbackIssueType = "factual" + FeedbackIssueTypeIncomplete FeedbackIssueType = "incomplete" + FeedbackIssueTypeTool FeedbackIssueType = "tool" ) -// Feedback represents user feedback on agent responses type Feedback struct { - gorm.Model - UserID string `gorm:"not null;index" json:"user_id"` - MessageID uint `gorm:"index;constraint:OnDelete:CASCADE" json:"message_id"` - IsPositive bool `gorm:"default:false" json:"is_positive"` - FeedbackText string `gorm:"not null" json:"feedback_text"` + ID int64 `json:"id"` + CreatedAt *time.Time `json:"created_at,omitempty"` + UpdatedAt *time.Time `json:"updated_at,omitempty"` + DeletedAt *time.Time `json:"deleted_at,omitempty"` + UserID string `json:"user_id"` + MessageID *int64 `json:"message_id,omitempty"` + IsPositive bool `json:"is_positive"` + FeedbackText string `json:"feedback_text"` IssueType *FeedbackIssueType `json:"issue_type,omitempty"` } -// Tool represents a single tool that can be used by an agent type Tool struct { - ID string `gorm:"primaryKey;not null" json:"id"` - ServerName string `gorm:"primaryKey;not null" json:"server_name"` - GroupKind string `gorm:"primaryKey;not null" json:"group_kind"` - CreatedAt time.Time `gorm:"autoCreateTime" json:"created_at"` - UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at"` - DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at"` - Description string `json:"description"` + ID string `json:"id"` + ServerName string `json:"server_name"` + GroupKind string `json:"group_kind"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + DeletedAt *time.Time `json:"deleted_at,omitempty"` + Description string `json:"description"` } -// ToolServer represents a tool server that provides tools type ToolServer struct { - CreatedAt time.Time `gorm:"autoCreateTime" json:"created_at"` - UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at"` - DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at"` - Name string `gorm:"primaryKey;not null" json:"name"` - GroupKind string `gorm:"primaryKey;not null" json:"group_kind"` - Description string `json:"description"` - LastConnected *time.Time `json:"last_connected,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + DeletedAt *time.Time `json:"deleted_at,omitempty"` + Name string `json:"name"` + GroupKind string `json:"group_kind"` + Description string `json:"description"` + LastConnected *time.Time `json:"last_connected,omitempty"` } -// LangGraphCheckpoint represents a LangGraph checkpoint type LangGraphCheckpoint struct { - UserID string `gorm:"primaryKey;not null" json:"user_id"` - ThreadID string `gorm:"primaryKey;not null" json:"thread_id"` - CheckpointNS string `gorm:"primaryKey;not null;default:''" json:"checkpoint_ns"` - CheckpointID string `gorm:"primaryKey;not null" json:"checkpoint_id"` - ParentCheckpointID *string `gorm:"index" json:"parent_checkpoint_id,omitempty"` - CreatedAt time.Time `gorm:"autoCreateTime;index:idx_lgcp_list" json:"created_at"` - UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at"` - DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at"` - Metadata string `gorm:"type:text;not null" json:"metadata"` // JSON serialized metadata - Checkpoint string `gorm:"type:text;not null" json:"checkpoint"` // JSON serialized checkpoint - CheckpointType string `gorm:"not null" json:"checkpoint_type"` - Version int `gorm:"default:1" json:"version"` + UserID string `json:"user_id"` + ThreadID string `json:"thread_id"` + CheckpointNS string `json:"checkpoint_ns"` + CheckpointID string `json:"checkpoint_id"` + ParentCheckpointID *string `json:"parent_checkpoint_id,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + DeletedAt *time.Time `json:"deleted_at,omitempty"` + Metadata string `json:"metadata"` + Checkpoint string `json:"checkpoint"` + CheckpointType string `json:"checkpoint_type"` + Version int32 `json:"version"` } -// LangGraphCheckpointWrite represents a write operation for a checkpoint type LangGraphCheckpointWrite struct { - UserID string `gorm:"primaryKey;not null" json:"user_id"` - ThreadID string `gorm:"primaryKey;not null" json:"thread_id"` - CheckpointNS string `gorm:"primaryKey;not null;default:''" json:"checkpoint_ns"` - CheckpointID string `gorm:"primaryKey;not null" json:"checkpoint_id"` - WriteIdx int `gorm:"primaryKey;not null" json:"write_idx"` - Value string `gorm:"type:text;not null" json:"value"` // JSON serialized value - ValueType string `gorm:"not null" json:"value_type"` - Channel string `gorm:"not null" json:"channel"` - TaskID string `gorm:"not null" json:"task_id"` - CreatedAt time.Time `gorm:"autoCreateTime" json:"created_at"` - UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at"` - DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at"` + UserID string `json:"user_id"` + ThreadID string `json:"thread_id"` + CheckpointNS string `json:"checkpoint_ns"` + CheckpointID string `json:"checkpoint_id"` + WriteIdx int32 `json:"write_idx"` + Value string `json:"value"` + ValueType string `json:"value_type"` + Channel string `json:"channel"` + TaskID string `json:"task_id"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + DeletedAt *time.Time `json:"deleted_at,omitempty"` } -// CrewAIAgentMemory represents long-term memory for CrewAI agents type CrewAIAgentMemory struct { - UserID string `gorm:"primaryKey;not null" json:"user_id"` - ThreadID string `gorm:"primaryKey;not null" json:"thread_id"` - CreatedAt time.Time `gorm:"autoCreateTime;index:idx_crewai_memory_list" json:"created_at"` - UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at"` - DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at"` - // MemoryData contains JSON serialized memory data including task_description, score, metadata, datetime - MemoryData string `gorm:"type:text;not null" json:"memory_data"` + UserID string `json:"user_id"` + ThreadID string `json:"thread_id"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + DeletedAt *time.Time `json:"deleted_at,omitempty"` + MemoryData string `json:"memory_data"` } -// CrewAIFlowState represents flow state for CrewAI flows type CrewAIFlowState struct { - UserID string `gorm:"primaryKey;not null" json:"user_id"` - ThreadID string `gorm:"primaryKey;not null" json:"thread_id"` - MethodName string `gorm:"primaryKey;not null" json:"method_name"` - CreatedAt time.Time `gorm:"autoCreateTime;index:idx_crewai_flow_state_list" json:"created_at"` - UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at"` - DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at"` - // StateData contains JSON serialized flow state data - StateData string `gorm:"type:text;not null" json:"state_data"` + UserID string `json:"user_id"` + ThreadID string `json:"thread_id"` + MethodName string `json:"method_name"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + DeletedAt *time.Time `json:"deleted_at,omitempty"` + StateData string `json:"state_data"` } -// Memory represents a memory/session embedding with TTL support type Memory struct { - // ID is a UUID generated by the application layer before insert, ensuring - // compatibility with both Postgres and SQLite/libSQL backends. - ID string `gorm:"primaryKey" json:"id"` - AgentName string `gorm:"index:idx_memory_agent_user,composite:agent_name" json:"agent_name"` - UserID string `gorm:"index:idx_memory_agent_user,composite:user_id" json:"user_id"` - Content string `gorm:"type:text" json:"content"` - Embedding pgvector.Vector `gorm:"type:vector(768)" json:"embedding"` - Metadata string `gorm:"type:text" json:"metadata"` - CreatedAt time.Time `gorm:"autoCreateTime" json:"created_at"` - ExpiresAt *time.Time `gorm:"index" json:"expires_at,omitempty"` - AccessCount int `gorm:"default:0" json:"access_count"` -} - -// BeforeCreate generates a UUID for the Memory ID if one has not been set, -// making ID generation database-agnostic (works for both Postgres and SQLite). -func (m *Memory) BeforeCreate(tx *gorm.DB) error { - if m.ID == "" { - m.ID = uuid.New().String() - } - return nil -} - -// AgentMemorySearchResult is the result of a vector similarity search over Memory (e.g. SELECT *, 1 - (embedding <=> ?) as score). + ID string `json:"id"` + AgentName string `json:"agent_name"` + UserID string `json:"user_id"` + Content string `json:"content"` + Embedding pgvector.Vector `json:"embedding"` + Metadata string `json:"metadata"` + CreatedAt time.Time `json:"created_at"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` + AccessCount int32 `json:"access_count"` +} + +// AgentMemorySearchResult is the result of a vector similarity search over Memory. type AgentMemorySearchResult struct { Memory - Score float64 `gorm:"column:score" json:"score"` + Score float64 `json:"score"` } - -// TableName methods to match Python table names -func (Agent) TableName() string { return "agent" } -func (Event) TableName() string { return "event" } -func (Session) TableName() string { return "session" } -func (Task) TableName() string { return "task" } -func (PushNotification) TableName() string { return "push_notification" } -func (Feedback) TableName() string { return "feedback" } -func (Tool) TableName() string { return "tool" } -func (ToolServer) TableName() string { return "toolserver" } -func (LangGraphCheckpoint) TableName() string { return "lg_checkpoint" } -func (LangGraphCheckpointWrite) TableName() string { return "lg_checkpoint_write" } -func (CrewAIAgentMemory) TableName() string { return "crewai_agent_memory" } -func (CrewAIFlowState) TableName() string { return "crewai_flow_state" } -func (Memory) TableName() string { return "memory" } diff --git a/go/core/cmd/migrate/main.go b/go/core/cmd/migrate/main.go new file mode 100644 index 000000000..f0f25e373 --- /dev/null +++ b/go/core/cmd/migrate/main.go @@ -0,0 +1,311 @@ +// kagent-migrate runs Postgres schema migrations and exits. +// It is intended to run as a Kubernetes init container before the kagent +// controller starts, ensuring the schema is up to date before the app connects. +// +// Usage: +// +// kagent-migrate [command] +// +// Commands: +// +// up Apply all pending migrations (default when no command is given) +// down Roll back N migrations on a single track +// version Print the current applied version and dirty flag for each track +// +// Required environment variable: +// +// POSTGRES_DATABASE_URL — Postgres connection URL +// +// Optional environment variables: +// +// POSTGRES_DATABASE_URL_FILE — path to a file containing the URL (takes precedence) +// KAGENT_DATABASE_VECTOR_ENABLED — set to "true" to also run vector migrations +// +// Enterprise extension: replace this binary with enterprise-migrate, which imports +// go/core/pkg/migrations.FS directly via the OSS Go module dependency and adds its +// own migration passes alongside it at compile time. +package main + +import ( + "database/sql" + "errors" + "flag" + "fmt" + "io/fs" + "log" + "os" + "strings" + + "github.com/golang-migrate/migrate/v4" + migratepg "github.com/golang-migrate/migrate/v4/database/postgres" + "github.com/golang-migrate/migrate/v4/source/iofs" + "github.com/kagent-dev/kagent/go/core/pkg/migrations" + _ "github.com/lib/pq" +) + +func main() { + flag.Parse() + + url, err := resolveURL() + if err != nil { + log.Fatalf("kagent-migrate: %v", err) + } + + vectorEnabled := strings.EqualFold(os.Getenv("KAGENT_DATABASE_VECTOR_ENABLED"), "true") + + cmd := "up" + args := flag.Args() + if len(args) > 0 { + cmd = args[0] + args = args[1:] + } + + switch cmd { + case "up": + runUpCommand(url, migrations.FS, vectorEnabled) + case "down": + runDownCommand(url, migrations.FS, vectorEnabled, args) + case "version": + runVersionCommand(url, migrations.FS, vectorEnabled) + default: + log.Fatalf("kagent-migrate: unknown command %q (valid: up, down, version)", cmd) + } +} + +func runUpCommand(url string, migrationsFS fs.FS, vectorEnabled bool) { + corePrev, err := applyDir(url, migrationsFS, "core", "schema_migrations") + if err != nil { + log.Fatalf("kagent-migrate: core migrations: %v", err) + } + log.Println("kagent-migrate: core migrations applied") + + if vectorEnabled { + if _, err := applyDir(url, migrationsFS, "vector", "vector_schema_migrations"); err != nil { + // Vector failed (and already rolled itself back). Roll back core too + // since both tracks are treated as one unit. + log.Printf("kagent-migrate: rolling back core to version %d", corePrev) + rollbackDir(url, migrationsFS, "core", "schema_migrations", corePrev) + log.Fatalf("kagent-migrate: vector migrations: %v", err) + } + log.Println("kagent-migrate: vector migrations applied") + } + + log.Println("kagent-migrate: done") +} + +// applyDir runs Up for dir and rolls back on failure. It returns the pre-run +// version so the caller can roll back this track if a later track fails. +func applyDir(url string, migrationsFS fs.FS, dir, migrationsTable string) (prevVersion uint, err error) { + mg, err := newMigrate(url, migrationsFS, dir, migrationsTable) + if err != nil { + return 0, err + } + defer closeMigrate(dir, mg) + + prevVersion, _, err = mg.Version() + if err != nil && !errors.Is(err, migrate.ErrNilVersion) { + return 0, fmt.Errorf("get pre-migration version for %s: %w", dir, err) + } + // prevVersion == 0 when ErrNilVersion (no migrations applied yet). + + if upErr := mg.Up(); upErr != nil { + if errors.Is(upErr, migrate.ErrNoChange) { + return prevVersion, nil + } + log.Printf("kagent-migrate: migration failed for %s, attempting rollback to version %d", dir, prevVersion) + if rbErr := rollbackToVersion(mg, dir, prevVersion); rbErr != nil { + log.Printf("kagent-migrate: rollback failed for %s: %v", dir, rbErr) + } else { + log.Printf("kagent-migrate: rolled back %s to version %d", dir, prevVersion) + } + return prevVersion, fmt.Errorf("run migrations for %s: %w", dir, upErr) + } + return prevVersion, nil +} + +// rollbackDir opens a fresh migrate instance and rolls dir back to targetVersion. +// Used to roll back a previously-succeeded track when a later track fails. +func rollbackDir(url string, migrationsFS fs.FS, dir, migrationsTable string, targetVersion uint) { + mg, err := newMigrate(url, migrationsFS, dir, migrationsTable) + if err != nil { + log.Printf("kagent-migrate: rollback of %s failed (open): %v", dir, err) + return + } + defer closeMigrate(dir, mg) + if err := rollbackToVersion(mg, dir, targetVersion); err != nil { + log.Printf("kagent-migrate: rollback of %s failed: %v", dir, err) + } else { + log.Printf("kagent-migrate: rolled back %s to version %d", dir, targetVersion) + } +} + +// rollbackToVersion rolls the migration state back to targetVersion. +// It handles the dirty-state cleanup golang-migrate requires after a failed +// Up run before down steps can be applied. +func rollbackToVersion(mg *migrate.Migrate, dir string, targetVersion uint) error { + currentVersion, dirty, err := mg.Version() + if err != nil { + if errors.Is(err, migrate.ErrNilVersion) { + return nil // nothing was applied; nothing to roll back + } + return fmt.Errorf("get version after failure for %s: %w", dir, err) + } + + if dirty { + // The failed migration is recorded as dirty at currentVersion. + // Force to the last clean version so Steps can run. + cleanVersion := int(currentVersion) - 1 + forceTarget := cleanVersion + if forceTarget < 1 { + forceTarget = -1 // negative tells golang-migrate to remove the version record entirely + } + if err := mg.Force(forceTarget); err != nil { + return fmt.Errorf("clear dirty state for %s: %w", dir, err) + } + if forceTarget < 0 { + return nil // first migration failed and was cleared; nothing left to roll back + } + currentVersion = uint(cleanVersion) + } + + steps := int(currentVersion) - int(targetVersion) + if steps <= 0 { + return nil + } + if err := mg.Steps(-steps); err != nil && !errors.Is(err, migrate.ErrNoChange) { + return fmt.Errorf("roll back %d step(s) for %s: %w", steps, dir, err) + } + return nil +} + +func runDownCommand(url string, migrationsFS fs.FS, vectorEnabled bool, args []string) { + downFlags := flag.NewFlagSet("down", flag.ExitOnError) + steps := downFlags.Int("steps", 0, "number of down migrations to run (required, must be > 0)") + track := downFlags.String("track", "core", "migration track to roll back: core or vector") + if err := downFlags.Parse(args); err != nil { + log.Fatalf("kagent-migrate: down: %v", err) + } + + if *steps <= 0 { + log.Fatalf("kagent-migrate: down: --steps must be a positive integer") + } + + var dir, table string + switch *track { + case "core": + dir, table = "core", "schema_migrations" + case "vector": + if !vectorEnabled { + log.Fatalf("kagent-migrate: down: track %q requested but KAGENT_DATABASE_VECTOR_ENABLED is not true", *track) + } + dir, table = "vector", "vector_schema_migrations" + default: + log.Fatalf("kagent-migrate: down: unknown track %q (valid: core, vector)", *track) + } + + if err := downDir(url, migrationsFS, dir, table, *steps); err != nil { + log.Fatalf("kagent-migrate: down %s (%d steps): %v", *track, *steps, err) + } + log.Printf("kagent-migrate: rolled back %d migration(s) on %s track", *steps, *track) +} + +func runVersionCommand(url string, migrationsFS fs.FS, vectorEnabled bool) { + tracks := []struct{ dir, table string }{ + {"core", "schema_migrations"}, + } + if vectorEnabled { + tracks = append(tracks, struct{ dir, table string }{"vector", "vector_schema_migrations"}) + } + + for _, t := range tracks { + version, dirty, err := versionDir(url, migrationsFS, t.dir, t.table) + if err != nil { + log.Fatalf("kagent-migrate: version %s: %v", t.dir, err) + } + log.Printf("kagent-migrate: track=%-6s table=%-30s version=%d dirty=%v", t.dir, t.table, version, dirty) + } +} + +func resolveURL() (string, error) { + if file := os.Getenv("POSTGRES_DATABASE_URL_FILE"); file != "" { + content, err := os.ReadFile(file) + if err != nil { + return "", fmt.Errorf("reading URL file %s: %w", file, err) + } + url := strings.TrimSpace(string(content)) + if url == "" { + return "", fmt.Errorf("URL file %s is empty", file) + } + return url, nil + } + url := os.Getenv("POSTGRES_DATABASE_URL") + if url == "" { + return "", fmt.Errorf("POSTGRES_DATABASE_URL must be set") + } + return url, nil +} + +// newMigrate opens a database connection and constructs a migrate.Migrate for the given dir/table. +// The caller is responsible for calling closeMigrate on the returned instance. +func newMigrate(url string, migrationsFS fs.FS, dir, migrationsTable string) (*migrate.Migrate, error) { + db, err := sql.Open("postgres", url) + if err != nil { + return nil, fmt.Errorf("open database for %s: %w", dir, err) + } + + src, err := iofs.New(migrationsFS, dir) + if err != nil { + return nil, fmt.Errorf("load migration files from %s: %w", dir, err) + } + + driver, err := migratepg.WithInstance(db, &migratepg.Config{ + MigrationsTable: migrationsTable, + }) + if err != nil { + return nil, fmt.Errorf("create migration driver for %s: %w", dir, err) + } + + mg, err := migrate.NewWithInstance("iofs", src, "postgres", driver) + if err != nil { + return nil, fmt.Errorf("create migrator for %s: %w", dir, err) + } + return mg, nil +} + +func downDir(url string, migrationsFS fs.FS, dir, migrationsTable string, steps int) error { + mg, err := newMigrate(url, migrationsFS, dir, migrationsTable) + if err != nil { + return err + } + defer closeMigrate(dir, mg) + + if err := mg.Steps(-steps); err != nil && !errors.Is(err, migrate.ErrNoChange) { + return fmt.Errorf("roll back %d migration(s) for %s: %w", steps, dir, err) + } + return nil +} + +func versionDir(url string, migrationsFS fs.FS, dir, migrationsTable string) (version uint, dirty bool, err error) { + mg, err := newMigrate(url, migrationsFS, dir, migrationsTable) + if err != nil { + return 0, false, err + } + defer closeMigrate(dir, mg) + + version, dirty, err = mg.Version() + if err != nil && !errors.Is(err, migrate.ErrNilVersion) { + return 0, false, fmt.Errorf("get version for %s: %w", dir, err) + } + return version, dirty, nil +} + +// closeMigrate closes mg, logging source and database close errors separately. +func closeMigrate(dir string, mg *migrate.Migrate) { + srcErr, dbErr := mg.Close() + if srcErr != nil { + log.Printf("warning: closing migration source for %s: %v", dir, srcErr) + } + if dbErr != nil { + log.Printf("warning: closing migration database for %s: %v", dir, dbErr) + } +} diff --git a/go/core/cmd/migrate/main_test.go b/go/core/cmd/migrate/main_test.go new file mode 100644 index 000000000..0e17321a4 --- /dev/null +++ b/go/core/cmd/migrate/main_test.go @@ -0,0 +1,247 @@ +package main + +import ( + "context" + "database/sql" + "fmt" + "maps" + "testing" + "testing/fstest" + + "github.com/kagent-dev/kagent/go/core/internal/dbtest" + _ "github.com/lib/pq" +) + +// --- migration fixtures --- + +// goodCoreFS has two valid core migrations. +var goodCoreFS = fstest.MapFS{ + "core/000001_create.up.sql": {Data: []byte(`CREATE TABLE mig_test (id SERIAL PRIMARY KEY);`)}, + "core/000001_create.down.sql": {Data: []byte(`DROP TABLE IF EXISTS mig_test;`)}, + "core/000002_alter.up.sql": {Data: []byte(`ALTER TABLE mig_test ADD COLUMN name TEXT;`)}, + "core/000002_alter.down.sql": {Data: []byte(`ALTER TABLE mig_test DROP COLUMN IF EXISTS name;`)}, +} + +// oneCoreFS is just the first migration from goodCoreFS. +var oneCoreFS = fstest.MapFS{ + "core/000001_create.up.sql": {Data: []byte(`CREATE TABLE mig_test (id SERIAL PRIMARY KEY);`)}, + "core/000001_create.down.sql": {Data: []byte(`DROP TABLE IF EXISTS mig_test;`)}, +} + +// failOnFirstCoreFS fails immediately on the first migration. +var failOnFirstCoreFS = fstest.MapFS{ + "core/000001_bad.up.sql": {Data: []byte(`ALTER TABLE no_such_table ADD COLUMN x TEXT;`)}, + "core/000001_bad.down.sql": {Data: []byte(`SELECT 1;`)}, +} + +// failOnSecondCoreFS succeeds on migration 1 then fails on migration 2. +var failOnSecondCoreFS = fstest.MapFS{ + "core/000001_create.up.sql": {Data: []byte(`CREATE TABLE mig_test (id SERIAL PRIMARY KEY);`)}, + "core/000001_create.down.sql": {Data: []byte(`DROP TABLE IF EXISTS mig_test;`)}, + "core/000002_bad.up.sql": {Data: []byte(`ALTER TABLE no_such_table ADD COLUMN x TEXT;`)}, + "core/000002_bad.down.sql": {Data: []byte(`SELECT 1;`)}, +} + +// failVectorFS has a vector migration that fails. +var failVectorFS = fstest.MapFS{ + "vector/000001_bad.up.sql": {Data: []byte(`ALTER TABLE no_such_table ADD COLUMN y TEXT;`)}, + "vector/000001_bad.down.sql": {Data: []byte(`SELECT 1;`)}, +} + +// mergeFS combines multiple MapFS values into one. +func mergeFS(fsMaps ...fstest.MapFS) fstest.MapFS { + out := fstest.MapFS{} + for _, m := range fsMaps { + maps.Copy(out, m) + } + return out +} + +// trackVersion reads the current version from a golang-migrate tracking table. +// Returns 0 if the table is empty or does not exist (fully rolled back). +func trackVersion(t *testing.T, connStr, table string) uint { + t.Helper() + db, err := sql.Open("postgres", connStr) + if err != nil { + t.Fatalf("trackVersion: open db: %v", err) + } + defer db.Close() + var v uint + err = db.QueryRowContext(context.Background(), + fmt.Sprintf(`SELECT version FROM %s LIMIT 1`, table)).Scan(&v) + if err != nil { + return 0 // sql.ErrNoRows or table doesn't exist + } + return v +} + +// --- applyDir tests --- + +func TestApplyDir_HappyPath(t *testing.T) { + connStr := dbtest.StartT(context.Background(), t) + + prev, err := applyDir(connStr, goodCoreFS, "core", "schema_migrations") + if err != nil { + t.Fatalf("applyDir: %v", err) + } + if prev != 0 { + t.Errorf("prevVersion = %d, want 0", prev) + } + if got := trackVersion(t, connStr, "schema_migrations"); got != 2 { + t.Errorf("version = %d, want 2", got) + } +} + +func TestApplyDir_NoOpWhenAlreadyAtLatest(t *testing.T) { + connStr := dbtest.StartT(context.Background(), t) + + if _, err := applyDir(connStr, goodCoreFS, "core", "schema_migrations"); err != nil { + t.Fatalf("first apply: %v", err) + } + prev, err := applyDir(connStr, goodCoreFS, "core", "schema_migrations") + if err != nil { + t.Fatalf("second apply: %v", err) + } + if prev != 2 { + t.Errorf("prevVersion on no-op = %d, want 2", prev) + } + if got := trackVersion(t, connStr, "schema_migrations"); got != 2 { + t.Errorf("version = %d, want 2", got) + } +} + +func TestApplyDir_RollsBackWhenFirstMigrationFails(t *testing.T) { + connStr := dbtest.StartT(context.Background(), t) + + if _, err := applyDir(connStr, failOnFirstCoreFS, "core", "schema_migrations"); err == nil { + t.Fatal("expected error, got nil") + } + if got := trackVersion(t, connStr, "schema_migrations"); got != 0 { + t.Errorf("version after rollback = %d, want 0", got) + } +} + +func TestApplyDir_RollsBackWhenLaterMigrationFails(t *testing.T) { + connStr := dbtest.StartT(context.Background(), t) + + if _, err := applyDir(connStr, failOnSecondCoreFS, "core", "schema_migrations"); err == nil { + t.Fatal("expected error, got nil") + } + // Migration 1 succeeded then was rolled back along with the failed migration 2. + if got := trackVersion(t, connStr, "schema_migrations"); got != 0 { + t.Errorf("version after rollback = %d, want 0", got) + } +} + +func TestApplyDir_RollsBackToExistingVersion(t *testing.T) { + connStr := dbtest.StartT(context.Background(), t) + + // Establish a baseline at version 1. + if _, err := applyDir(connStr, oneCoreFS, "core", "schema_migrations"); err != nil { + t.Fatalf("setup: %v", err) + } + + // Advance to version 2 — should fail and roll back to version 1, not 0. + if _, err := applyDir(connStr, failOnSecondCoreFS, "core", "schema_migrations"); err == nil { + t.Fatal("expected error, got nil") + } + if got := trackVersion(t, connStr, "schema_migrations"); got != 1 { + t.Errorf("version after rollback = %d, want 1 (pre-run baseline)", got) + } +} + +// --- rollbackDir tests --- + +func TestRollbackDir_RollsBackToTarget(t *testing.T) { + connStr := dbtest.StartT(context.Background(), t) + + if _, err := applyDir(connStr, goodCoreFS, "core", "schema_migrations"); err != nil { + t.Fatalf("setup: %v", err) + } + + rollbackDir(connStr, goodCoreFS, "core", "schema_migrations", 0) + + if got := trackVersion(t, connStr, "schema_migrations"); got != 0 { + t.Errorf("version after rollback = %d, want 0", got) + } +} + +func TestRollbackDir_PartialRollback(t *testing.T) { + connStr := dbtest.StartT(context.Background(), t) + + if _, err := applyDir(connStr, goodCoreFS, "core", "schema_migrations"); err != nil { + t.Fatalf("setup: %v", err) + } + + // Roll back only one step (2 → 1). + rollbackDir(connStr, goodCoreFS, "core", "schema_migrations", 1) + + if got := trackVersion(t, connStr, "schema_migrations"); got != 1 { + t.Errorf("version after partial rollback = %d, want 1", got) + } +} + +// --- cross-track rollback --- + +// TestCrossTrackRollback_CoreUnchangedWhenVectorFails covers the case where +// core has no new migrations (ErrNoChange) and vector fails. Core should not +// be downgraded by the cross-track rollback. +func TestCrossTrackRollback_CoreUnchangedWhenVectorFails(t *testing.T) { + connStr := dbtest.StartT(context.Background(), t) + + combined := mergeFS(goodCoreFS, failVectorFS) + + // Establish core at its latest version before the run. + if _, err := applyDir(connStr, combined, "core", "schema_migrations"); err != nil { + t.Fatalf("setup core: %v", err) + } + + // Core has no new migrations — applyDir returns ErrNoChange. + corePrev, err := applyDir(connStr, combined, "core", "schema_migrations") + if err != nil { + t.Fatalf("core apply (no-op): %v", err) + } + if corePrev != 2 { + t.Fatalf("corePrev = %d, want 2", corePrev) + } + + // Vector fails and self-rolls-back. + if _, err := applyDir(connStr, combined, "vector", "vector_schema_migrations"); err == nil { + t.Fatal("expected vector error, got nil") + } + + // Cross-track rollback: core should be untouched since corePrev == current version. + rollbackDir(connStr, combined, "core", "schema_migrations", corePrev) + if got := trackVersion(t, connStr, "schema_migrations"); got != 2 { + t.Errorf("core version = %d, want 2 (should not have been downgraded)", got) + } +} + +func TestCrossTrackRollback_CoreRolledBackWhenVectorFails(t *testing.T) { + connStr := dbtest.StartT(context.Background(), t) + + combined := mergeFS(goodCoreFS, failVectorFS) + + // Core succeeds. + corePrev, err := applyDir(connStr, combined, "core", "schema_migrations") + if err != nil { + t.Fatalf("core apply: %v", err) + } + if got := trackVersion(t, connStr, "schema_migrations"); got != 2 { + t.Fatalf("core version = %d, want 2", got) + } + + // Vector fails and rolls itself back. + if _, err := applyDir(connStr, combined, "vector", "vector_schema_migrations"); err == nil { + t.Fatal("expected vector error, got nil") + } + if got := trackVersion(t, connStr, "vector_schema_migrations"); got != 0 { + t.Errorf("vector version after self-rollback = %d, want 0", got) + } + + // Cross-track rollback: core should be rolled back to its pre-run version. + rollbackDir(connStr, combined, "core", "schema_migrations", corePrev) + if got := trackVersion(t, connStr, "schema_migrations"); got != corePrev { + t.Errorf("core version after cross-track rollback = %d, want %d", got, corePrev) + } +} diff --git a/go/core/internal/controller/reconciler/mcp_server_reconciler_test.go b/go/core/internal/controller/reconciler/mcp_server_reconciler_test.go index 1914946c4..15ba5bd62 100644 --- a/go/core/internal/controller/reconciler/mcp_server_reconciler_test.go +++ b/go/core/internal/controller/reconciler/mcp_server_reconciler_test.go @@ -85,19 +85,16 @@ func TestReconcileKagentMCPServer_ErrorPropagation(t *testing.T) { WithObjects(tc.mcpServer). Build() - dbManager, err := database.NewManager(context.Background(), &database.Config{ - PostgresConfig: &database.PostgresConfig{ - URL: connStr, - VectorEnabled: true, - }, - }) - require.NoError(t, err) - defer dbManager.Close() + dbtest.MigrateT(t, connStr, true) - err = dbManager.Initialize() + db, err := database.Connect(context.Background(), &database.PostgresConfig{ + URL: connStr, + VectorEnabled: true, + }) require.NoError(t, err) + defer db.Close() - dbClient := database.NewClient(dbManager) + dbClient := database.NewClient(db) // Create reconciler translator := agenttranslator.NewAdkApiTranslator( diff --git a/go/core/internal/database/client.go b/go/core/internal/database/client.go deleted file mode 100644 index 8a38366a4..000000000 --- a/go/core/internal/database/client.go +++ /dev/null @@ -1,690 +0,0 @@ -package database - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "strings" - "time" - - dbpkg "github.com/kagent-dev/kagent/go/api/database" - "github.com/kagent-dev/kagent/go/api/v1alpha2" - "github.com/pgvector/pgvector-go" - "gorm.io/gorm" - ctrllog "sigs.k8s.io/controller-runtime/pkg/log" - "trpc.group/trpc-go/trpc-a2a-go/protocol" -) - -type clientImpl struct { - db *gorm.DB -} - -func NewClient(dbManager *Manager) dbpkg.Client { - return &clientImpl{ - db: dbManager.db, - } -} - -// CreateFeedback creates a new feedback record -func (c *clientImpl) StoreFeedback(ctx context.Context, feedback *dbpkg.Feedback) error { - if err := c.db.WithContext(ctx).Create(feedback).Error; err != nil { - return fmt.Errorf("failed to create feedback: %w", err) - } - return nil -} - -// CreateSession creates a new session record -func (c *clientImpl) StoreSession(ctx context.Context, session *dbpkg.Session) error { - return save(c.db.WithContext(ctx), session) -} - -// CreateAgent creates a new agent record -func (c *clientImpl) StoreAgent(ctx context.Context, agent *dbpkg.Agent) error { - return save(c.db.WithContext(ctx), agent) -} - -// CreateToolServer creates a new tool server record -func (c *clientImpl) StoreToolServer(ctx context.Context, toolServer *dbpkg.ToolServer) (*dbpkg.ToolServer, error) { - err := save(c.db.WithContext(ctx), toolServer) - if err != nil { - return nil, err - } - return toolServer, nil -} - -// CreateTool creates a new tool record -func (c *clientImpl) StoreTool(ctx context.Context, tool *dbpkg.Tool) error { - return save(c.db.WithContext(ctx), tool) -} - -// DeleteTask deletes a task by ID -func (c *clientImpl) DeleteTask(ctx context.Context, taskID string) error { - return delete[dbpkg.Task](c.db.WithContext(ctx), Clause{Key: "id", Value: taskID}) -} - -// DeleteSession deletes a session by id and user ID -func (c *clientImpl) DeleteSession(ctx context.Context, sessionID string, userID string) error { - return delete[dbpkg.Session](c.db.WithContext(ctx), - Clause{Key: "id", Value: sessionID}, - Clause{Key: "user_id", Value: userID}) -} - -// DeleteAgent deletes an agent by name and user ID -func (c *clientImpl) DeleteAgent(ctx context.Context, agentID string) error { - return delete[dbpkg.Agent](c.db.WithContext(ctx), Clause{Key: "id", Value: agentID}) -} - -// DeleteToolServer deletes a tool server by name and user ID -func (c *clientImpl) DeleteToolServer(ctx context.Context, serverName string, groupKind string) error { - return delete[dbpkg.ToolServer](c.db.WithContext(ctx), - Clause{Key: "name", Value: serverName}, - Clause{Key: "group_kind", Value: groupKind}) -} - -func (c *clientImpl) DeleteToolsForServer(ctx context.Context, serverName string, groupKind string) error { - return delete[dbpkg.Tool](c.db.WithContext(ctx), - Clause{Key: "server_name", Value: serverName}, - Clause{Key: "group_kind", Value: groupKind}) -} - -// GetTaskMessages retrieves messages for a specific task -func (c *clientImpl) GetTaskMessages(ctx context.Context, taskID int) ([]*protocol.Message, error) { - messages, err := list[dbpkg.Event](c.db.WithContext(ctx), Clause{Key: "task_id", Value: taskID}) - if err != nil { - return nil, err - } - - protocolMessages := make([]*protocol.Message, 0, len(messages)) - for _, message := range messages { - var protocolMessage protocol.Message - if err := json.Unmarshal([]byte(message.Data), &protocolMessage); err != nil { - return nil, fmt.Errorf("failed to deserialize message: %w", err) - } - protocolMessages = append(protocolMessages, &protocolMessage) - } - - return protocolMessages, nil -} - -// GetSession retrieves a session by id and user ID -func (c *clientImpl) GetSession(ctx context.Context, sessionID string, userID string) (*dbpkg.Session, error) { - return get[dbpkg.Session](c.db.WithContext(ctx), - Clause{Key: "id", Value: sessionID}, - Clause{Key: "user_id", Value: userID}) -} - -// GetAgent retrieves an agent by name and user ID -func (c *clientImpl) GetAgent(ctx context.Context, agentID string) (*dbpkg.Agent, error) { - return get[dbpkg.Agent](c.db.WithContext(ctx), Clause{Key: "id", Value: agentID}) -} - -// GetTool retrieves a tool by provider (name) and user ID -func (c *clientImpl) GetTool(ctx context.Context, provider string) (*dbpkg.Tool, error) { - return get[dbpkg.Tool](c.db.WithContext(ctx), Clause{Key: "name", Value: provider}) -} - -// GetToolServer retrieves a tool server by name and user ID -func (c *clientImpl) GetToolServer(ctx context.Context, serverName string) (*dbpkg.ToolServer, error) { - return get[dbpkg.ToolServer](c.db.WithContext(ctx), Clause{Key: "name", Value: serverName}) -} - -// ListFeedback lists all feedback for a user -func (c *clientImpl) ListFeedback(ctx context.Context, userID string) ([]dbpkg.Feedback, error) { - feedback, err := list[dbpkg.Feedback](c.db.WithContext(ctx), Clause{Key: "user_id", Value: userID}) - if err != nil { - return nil, err - } - - return feedback, nil -} - -func (c *clientImpl) StoreEvents(ctx context.Context, events ...*dbpkg.Event) error { - for _, event := range events { - err := save(c.db.WithContext(ctx), event) - if err != nil { - return fmt.Errorf("failed to create event: %w", err) - } - } - return nil -} - -// ListSessionRuns lists all runs for a specific session -func (c *clientImpl) ListTasksForSession(ctx context.Context, sessionID string) ([]*protocol.Task, error) { - tasks, err := list[dbpkg.Task](c.db.WithContext(ctx), - Clause{Key: "session_id", Value: sessionID}, - ) - if err != nil { - return nil, err - } - - return dbpkg.ParseTasks(tasks) -} - -func (c *clientImpl) ListSessionsForAgent(ctx context.Context, agentID string, userID string) ([]dbpkg.Session, error) { - var sessions []dbpkg.Session - err := c.db.WithContext(ctx). - Where("agent_id = ? AND user_id = ?", agentID, userID). - Where("source IS NULL OR source != ?", dbpkg.SessionSourceAgent). - Order("created_at ASC"). - Find(&sessions).Error - if err != nil { - return nil, fmt.Errorf("failed to list sessions for agent: %w", err) - } - return sessions, nil -} - -// ListSessions lists all sessions for a user -func (c *clientImpl) ListSessions(ctx context.Context, userID string) ([]dbpkg.Session, error) { - return list[dbpkg.Session](c.db.WithContext(ctx), Clause{Key: "user_id", Value: userID}) -} - -// ListAgents lists all agents -func (c *clientImpl) ListAgents(ctx context.Context) ([]dbpkg.Agent, error) { - return list[dbpkg.Agent](c.db.WithContext(ctx)) -} - -// ListToolServers lists all tool servers for a user -func (c *clientImpl) ListToolServers(ctx context.Context) ([]dbpkg.ToolServer, error) { - return list[dbpkg.ToolServer](c.db.WithContext(ctx)) -} - -// ListTools lists all tools for a user -func (c *clientImpl) ListTools(ctx context.Context) ([]dbpkg.Tool, error) { - return list[dbpkg.Tool](c.db.WithContext(ctx)) -} - -// ListToolsForServer lists all tools for a specific server and group kind -func (c *clientImpl) ListToolsForServer(ctx context.Context, serverName string, groupKind string) ([]dbpkg.Tool, error) { - return list[dbpkg.Tool](c.db.WithContext(ctx), - Clause{Key: "server_name", Value: serverName}, - Clause{Key: "group_kind", Value: groupKind}) -} - -// RefreshToolsForServer atomically replaces all tools for a server. -// Uses a database transaction to ensure consistency under concurrent access. -// -// IMPORTANT: This function should only contain fast database operations. -// Network I/O (e.g., fetching tools from remote MCP servers) must happen -// BEFORE calling this function, not inside it. Holding a database transaction -// during slow operations can cause contention and degrade performance. -func (c *clientImpl) RefreshToolsForServer(ctx context.Context, serverName string, groupKind string, tools ...*v1alpha2.MCPTool) error { - return c.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { - // Delete all existing tools for this server in the transaction - if err := delete[dbpkg.Tool](tx, - Clause{Key: "server_name", Value: serverName}, - Clause{Key: "group_kind", Value: groupKind}); err != nil { - return fmt.Errorf("failed to delete existing tools: %w", err) - } - - // Insert all new tools - for _, tool := range tools { - if err := save(tx, &dbpkg.Tool{ - ID: tool.Name, - ServerName: serverName, - GroupKind: groupKind, - Description: tool.Description, - }); err != nil { - return fmt.Errorf("failed to create tool %s: %w", tool.Name, err) - } - } - - return nil - }) -} - -// ListMessagesForRun retrieves messages for a specific run (helper method) -func (c *clientImpl) ListMessagesForTask(ctx context.Context, taskID, userID string) ([]*protocol.Message, error) { - messages, err := list[dbpkg.Event](c.db.WithContext(ctx), - Clause{Key: "task_id", Value: taskID}, - Clause{Key: "user_id", Value: userID}) - if err != nil { - return nil, err - } - - return dbpkg.ParseMessages(messages) -} - -// ListEventsForSession retrieves events for a specific session -// Use Limit with DESC for getting latest events, ASC with no limit for chronological order -func (c *clientImpl) ListEventsForSession(ctx context.Context, sessionID, userID string, options dbpkg.QueryOptions) ([]*dbpkg.Event, error) { - var events []dbpkg.Event - order := "created_at DESC" - if options.OrderAsc { - order = "created_at ASC" - } - query := c.db.WithContext(ctx). - Where("session_id = ?", sessionID). - Where("user_id = ?", userID). - Order(order) - - if !options.After.IsZero() { - query = query.Where("created_at > ?", options.After) - } - - if options.Limit > 0 { - query = query.Limit(options.Limit) - } - - err := query.Find(&events).Error - if err != nil { - return nil, err - } - - protocolEvents := make([]*dbpkg.Event, 0, len(events)) - for _, event := range events { - protocolEvents = append(protocolEvents, &event) - } - - return protocolEvents, nil -} - -// GetMessage retrieves a protocol message from the database -func (c *clientImpl) GetMessage(ctx context.Context, messageID string) (*protocol.Message, error) { - dbMessage, err := get[dbpkg.Event](c.db.WithContext(ctx), Clause{Key: "id", Value: messageID}) - if err != nil { - return nil, fmt.Errorf("failed to get message: %w", err) - } - - var message protocol.Message - if err := json.Unmarshal([]byte(dbMessage.Data), &message); err != nil { - return nil, fmt.Errorf("failed to deserialize message: %w", err) - } - - return &message, nil -} - -// DeleteMessage deletes a protocol message from the database -func (c *clientImpl) DeleteMessage(ctx context.Context, messageID string) error { - return delete[dbpkg.Event](c.db.WithContext(ctx), Clause{Key: "id", Value: messageID}) -} - -// ListMessagesByContextID retrieves messages by context ID with optional limit -func (c *clientImpl) ListMessagesByContextID(ctx context.Context, contextID string, limit int) ([]protocol.Message, error) { - var dbMessages []dbpkg.Event - query := c.db.WithContext(ctx).Where("session_id = ?", contextID).Order("created_at DESC") - - if limit > 0 { - query = query.Limit(limit) - } - - err := query.Find(&dbMessages).Error - if err != nil { - return nil, fmt.Errorf("failed to get messages: %w", err) - } - - protocolMessages := make([]protocol.Message, 0, len(dbMessages)) - for _, dbMessage := range dbMessages { - var protocolMessage protocol.Message - if err := json.Unmarshal([]byte(dbMessage.Data), &protocolMessage); err != nil { - return nil, fmt.Errorf("failed to deserialize message: %w", err) - } - protocolMessages = append(protocolMessages, protocolMessage) - } - - return protocolMessages, nil -} - -// StoreTask stores a MemoryCancellableTask in the database -func (c *clientImpl) StoreTask(ctx context.Context, task *protocol.Task) error { - data, err := json.Marshal(task) - if err != nil { - return fmt.Errorf("failed to serialize task: %w", err) - } - - dbTask := dbpkg.Task{ - ID: task.ID, - Data: string(data), - SessionID: task.ContextID, - } - - return save(c.db.WithContext(ctx), &dbTask) -} - -// GetTask retrieves a MemoryCancellableTask from the database -func (c *clientImpl) GetTask(ctx context.Context, taskID string) (*protocol.Task, error) { - dbTask, err := get[dbpkg.Task](c.db.WithContext(ctx), Clause{Key: "id", Value: taskID}) - if err != nil { - return nil, fmt.Errorf("failed to get task: %w", err) - } - - var task protocol.Task - if err := json.Unmarshal([]byte(dbTask.Data), &task); err != nil { - return nil, fmt.Errorf("failed to deserialize task: %w", err) - } - - return &task, nil -} - -// TaskExists checks if a task exists in the database -func (c *clientImpl) TaskExists(ctx context.Context, taskID string) bool { - var count int64 - c.db.WithContext(ctx).Model(&dbpkg.Task{}).Where("id = ?", taskID).Count(&count) - return count > 0 -} - -// StorePushNotification stores a push notification configuration in the database -func (c *clientImpl) StorePushNotification(ctx context.Context, config *protocol.TaskPushNotificationConfig) error { - data, err := json.Marshal(config) - if err != nil { - return fmt.Errorf("failed to serialize push notification config: %w", err) - } - - dbPushNotification := dbpkg.PushNotification{ - ID: config.PushNotificationConfig.ID, - TaskID: config.TaskID, - Data: string(data), - } - - return save(c.db.WithContext(ctx), &dbPushNotification) -} - -// GetPushNotification retrieves a push notification configuration from the database -func (c *clientImpl) GetPushNotification(ctx context.Context, taskID string, configID string) (*protocol.TaskPushNotificationConfig, error) { - dbPushNotification, err := get[dbpkg.PushNotification](c.db.WithContext(ctx), - Clause{Key: "task_id", Value: taskID}, - Clause{Key: "id", Value: configID}) - if err != nil { - return nil, fmt.Errorf("failed to get push notification config: %w", err) - } - - var config protocol.TaskPushNotificationConfig - if err := json.Unmarshal([]byte(dbPushNotification.Data), &config); err != nil { - return nil, fmt.Errorf("failed to deserialize push notification config: %w", err) - } - - return &config, nil -} - -func (c *clientImpl) ListPushNotifications(ctx context.Context, taskID string) ([]*protocol.TaskPushNotificationConfig, error) { - pushNotifications, err := list[dbpkg.PushNotification](c.db.WithContext(ctx), Clause{Key: "task_id", Value: taskID}) - if err != nil { - return nil, err - } - - protocolPushNotifications := make([]*protocol.TaskPushNotificationConfig, 0, len(pushNotifications)) - for _, pushNotification := range pushNotifications { - var protocolPushNotification protocol.TaskPushNotificationConfig - if err := json.Unmarshal([]byte(pushNotification.Data), &protocolPushNotification); err != nil { - return nil, fmt.Errorf("failed to deserialize push notification config: %w", err) - } - protocolPushNotifications = append(protocolPushNotifications, &protocolPushNotification) - } - - return protocolPushNotifications, nil -} - -// DeletePushNotification deletes a push notification configuration from the database -func (c *clientImpl) DeletePushNotification(ctx context.Context, taskID string) error { - return delete[dbpkg.PushNotification](c.db.WithContext(ctx), Clause{Key: "task_id", Value: taskID}) -} - -// StoreCheckpoint stores a LangGraph checkpoint and its writes atomically -func (c *clientImpl) StoreCheckpoint(ctx context.Context, checkpoint *dbpkg.LangGraphCheckpoint) error { - err := save(c.db.WithContext(ctx), checkpoint) - if err != nil { - return fmt.Errorf("failed to store checkpoint: %w", err) - } - - return nil -} - -func (c *clientImpl) StoreCheckpointWrites(ctx context.Context, writes []*dbpkg.LangGraphCheckpointWrite) error { - return c.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { - for _, write := range writes { - if err := save(tx, write); err != nil { - return fmt.Errorf("failed to store checkpoint write: %w", err) - } - } - return nil - }) -} - -// ListCheckpoints lists checkpoints for a thread, optionally filtered by beforeCheckpointID -func (c *clientImpl) ListCheckpoints(ctx context.Context, userID, threadID, checkpointNS string, checkpointID *string, limit int) ([]*dbpkg.LangGraphCheckpointTuple, error) { - var checkpointTuples []*dbpkg.LangGraphCheckpointTuple - db := c.db.WithContext(ctx) - if err := db.Transaction(func(tx *gorm.DB) error { - query := db.Where( - "user_id = ? AND thread_id = ? AND checkpoint_ns = ?", - userID, threadID, checkpointNS, - ) - - if checkpointID != nil { - query = query.Where("checkpoint_id = ?", *checkpointID) - } else { - query = query.Order("checkpoint_id DESC") - } - - // Apply limit - if limit > 0 { - query = query.Limit(limit) - } - - var checkpoints []dbpkg.LangGraphCheckpoint - err := query.Find(&checkpoints).Error - if err != nil { - return fmt.Errorf("failed to list checkpoints: %w", err) - } - - for _, checkpoint := range checkpoints { - var writes []*dbpkg.LangGraphCheckpointWrite - if err := tx.Where( - "user_id = ? AND thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?", - userID, threadID, checkpointNS, checkpoint.CheckpointID, - ).Order("task_id, write_idx").Find(&writes).Error; err != nil { - return fmt.Errorf("failed to get checkpoint writes: %w", err) - } - checkpointTuples = append(checkpointTuples, &dbpkg.LangGraphCheckpointTuple{ - Checkpoint: &checkpoint, - Writes: writes, - }) - } - return nil - }); err != nil { - return nil, fmt.Errorf("failed to list checkpoints: %w", err) - } - return checkpointTuples, nil -} - -// DeleteCheckpoint deletes a checkpoint and its writes atomically -func (c *clientImpl) DeleteCheckpoint(ctx context.Context, userID, threadID string) error { - clauses := []Clause{ - {Key: "user_id", Value: userID}, - {Key: "thread_id", Value: threadID}, - } - return c.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { - if err := delete[dbpkg.LangGraphCheckpoint](tx, clauses...); err != nil { - return fmt.Errorf("failed to delete checkpoint: %w", err) - } - if err := delete[dbpkg.LangGraphCheckpointWrite](tx, clauses...); err != nil { - return fmt.Errorf("failed to delete checkpoint writes: %w", err) - } - return nil - }) -} - -// CrewAI methods - -// StoreCrewAIMemory stores CrewAI agent memory -func (c *clientImpl) StoreCrewAIMemory(ctx context.Context, memory *dbpkg.CrewAIAgentMemory) error { - err := save(c.db.WithContext(ctx), memory) - if err != nil { - return fmt.Errorf("failed to store CrewAI agent memory: %w", err) - } - return nil -} - -// SearchCrewAIMemoryByTask searches CrewAI agent memory by task description across all agents for a session -func (c *clientImpl) SearchCrewAIMemoryByTask(ctx context.Context, userID, threadID, taskDescription string, limit int) ([]*dbpkg.CrewAIAgentMemory, error) { - var memories []*dbpkg.CrewAIAgentMemory - - query := c.db.WithContext(ctx).Where( - "user_id = ? AND thread_id = ? AND (memory_data LIKE ? OR memory_data->>'task_description' LIKE ?)", - userID, threadID, "%"+taskDescription+"%", "%"+taskDescription+"%", - ).Order("created_at DESC, memory_data->>'score' ASC") - - // Apply limit - if limit > 0 { - query = query.Limit(limit) - } - - err := query.Find(&memories).Error - if err != nil { - return nil, fmt.Errorf("failed to search CrewAI agent memory by task: %w", err) - } - - return memories, nil -} - -// ResetCrewAIMemory deletes all CrewAI agent memory for a session -func (c *clientImpl) ResetCrewAIMemory(ctx context.Context, userID, threadID string) error { - result := c.db.WithContext(ctx).Where( - "user_id = ? AND thread_id = ?", - userID, threadID, - ).Delete(&dbpkg.CrewAIAgentMemory{}) - - if result.Error != nil { - return fmt.Errorf("failed to reset CrewAI agent memory: %w", result.Error) - } - - return nil -} - -// StoreCrewAIFlowState stores CrewAI flow state -func (c *clientImpl) StoreCrewAIFlowState(ctx context.Context, state *dbpkg.CrewAIFlowState) error { - err := save(c.db.WithContext(ctx), state) - if err != nil { - return fmt.Errorf("failed to store CrewAI flow state: %w", err) - } - return nil -} - -// GetCrewAIFlowState retrieves the most recent CrewAI flow state -func (c *clientImpl) GetCrewAIFlowState(ctx context.Context, userID, threadID string) (*dbpkg.CrewAIFlowState, error) { - var state dbpkg.CrewAIFlowState - - // Get the most recent state by ordering by created_at DESC - // Thread_id is equivalent to flow_uuid used by CrewAI because in each session there is only one flow - err := c.db.WithContext(ctx).Where( - "user_id = ? AND thread_id = ?", - userID, threadID, - ).Order("created_at DESC").First(&state).Error - - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, nil // Return nil for not found, as expected by the Python client - } - return nil, fmt.Errorf("failed to get CrewAI flow state: %w", err) - } - - return &state, nil -} - -// AgentMemory methods - -func (c *clientImpl) StoreAgentMemory(ctx context.Context, memory *dbpkg.Memory) error { - return save(c.db.WithContext(ctx), memory) -} - -func (c *clientImpl) StoreAgentMemories(ctx context.Context, memories []*dbpkg.Memory) error { - return c.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { - for _, memory := range memories { - if err := save(tx, memory); err != nil { - return err - } - } - return nil - }) -} - -func (c *clientImpl) SearchAgentMemory(ctx context.Context, agentName, userID string, embedding pgvector.Vector, limit int) ([]dbpkg.AgentMemorySearchResult, error) { - var results []dbpkg.AgentMemorySearchResult - - // pgvector <=> operator for cosine distance. - // COALESCE guards against NaN when either vector has zero magnitude. - query := ` - SELECT *, COALESCE(1 - (embedding <=> ?), 0) as score - FROM memory - WHERE agent_name = ? AND user_id = ? - ORDER BY embedding <=> ? ASC - LIMIT ? - ` - if err := c.db.WithContext(ctx).Raw(query, embedding, agentName, userID, embedding, limit).Scan(&results).Error; err != nil { - return nil, fmt.Errorf("failed to search agent memory: %w", err) - } - - // Increment access count for found memories synchronously. - if len(results) > 0 { - ids := make([]string, len(results)) - for i, m := range results { - ids[i] = m.ID - } - if err := c.db.WithContext(ctx).Model(&dbpkg.Memory{}).Where("id IN ?", ids).UpdateColumn("access_count", gorm.Expr("access_count + ?", 1)).Error; err != nil { - return nil, fmt.Errorf("failed to increment access count: %w", err) - } - } - - return results, nil -} - -// PruneExpiredMemories deletes expired memories if they haven't been accessed enough, -// otherwise extends their TTL. -func (c *clientImpl) PruneExpiredMemories(ctx context.Context) error { - return c.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { - now := time.Now() - - // 1. Extend TTL for popular memories (AccessCount >= 10) - if err := tx.Model(&dbpkg.Memory{}). - Where("expires_at < ? AND access_count >= ?", now, 10). - Updates(map[string]any{ - "expires_at": now.Add(15 * 24 * time.Hour), - "access_count": 0, // Reset count to ensure it's still relevant next time - }).Error; err != nil { - return fmt.Errorf("failed to extend TTL for popular memories: %w", err) - } - - // 2. Delete unpopular expired memories (AccessCount < 10) - result := tx.Where("expires_at < ? AND access_count < ?", now, 10).Delete(&dbpkg.Memory{}) - if result.Error != nil { - return fmt.Errorf("failed to delete expired memories: %w", result.Error) - } - if result.RowsAffected > 0 { - log := ctrllog.Log.WithName("database") - log.Info("Pruned expired memories", "count", result.RowsAffected) - } - - return nil - }) -} - -func (c *clientImpl) ListAgentMemories(ctx context.Context, agentName, userID string) ([]dbpkg.Memory, error) { - normalizedName := strings.ReplaceAll(agentName, "-", "_") - - var memories []dbpkg.Memory - query := c.db.WithContext(ctx).Where("(agent_name = ? OR agent_name = ?) AND user_id = ?", agentName, normalizedName, userID). - Order("access_count DESC") - - if err := query.Find(&memories).Error; err != nil { - return nil, fmt.Errorf("failed to list agent memories: %w", err) - } - return memories, nil -} - -func (c *clientImpl) DeleteAgentMemory(ctx context.Context, agentName, userID string) error { - if err := c.deleteAgentMemoryByQuery(ctx, agentName, userID); err != nil { - return err - } - normalizedName := strings.ReplaceAll(agentName, "-", "_") - if normalizedName != agentName { - return c.deleteAgentMemoryByQuery(ctx, normalizedName, userID) - } - return nil -} - -func (c *clientImpl) deleteAgentMemoryByQuery(ctx context.Context, agentName, userID string) error { - if err := c.db.WithContext(ctx).Where("agent_name = ? AND user_id = ?", agentName, userID).Delete(&dbpkg.Memory{}).Error; err != nil { - return fmt.Errorf("failed to delete agent memory: %w", err) - } - return nil -} diff --git a/go/core/internal/database/client_postgres.go b/go/core/internal/database/client_postgres.go new file mode 100644 index 000000000..cd8e424e2 --- /dev/null +++ b/go/core/internal/database/client_postgres.go @@ -0,0 +1,908 @@ +package database + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + dbpkg "github.com/kagent-dev/kagent/go/api/database" + "github.com/kagent-dev/kagent/go/api/v1alpha2" + dbgen "github.com/kagent-dev/kagent/go/core/internal/database/gen" + "github.com/pgvector/pgvector-go" + "trpc.group/trpc-go/trpc-a2a-go/protocol" +) + +type postgresClient struct { + q *dbgen.Queries + db *sql.DB +} + +func NewClient(db *sql.DB) dbpkg.Client { + return &postgresClient{ + q: dbgen.New(db), + db: db, + } +} + +// ── Agents ──────────────────────────────────────────────────────────────────── + +func (c *postgresClient) StoreAgent(ctx context.Context, agent *dbpkg.Agent) error { + return c.q.UpsertAgent(ctx, dbgen.UpsertAgentParams{ + ID: agent.ID, + Type: agent.Type, + Config: agent.Config, + }) +} + +func (c *postgresClient) GetAgent(ctx context.Context, id string) (*dbpkg.Agent, error) { + row, err := c.q.GetAgent(ctx, id) + if err != nil { + return nil, fmt.Errorf("failed to get agent %s: %w", id, err) + } + return toAgent(row), nil +} + +func (c *postgresClient) ListAgents(ctx context.Context) ([]dbpkg.Agent, error) { + rows, err := c.q.ListAgents(ctx) + if err != nil { + return nil, fmt.Errorf("failed to list agents: %w", err) + } + agents := make([]dbpkg.Agent, len(rows)) + for i, r := range rows { + agents[i] = *toAgent(r) + } + return agents, nil +} + +func (c *postgresClient) DeleteAgent(ctx context.Context, agentID string) error { + return c.q.SoftDeleteAgent(ctx, agentID) +} + +// ── Sessions ────────────────────────────────────────────────────────────────── + +func (c *postgresClient) StoreSession(ctx context.Context, session *dbpkg.Session) error { + params := dbgen.UpsertSessionParams{ + ID: session.ID, + UserID: session.UserID, + Name: ptrToNullString(session.Name), + AgentID: ptrToNullString(session.AgentID), + } + if session.Source != nil { + params.Source = sql.NullString{String: string(*session.Source), Valid: true} + } + return c.q.UpsertSession(ctx, params) +} + +func (c *postgresClient) GetSession(ctx context.Context, sessionID, userID string) (*dbpkg.Session, error) { + row, err := c.q.GetSession(ctx, dbgen.GetSessionParams{ID: sessionID, UserID: userID}) + if err != nil { + return nil, fmt.Errorf("failed to get session %s: %w", sessionID, err) + } + return toSession(row), nil +} + +func (c *postgresClient) ListSessions(ctx context.Context, userID string) ([]dbpkg.Session, error) { + rows, err := c.q.ListSessions(ctx, userID) + if err != nil { + return nil, fmt.Errorf("failed to list sessions: %w", err) + } + sessions := make([]dbpkg.Session, len(rows)) + for i, r := range rows { + sessions[i] = *toSession(r) + } + return sessions, nil +} + +func (c *postgresClient) ListSessionsForAgent(ctx context.Context, agentID, userID string) ([]dbpkg.Session, error) { + rows, err := c.q.ListSessionsForAgent(ctx, dbgen.ListSessionsForAgentParams{ + AgentID: ptrToNullString(&agentID), + UserID: userID, + }) + if err != nil { + return nil, fmt.Errorf("failed to list sessions for agent: %w", err) + } + sessions := make([]dbpkg.Session, len(rows)) + for i, r := range rows { + sessions[i] = *toSession(r) + } + return sessions, nil +} + +func (c *postgresClient) DeleteSession(ctx context.Context, sessionID, userID string) error { + return c.q.SoftDeleteSession(ctx, dbgen.SoftDeleteSessionParams{ID: sessionID, UserID: userID}) +} + +// ── Events ──────────────────────────────────────────────────────────────────── + +func (c *postgresClient) StoreEvents(ctx context.Context, events ...*dbpkg.Event) error { + for _, e := range events { + if err := c.q.InsertEvent(ctx, dbgen.InsertEventParams{ + ID: e.ID, + UserID: e.UserID, + SessionID: sql.NullString{String: e.SessionID, Valid: e.SessionID != ""}, + Data: e.Data, + }); err != nil { + return fmt.Errorf("failed to store event %s: %w", e.ID, err) + } + } + return nil +} + +func (c *postgresClient) ListEventsForSession(ctx context.Context, sessionID, userID string, opts dbpkg.QueryOptions) ([]*dbpkg.Event, error) { + var rows []dbgen.Event + var err error + nullSessionID := sql.NullString{String: sessionID, Valid: sessionID != ""} + + switch { + case opts.OrderAsc && opts.Limit > 0: + rows, err = c.q.ListEventsForSessionAscLimit(ctx, dbgen.ListEventsForSessionAscLimitParams{ + SessionID: nullSessionID, UserID: userID, Column3: opts.After, Limit: int32(opts.Limit), + }) + case opts.OrderAsc: + rows, err = c.q.ListEventsForSessionAsc(ctx, dbgen.ListEventsForSessionAscParams{ + SessionID: nullSessionID, UserID: userID, Column3: opts.After, + }) + case opts.Limit > 0: + rows, err = c.q.ListEventsForSessionDescLimit(ctx, dbgen.ListEventsForSessionDescLimitParams{ + SessionID: nullSessionID, UserID: userID, Column3: opts.After, Limit: int32(opts.Limit), + }) + default: + rows, err = c.q.ListEventsForSessionDesc(ctx, dbgen.ListEventsForSessionDescParams{ + SessionID: nullSessionID, UserID: userID, Column3: opts.After, + }) + } + if err != nil { + return nil, fmt.Errorf("failed to list events for session: %w", err) + } + + events := make([]*dbpkg.Event, len(rows)) + for i, r := range rows { + events[i] = toEvent(r) + } + return events, nil +} + +// ── Tasks ───────────────────────────────────────────────────────────────────── + +func (c *postgresClient) StoreTask(ctx context.Context, task *protocol.Task) error { + data, err := json.Marshal(task) + if err != nil { + return fmt.Errorf("failed to serialize task: %w", err) + } + return c.q.UpsertTask(ctx, dbgen.UpsertTaskParams{ + ID: task.ID, + Data: string(data), + SessionID: sql.NullString{String: task.ContextID, Valid: task.ContextID != ""}, + }) +} + +func (c *postgresClient) GetTask(ctx context.Context, taskID string) (*protocol.Task, error) { + row, err := c.q.GetTask(ctx, taskID) + if err != nil { + return nil, fmt.Errorf("failed to get task %s: %w", taskID, err) + } + var task protocol.Task + if err := json.Unmarshal([]byte(row.Data), &task); err != nil { + return nil, fmt.Errorf("failed to deserialize task: %w", err) + } + return &task, nil +} + +func (c *postgresClient) ListTasksForSession(ctx context.Context, sessionID string) ([]*protocol.Task, error) { + rows, err := c.q.ListTasksForSession(ctx, sql.NullString{String: sessionID, Valid: true}) + if err != nil { + return nil, fmt.Errorf("failed to list tasks for session: %w", err) + } + tasks := make([]dbpkg.Task, len(rows)) + for i, r := range rows { + tasks[i] = *toTask(r) + } + return dbpkg.ParseTasks(tasks) +} + +func (c *postgresClient) DeleteTask(ctx context.Context, taskID string) error { + return c.q.SoftDeleteTask(ctx, taskID) +} + +// ── Push Notifications ──────────────────────────────────────────────────────── + +func (c *postgresClient) StorePushNotification(ctx context.Context, config *protocol.TaskPushNotificationConfig) error { + data, err := json.Marshal(config) + if err != nil { + return fmt.Errorf("failed to serialize push notification: %w", err) + } + return c.q.UpsertPushNotification(ctx, dbgen.UpsertPushNotificationParams{ + ID: config.PushNotificationConfig.ID, + TaskID: config.TaskID, + Data: string(data), + }) +} + +func (c *postgresClient) GetPushNotification(ctx context.Context, taskID, configID string) (*protocol.TaskPushNotificationConfig, error) { + row, err := c.q.GetPushNotification(ctx, dbgen.GetPushNotificationParams{TaskID: taskID, ID: configID}) + if err != nil { + return nil, fmt.Errorf("failed to get push notification: %w", err) + } + var cfg protocol.TaskPushNotificationConfig + if err := json.Unmarshal([]byte(row.Data), &cfg); err != nil { + return nil, fmt.Errorf("failed to deserialize push notification: %w", err) + } + return &cfg, nil +} + +func (c *postgresClient) ListPushNotifications(ctx context.Context, taskID string) ([]*protocol.TaskPushNotificationConfig, error) { + rows, err := c.q.ListPushNotifications(ctx, taskID) + if err != nil { + return nil, fmt.Errorf("failed to list push notifications: %w", err) + } + result := make([]*protocol.TaskPushNotificationConfig, 0, len(rows)) + for _, row := range rows { + var cfg protocol.TaskPushNotificationConfig + if err := json.Unmarshal([]byte(row.Data), &cfg); err != nil { + return nil, fmt.Errorf("failed to deserialize push notification: %w", err) + } + result = append(result, &cfg) + } + return result, nil +} + +func (c *postgresClient) DeletePushNotification(ctx context.Context, taskID string) error { + return c.q.SoftDeletePushNotification(ctx, taskID) +} + +// ── Feedback ────────────────────────────────────────────────────────────────── + +func (c *postgresClient) StoreFeedback(ctx context.Context, feedback *dbpkg.Feedback) error { + _, err := c.q.InsertFeedback(ctx, dbgen.InsertFeedbackParams{ + UserID: feedback.UserID, + MessageID: ptrToNullInt64(feedback.MessageID), + IsPositive: sql.NullBool{Bool: feedback.IsPositive, Valid: true}, + FeedbackText: feedback.FeedbackText, + IssueType: feedback.IssueType, + }) + return err +} + +func (c *postgresClient) ListFeedback(ctx context.Context, userID string) ([]dbpkg.Feedback, error) { + rows, err := c.q.ListFeedback(ctx, userID) + if err != nil { + return nil, fmt.Errorf("failed to list feedback: %w", err) + } + result := make([]dbpkg.Feedback, len(rows)) + for i, r := range rows { + result[i] = *toFeedback(r) + } + return result, nil +} + +// ── Tools ───────────────────────────────────────────────────────────────────── + +func (c *postgresClient) GetTool(ctx context.Context, name string) (*dbpkg.Tool, error) { + row, err := c.q.GetTool(ctx, name) + if err != nil { + return nil, fmt.Errorf("failed to get tool %s: %w", name, err) + } + return toTool(row), nil +} + +func (c *postgresClient) ListTools(ctx context.Context) ([]dbpkg.Tool, error) { + rows, err := c.q.ListTools(ctx) + if err != nil { + return nil, fmt.Errorf("failed to list tools: %w", err) + } + tools := make([]dbpkg.Tool, len(rows)) + for i, r := range rows { + tools[i] = *toTool(r) + } + return tools, nil +} + +func (c *postgresClient) ListToolsForServer(ctx context.Context, serverName, groupKind string) ([]dbpkg.Tool, error) { + rows, err := c.q.ListToolsForServer(ctx, dbgen.ListToolsForServerParams{ServerName: serverName, GroupKind: groupKind}) + if err != nil { + return nil, fmt.Errorf("failed to list tools for server: %w", err) + } + tools := make([]dbpkg.Tool, len(rows)) + for i, r := range rows { + tools[i] = *toTool(r) + } + return tools, nil +} + +func (c *postgresClient) DeleteToolsForServer(ctx context.Context, serverName, groupKind string) error { + return c.q.SoftDeleteToolsForServer(ctx, dbgen.SoftDeleteToolsForServerParams{ServerName: serverName, GroupKind: groupKind}) +} + +func (c *postgresClient) RefreshToolsForServer(ctx context.Context, serverName, groupKind string, tools ...*v1alpha2.MCPTool) error { + tx, err := c.db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } + defer tx.Rollback() //nolint:errcheck + + q := c.q.WithTx(tx) + + if err := q.SoftDeleteToolsForServer(ctx, dbgen.SoftDeleteToolsForServerParams{ + ServerName: serverName, GroupKind: groupKind, + }); err != nil { + return fmt.Errorf("failed to delete existing tools: %w", err) + } + + for _, tool := range tools { + if err := q.UpsertTool(ctx, dbgen.UpsertToolParams{ + ID: tool.Name, + ServerName: serverName, + GroupKind: groupKind, + Description: sql.NullString{String: tool.Description, Valid: true}, + }); err != nil { + return fmt.Errorf("failed to upsert tool %s: %w", tool.Name, err) + } + } + + return tx.Commit() +} + +func (c *postgresClient) GetToolServer(ctx context.Context, name string) (*dbpkg.ToolServer, error) { + row, err := c.q.GetToolServer(ctx, name) + if err != nil { + return nil, fmt.Errorf("failed to get tool server %s: %w", name, err) + } + return toToolServer(row), nil +} + +func (c *postgresClient) ListToolServers(ctx context.Context) ([]dbpkg.ToolServer, error) { + rows, err := c.q.ListToolServers(ctx) + if err != nil { + return nil, fmt.Errorf("failed to list tool servers: %w", err) + } + servers := make([]dbpkg.ToolServer, len(rows)) + for i, r := range rows { + servers[i] = *toToolServer(r) + } + return servers, nil +} + +func (c *postgresClient) StoreToolServer(ctx context.Context, ts *dbpkg.ToolServer) (*dbpkg.ToolServer, error) { + row, err := c.q.UpsertToolServer(ctx, dbgen.UpsertToolServerParams{ + Name: ts.Name, + GroupKind: ts.GroupKind, + Description: sql.NullString{String: ts.Description, Valid: true}, + LastConnected: ptrToNullTime(ts.LastConnected), + }) + if err != nil { + return nil, fmt.Errorf("failed to store tool server: %w", err) + } + return toToolServer(row), nil +} + +func (c *postgresClient) DeleteToolServer(ctx context.Context, serverName, groupKind string) error { + return c.q.SoftDeleteToolServer(ctx, dbgen.SoftDeleteToolServerParams{Name: serverName, GroupKind: groupKind}) +} + +// ── LangGraph Checkpoints ───────────────────────────────────────────────────── + +func (c *postgresClient) StoreCheckpoint(ctx context.Context, cp *dbpkg.LangGraphCheckpoint) error { + return c.q.UpsertCheckpoint(ctx, dbgen.UpsertCheckpointParams{ + UserID: cp.UserID, + ThreadID: cp.ThreadID, + CheckpointNs: cp.CheckpointNS, + CheckpointID: cp.CheckpointID, + ParentCheckpointID: ptrToNullString(cp.ParentCheckpointID), + Metadata: cp.Metadata, + Checkpoint: cp.Checkpoint, + CheckpointType: cp.CheckpointType, + Version: sql.NullInt32{Int32: cp.Version, Valid: true}, + }) +} + +func (c *postgresClient) StoreCheckpointWrites(ctx context.Context, writes []*dbpkg.LangGraphCheckpointWrite) error { + tx, err := c.db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } + defer tx.Rollback() //nolint:errcheck + + q := c.q.WithTx(tx) + for _, w := range writes { + if err := q.UpsertCheckpointWrite(ctx, dbgen.UpsertCheckpointWriteParams{ + UserID: w.UserID, + ThreadID: w.ThreadID, + CheckpointNs: w.CheckpointNS, + CheckpointID: w.CheckpointID, + WriteIdx: w.WriteIdx, + Value: w.Value, + ValueType: w.ValueType, + Channel: w.Channel, + TaskID: w.TaskID, + }); err != nil { + return fmt.Errorf("failed to store checkpoint write: %w", err) + } + } + return tx.Commit() +} + +func (c *postgresClient) ListCheckpoints(ctx context.Context, userID, threadID, checkpointNS string, checkpointID *string, limit int) ([]*dbpkg.LangGraphCheckpointTuple, error) { + tx, err := c.db.BeginTx(ctx, nil) + if err != nil { + return nil, fmt.Errorf("failed to begin transaction: %w", err) + } + defer tx.Rollback() //nolint:errcheck + + q := c.q.WithTx(tx) + + var checkpoints []dbgen.LgCheckpoint + if checkpointID != nil { + cp, err := q.GetCheckpoint(ctx, dbgen.GetCheckpointParams{ + UserID: userID, ThreadID: threadID, CheckpointNs: checkpointNS, CheckpointID: *checkpointID, + }) + if err != nil { + return nil, fmt.Errorf("failed to get checkpoint: %w", err) + } + checkpoints = []dbgen.LgCheckpoint{cp} + } else if limit > 0 { + checkpoints, err = q.ListCheckpointsLimit(ctx, dbgen.ListCheckpointsLimitParams{ + UserID: userID, ThreadID: threadID, CheckpointNs: checkpointNS, Limit: int32(limit), + }) + if err != nil { + return nil, fmt.Errorf("failed to list checkpoints: %w", err) + } + } else { + checkpoints, err = q.ListCheckpoints(ctx, dbgen.ListCheckpointsParams{ + UserID: userID, ThreadID: threadID, CheckpointNs: checkpointNS, + }) + if err != nil { + return nil, fmt.Errorf("failed to list checkpoints: %w", err) + } + } + + tuples := make([]*dbpkg.LangGraphCheckpointTuple, 0, len(checkpoints)) + for _, cp := range checkpoints { + writes, err := q.ListCheckpointWrites(ctx, dbgen.ListCheckpointWritesParams{ + UserID: userID, ThreadID: threadID, CheckpointNs: checkpointNS, CheckpointID: cp.CheckpointID, + }) + if err != nil { + return nil, fmt.Errorf("failed to get checkpoint writes: %w", err) + } + dbWrites := make([]*dbpkg.LangGraphCheckpointWrite, len(writes)) + for i, w := range writes { + dbWrites[i] = toCheckpointWrite(w) + } + tuples = append(tuples, &dbpkg.LangGraphCheckpointTuple{ + Checkpoint: toCheckpoint(cp), + Writes: dbWrites, + }) + } + + return tuples, tx.Commit() +} + +func (c *postgresClient) DeleteCheckpoint(ctx context.Context, userID, threadID string) error { + tx, err := c.db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } + defer tx.Rollback() //nolint:errcheck + + q := c.q.WithTx(tx) + if err := q.SoftDeleteCheckpoints(ctx, dbgen.SoftDeleteCheckpointsParams{UserID: userID, ThreadID: threadID}); err != nil { + return fmt.Errorf("failed to delete checkpoints: %w", err) + } + if err := q.SoftDeleteCheckpointWrites(ctx, dbgen.SoftDeleteCheckpointWritesParams{UserID: userID, ThreadID: threadID}); err != nil { + return fmt.Errorf("failed to delete checkpoint writes: %w", err) + } + return tx.Commit() +} + +// ── CrewAI ──────────────────────────────────────────────────────────────────── + +func (c *postgresClient) StoreCrewAIMemory(ctx context.Context, memory *dbpkg.CrewAIAgentMemory) error { + return c.q.UpsertCrewAIMemory(ctx, dbgen.UpsertCrewAIMemoryParams{ + UserID: memory.UserID, + ThreadID: memory.ThreadID, + MemoryData: memory.MemoryData, + }) +} + +func (c *postgresClient) SearchCrewAIMemoryByTask(ctx context.Context, userID, threadID, taskDescription string, limit int) ([]*dbpkg.CrewAIAgentMemory, error) { + pattern := "%" + taskDescription + "%" + var rows []dbgen.CrewaiAgentMemory + var err error + + if limit > 0 { + rows, err = c.q.SearchCrewAIMemoryByTaskLimit(ctx, dbgen.SearchCrewAIMemoryByTaskLimitParams{ + UserID: userID, ThreadID: threadID, MemoryData: pattern, Limit: int32(limit), + }) + } else { + rows, err = c.q.SearchCrewAIMemoryByTask(ctx, dbgen.SearchCrewAIMemoryByTaskParams{ + UserID: userID, ThreadID: threadID, MemoryData: pattern, + }) + } + if err != nil { + return nil, fmt.Errorf("failed to search CrewAI memory: %w", err) + } + + result := make([]*dbpkg.CrewAIAgentMemory, len(rows)) + for i, r := range rows { + result[i] = toCrewAIMemory(r) + } + return result, nil +} + +func (c *postgresClient) ResetCrewAIMemory(ctx context.Context, userID, threadID string) error { + return c.q.HardDeleteCrewAIMemory(ctx, dbgen.HardDeleteCrewAIMemoryParams{UserID: userID, ThreadID: threadID}) +} + +func (c *postgresClient) StoreCrewAIFlowState(ctx context.Context, state *dbpkg.CrewAIFlowState) error { + return c.q.UpsertCrewAIFlowState(ctx, dbgen.UpsertCrewAIFlowStateParams{ + UserID: state.UserID, + ThreadID: state.ThreadID, + MethodName: state.MethodName, + StateData: state.StateData, + }) +} + +func (c *postgresClient) GetCrewAIFlowState(ctx context.Context, userID, threadID string) (*dbpkg.CrewAIFlowState, error) { + row, err := c.q.GetLatestCrewAIFlowState(ctx, dbgen.GetLatestCrewAIFlowStateParams{UserID: userID, ThreadID: threadID}) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, fmt.Errorf("failed to get CrewAI flow state: %w", err) + } + return toCrewAIFlowState(row), nil +} + +// ── Agent Memory (vector search) ────────────────────────────────────────────── + +func (c *postgresClient) StoreAgentMemory(ctx context.Context, memory *dbpkg.Memory) error { + id, err := c.q.InsertMemory(ctx, dbgen.InsertMemoryParams{ + AgentName: sql.NullString{String: memory.AgentName, Valid: true}, + UserID: sql.NullString{String: memory.UserID, Valid: true}, + Content: sql.NullString{String: memory.Content, Valid: true}, + Embedding: memory.Embedding, + Metadata: sql.NullString{String: memory.Metadata, Valid: true}, + ExpiresAt: ptrToNullTime(memory.ExpiresAt), + AccessCount: sql.NullInt32{Int32: memory.AccessCount, Valid: true}, + }) + if err != nil { + return err + } + memory.ID = id + return nil +} + +func (c *postgresClient) StoreAgentMemories(ctx context.Context, memories []*dbpkg.Memory) error { + tx, err := c.db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } + defer tx.Rollback() //nolint:errcheck + + q := c.q.WithTx(tx) + for _, m := range memories { + id, err := q.InsertMemory(ctx, dbgen.InsertMemoryParams{ + AgentName: sql.NullString{String: m.AgentName, Valid: true}, + UserID: sql.NullString{String: m.UserID, Valid: true}, + Content: sql.NullString{String: m.Content, Valid: true}, + Embedding: m.Embedding, + Metadata: sql.NullString{String: m.Metadata, Valid: true}, + ExpiresAt: ptrToNullTime(m.ExpiresAt), + AccessCount: sql.NullInt32{Int32: m.AccessCount, Valid: true}, + }) + if err != nil { + return fmt.Errorf("failed to store memory: %w", err) + } + m.ID = id + } + return tx.Commit() +} + +func (c *postgresClient) SearchAgentMemory(ctx context.Context, agentName, userID string, embedding pgvector.Vector, limit int) ([]dbpkg.AgentMemorySearchResult, error) { + rows, err := c.q.SearchAgentMemory(ctx, dbgen.SearchAgentMemoryParams{ + Embedding: embedding, + AgentName: sql.NullString{String: agentName, Valid: true}, + UserID: sql.NullString{String: userID, Valid: true}, + Limit: int32(limit), + }) + if err != nil { + return nil, fmt.Errorf("failed to search agent memory: %w", err) + } + + results := make([]dbpkg.AgentMemorySearchResult, len(rows)) + for i, r := range rows { + score, _ := r.Score.(float64) + results[i] = dbpkg.AgentMemorySearchResult{ + Memory: dbpkg.Memory{ + ID: r.ID, + AgentName: r.AgentName.String, + UserID: r.UserID.String, + Content: r.Content.String, + Embedding: r.Embedding, + Metadata: r.Metadata.String, + CreatedAt: r.CreatedAt, + ExpiresAt: nullTimeToPtr(r.ExpiresAt), + AccessCount: nullInt32ToVal(r.AccessCount), + }, + Score: score, + } + } + + if len(results) > 0 { + ids := make([]string, len(results)) + for i, r := range results { + ids[i] = r.ID + } + if err := c.q.IncrementMemoryAccessCount(ctx, ids); err != nil { + return nil, fmt.Errorf("failed to increment access count: %w", err) + } + } + + return results, nil +} + +func (c *postgresClient) ListAgentMemories(ctx context.Context, agentName, userID string) ([]dbpkg.Memory, error) { + normalized := strings.ReplaceAll(agentName, "-", "_") + rows, err := c.q.ListAgentMemories(ctx, dbgen.ListAgentMemoriesParams{ + AgentName: sql.NullString{String: agentName, Valid: true}, + AgentName_2: sql.NullString{String: normalized, Valid: true}, + UserID: sql.NullString{String: userID, Valid: true}, + }) + if err != nil { + return nil, fmt.Errorf("failed to list agent memories: %w", err) + } + memories := make([]dbpkg.Memory, len(rows)) + for i, r := range rows { + memories[i] = *toMemory(r) + } + return memories, nil +} + +func (c *postgresClient) DeleteAgentMemory(ctx context.Context, agentName, userID string) error { + if err := c.q.DeleteAgentMemory(ctx, dbgen.DeleteAgentMemoryParams{ + AgentName: sql.NullString{String: agentName, Valid: true}, + UserID: sql.NullString{String: userID, Valid: true}, + }); err != nil { + return fmt.Errorf("failed to delete agent memory: %w", err) + } + normalized := strings.ReplaceAll(agentName, "-", "_") + if normalized != agentName { + if err := c.q.DeleteAgentMemory(ctx, dbgen.DeleteAgentMemoryParams{ + AgentName: sql.NullString{String: normalized, Valid: true}, + UserID: sql.NullString{String: userID, Valid: true}, + }); err != nil { + return fmt.Errorf("failed to delete normalized agent memory: %w", err) + } + } + return nil +} + +func (c *postgresClient) PruneExpiredMemories(ctx context.Context) error { + tx, err := c.db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } + defer tx.Rollback() //nolint:errcheck + + q := c.q.WithTx(tx) + if err := q.ExtendMemoryTTL(ctx); err != nil { + return fmt.Errorf("failed to extend TTL for popular memories: %w", err) + } + if err := q.DeleteExpiredMemories(ctx); err != nil { + return fmt.Errorf("failed to delete expired memories: %w", err) + } + return tx.Commit() +} + +// ── Conversion helpers ──────────────────────────────────────────────────────── + +func toAgent(r dbgen.Agent) *dbpkg.Agent { + return &dbpkg.Agent{ + ID: r.ID, + CreatedAt: r.CreatedAt, + UpdatedAt: r.UpdatedAt, + DeletedAt: nullTimeToPtr(r.DeletedAt), + Type: r.Type, + Config: r.Config, + } +} + +func toSession(r dbgen.Session) *dbpkg.Session { + s := &dbpkg.Session{ + ID: r.ID, + UserID: r.UserID, + Name: nullStringToPtr(r.Name), + CreatedAt: r.CreatedAt, + UpdatedAt: r.UpdatedAt, + DeletedAt: nullTimeToPtr(r.DeletedAt), + AgentID: nullStringToPtr(r.AgentID), + } + if r.Source.Valid { + src := dbpkg.SessionSource(r.Source.String) + s.Source = &src + } + return s +} + +func toEvent(r dbgen.Event) *dbpkg.Event { + return &dbpkg.Event{ + ID: r.ID, + UserID: r.UserID, + SessionID: r.SessionID.String, + CreatedAt: r.CreatedAt, + UpdatedAt: r.UpdatedAt, + DeletedAt: nullTimeToPtr(r.DeletedAt), + Data: r.Data, + } +} + +func toTask(r dbgen.Task) *dbpkg.Task { + return &dbpkg.Task{ + ID: r.ID, + CreatedAt: r.CreatedAt, + UpdatedAt: r.UpdatedAt, + DeletedAt: nullTimeToPtr(r.DeletedAt), + Data: r.Data, + SessionID: r.SessionID.String, + } +} + +func toFeedback(r dbgen.Feedback) *dbpkg.Feedback { + return &dbpkg.Feedback{ + ID: r.ID, + CreatedAt: nullTimeToPtr(r.CreatedAt), + UpdatedAt: nullTimeToPtr(r.UpdatedAt), + DeletedAt: nullTimeToPtr(r.DeletedAt), + UserID: r.UserID, + MessageID: nullInt64ToPtr(r.MessageID), + IsPositive: r.IsPositive.Bool, + FeedbackText: r.FeedbackText, + IssueType: r.IssueType, + } +} + +func toTool(r dbgen.Tool) *dbpkg.Tool { + return &dbpkg.Tool{ + ID: r.ID, + ServerName: r.ServerName, + GroupKind: r.GroupKind, + CreatedAt: r.CreatedAt, + UpdatedAt: r.UpdatedAt, + DeletedAt: nullTimeToPtr(r.DeletedAt), + Description: r.Description.String, + } +} + +func toToolServer(r dbgen.Toolserver) *dbpkg.ToolServer { + return &dbpkg.ToolServer{ + Name: r.Name, + GroupKind: r.GroupKind, + CreatedAt: r.CreatedAt, + UpdatedAt: r.UpdatedAt, + DeletedAt: nullTimeToPtr(r.DeletedAt), + Description: r.Description.String, + LastConnected: nullTimeToPtr(r.LastConnected), + } +} + +func toCheckpoint(r dbgen.LgCheckpoint) *dbpkg.LangGraphCheckpoint { + return &dbpkg.LangGraphCheckpoint{ + UserID: r.UserID, + ThreadID: r.ThreadID, + CheckpointNS: r.CheckpointNs, + CheckpointID: r.CheckpointID, + ParentCheckpointID: nullStringToPtr(r.ParentCheckpointID), + CreatedAt: r.CreatedAt, + UpdatedAt: r.UpdatedAt, + DeletedAt: nullTimeToPtr(r.DeletedAt), + Metadata: r.Metadata, + Checkpoint: r.Checkpoint, + CheckpointType: r.CheckpointType, + Version: r.Version.Int32, + } +} + +func toCheckpointWrite(r dbgen.LgCheckpointWrite) *dbpkg.LangGraphCheckpointWrite { + return &dbpkg.LangGraphCheckpointWrite{ + UserID: r.UserID, + ThreadID: r.ThreadID, + CheckpointNS: r.CheckpointNs, + CheckpointID: r.CheckpointID, + WriteIdx: r.WriteIdx, + Value: r.Value, + ValueType: r.ValueType, + Channel: r.Channel, + TaskID: r.TaskID, + CreatedAt: r.CreatedAt, + UpdatedAt: r.UpdatedAt, + DeletedAt: nullTimeToPtr(r.DeletedAt), + } +} + +func toCrewAIMemory(r dbgen.CrewaiAgentMemory) *dbpkg.CrewAIAgentMemory { + return &dbpkg.CrewAIAgentMemory{ + UserID: r.UserID, + ThreadID: r.ThreadID, + CreatedAt: r.CreatedAt, + UpdatedAt: r.UpdatedAt, + DeletedAt: nullTimeToPtr(r.DeletedAt), + MemoryData: r.MemoryData, + } +} + +func toCrewAIFlowState(r dbgen.CrewaiFlowState) *dbpkg.CrewAIFlowState { + return &dbpkg.CrewAIFlowState{ + UserID: r.UserID, + ThreadID: r.ThreadID, + MethodName: r.MethodName, + CreatedAt: r.CreatedAt, + UpdatedAt: r.UpdatedAt, + DeletedAt: nullTimeToPtr(r.DeletedAt), + StateData: r.StateData, + } +} + +func toMemory(r dbgen.Memory) *dbpkg.Memory { + return &dbpkg.Memory{ + ID: r.ID, + AgentName: r.AgentName.String, + UserID: r.UserID.String, + Content: r.Content.String, + Embedding: r.Embedding, + Metadata: r.Metadata.String, + CreatedAt: r.CreatedAt, + ExpiresAt: nullTimeToPtr(r.ExpiresAt), + AccessCount: nullInt32ToVal(r.AccessCount), + } +} + +// ── sql.Null* helpers ───────────────────────────────────────────────────────── + +func nullStringToPtr(s sql.NullString) *string { + if s.Valid { + return &s.String + } + return nil +} + +func nullTimeToPtr(t sql.NullTime) *time.Time { + if t.Valid { + return &t.Time + } + return nil +} + +func nullInt64ToPtr(n sql.NullInt64) *int64 { + if n.Valid { + return &n.Int64 + } + return nil +} + +func nullInt32ToVal(n sql.NullInt32) int32 { + return n.Int32 +} + +func ptrToNullString(s *string) sql.NullString { + if s != nil { + return sql.NullString{String: *s, Valid: true} + } + return sql.NullString{} +} + +func ptrToNullTime(t *time.Time) sql.NullTime { + if t != nil { + return sql.NullTime{Time: *t, Valid: true} + } + return sql.NullTime{} +} + +func ptrToNullInt64(n *int64) sql.NullInt64 { + if n != nil { + return sql.NullInt64{Int64: *n, Valid: true} + } + return sql.NullInt64{} +} diff --git a/go/core/internal/database/client_test.go b/go/core/internal/database/client_test.go index 12561ba97..e24639983 100644 --- a/go/core/internal/database/client_test.go +++ b/go/core/internal/database/client_test.go @@ -2,6 +2,7 @@ package database import ( "context" + "database/sql" "fmt" "sync" "testing" @@ -9,6 +10,7 @@ import ( dbpkg "github.com/kagent-dev/kagent/go/api/database" "github.com/kagent-dev/kagent/go/api/v1alpha2" + "github.com/kagent-dev/kagent/go/core/internal/dbtest" "github.com/pgvector/pgvector-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -207,16 +209,17 @@ func TestStoreToolServerIdempotence(t *testing.T) { assert.Equal(t, "Updated description", retrieved.Description) } -// setupTestDB resets the shared Postgres manager's tables for test isolation. -func setupTestDB(t *testing.T) *Manager { +// setupTestDB resets the shared Postgres database's tables for test isolation. +func setupTestDB(t *testing.T) *sql.DB { t.Helper() if testing.Short() { t.Skip("skipping database test in short mode") } - require.NoError(t, sharedManager.Reset(true), "Failed to reset test database") + require.NoError(t, dbtest.MigrateDown(sharedConnStr, true), "Failed to reset test database (down)") + require.NoError(t, dbtest.Migrate(sharedConnStr, true), "Failed to reset test database (up)") - return sharedManager + return sharedDB } func TestListEventsForSession(t *testing.T) { db := setupTestDB(t) @@ -439,23 +442,15 @@ func TestSearchAgentMemoryIsolation(t *testing.T) { client := NewClient(db) ctx := context.Background() - require.NoError(t, client.StoreAgentMemory(ctx, &dbpkg.Memory{ - ID: "iso-1", AgentName: "agent-a", UserID: "user-1", - Content: "agent-a user-1 memory", Embedding: makeEmbedding(0.5), - })) - require.NoError(t, client.StoreAgentMemory(ctx, &dbpkg.Memory{ - ID: "iso-2", AgentName: "agent-b", UserID: "user-1", - Content: "agent-b user-1 memory", Embedding: makeEmbedding(0.5), - })) - require.NoError(t, client.StoreAgentMemory(ctx, &dbpkg.Memory{ - ID: "iso-3", AgentName: "agent-a", UserID: "user-2", - Content: "agent-a user-2 memory", Embedding: makeEmbedding(0.5), - })) + mem1 := &dbpkg.Memory{AgentName: "agent-a", UserID: "user-1", Content: "agent-a user-1 memory", Embedding: makeEmbedding(0.5)} + require.NoError(t, client.StoreAgentMemory(ctx, mem1)) + require.NoError(t, client.StoreAgentMemory(ctx, &dbpkg.Memory{AgentName: "agent-b", UserID: "user-1", Content: "agent-b user-1 memory", Embedding: makeEmbedding(0.5)})) + require.NoError(t, client.StoreAgentMemory(ctx, &dbpkg.Memory{AgentName: "agent-a", UserID: "user-2", Content: "agent-a user-2 memory", Embedding: makeEmbedding(0.5)})) results, err := client.SearchAgentMemory(ctx, "agent-a", "user-1", makeEmbedding(0.5), 10) require.NoError(t, err) require.Len(t, results, 1, "Should only return memories for agent-a / user-1") - assert.Equal(t, "iso-1", results[0].ID) + assert.Equal(t, mem1.ID, results[0].ID) } // TestDeleteAgentMemory verifies that DeleteAgentMemory removes all memories for the @@ -505,38 +500,17 @@ func TestPruneExpiredMemories(t *testing.T) { past := time.Now().Add(-1 * time.Hour) // Memory that is expired and unpopular — should be deleted - require.NoError(t, client.StoreAgentMemory(ctx, &dbpkg.Memory{ - ID: "prune-cold", - AgentName: agentName, - UserID: userID, - Content: "cold expired memory", - Embedding: makeEmbedding(0.1), - ExpiresAt: &past, - AccessCount: 2, - })) + coldMem := &dbpkg.Memory{AgentName: agentName, UserID: userID, Content: "cold expired memory", Embedding: makeEmbedding(0.1), ExpiresAt: &past, AccessCount: 2} + require.NoError(t, client.StoreAgentMemory(ctx, coldMem)) // Memory that is expired but popular (AccessCount >= 10) — TTL should be extended - require.NoError(t, client.StoreAgentMemory(ctx, &dbpkg.Memory{ - ID: "prune-hot", - AgentName: agentName, - UserID: userID, - Content: "hot expired memory", - Embedding: makeEmbedding(0.9), - ExpiresAt: &past, - AccessCount: 15, - })) + hotMem := &dbpkg.Memory{AgentName: agentName, UserID: userID, Content: "hot expired memory", Embedding: makeEmbedding(0.9), ExpiresAt: &past, AccessCount: 15} + require.NoError(t, client.StoreAgentMemory(ctx, hotMem)) // Memory that has not expired — should be untouched future := time.Now().Add(24 * time.Hour) - require.NoError(t, client.StoreAgentMemory(ctx, &dbpkg.Memory{ - ID: "prune-live", - AgentName: agentName, - UserID: userID, - Content: "non-expired memory", - Embedding: makeEmbedding(0.5), - ExpiresAt: &future, - AccessCount: 0, - })) + liveMem := &dbpkg.Memory{AgentName: agentName, UserID: userID, Content: "non-expired memory", Embedding: makeEmbedding(0.5), ExpiresAt: &future, AccessCount: 0} + require.NoError(t, client.StoreAgentMemory(ctx, liveMem)) err := client.PruneExpiredMemories(ctx) require.NoError(t, err) @@ -549,7 +523,7 @@ func TestPruneExpiredMemories(t *testing.T) { ids = append(ids, r.ID) } - assert.NotContains(t, ids, "prune-cold", "Expired unpopular memory should be pruned") - assert.Contains(t, ids, "prune-hot", "Expired popular memory should have TTL extended and be retained") - assert.Contains(t, ids, "prune-live", "Non-expired memory should be retained") + assert.NotContains(t, ids, coldMem.ID, "Expired unpopular memory should be pruned") + assert.Contains(t, ids, hotMem.ID, "Expired popular memory should have TTL extended and be retained") + assert.Contains(t, ids, liveMem.ID, "Non-expired memory should be retained") } diff --git a/go/core/internal/database/connect.go b/go/core/internal/database/connect.go new file mode 100644 index 000000000..169207073 --- /dev/null +++ b/go/core/internal/database/connect.go @@ -0,0 +1,87 @@ +package database + +import ( + "context" + "database/sql" + "fmt" + "log" + "os" + "strings" + "time" + + _ "github.com/lib/pq" +) + +// PostgresConfig holds the connection parameters for a Postgres database. +type PostgresConfig struct { + URL string + URLFile string + VectorEnabled bool +} + +const ( + defaultMaxTimeout = 120 * time.Second + defaultInitialDelay = 500 * time.Millisecond + defaultMaxDelay = 5 * time.Second +) + +// Connect opens a Postgres connection using cfg, resolving the URL from a file +// if URLFile is set, and retries PingContext with exponential backoff until the +// connection succeeds or defaultMaxTimeout elapses. +func Connect(ctx context.Context, cfg *PostgresConfig) (*sql.DB, error) { + url := cfg.URL + if cfg.URLFile != "" { + resolved, err := resolveURLFile(cfg.URLFile) + if err != nil { + return nil, fmt.Errorf("failed to resolve postgres URL from file: %w", err) + } + url = resolved + } + return retryDBConnection(ctx, url) +} + +// retryDBConnection opens a database connection and retries PingContext with +// exponential backoff until the connection succeeds or defaultMaxTimeout elapses. +func retryDBConnection(ctx context.Context, url string) (*sql.DB, error) { + ctx, cancel := context.WithTimeout(ctx, defaultMaxTimeout) + defer cancel() + + db, err := sql.Open("postgres", url) + if err != nil { + return nil, fmt.Errorf("failed to open database: %w", err) + } + + start := time.Now() + delay := defaultInitialDelay + for attempt := 1; ; attempt++ { + if err := db.PingContext(ctx); err == nil { + return db, nil + } else { + log.Printf("database not ready (attempt %d, elapsed %s): %v", attempt, time.Since(start).Round(time.Second), err) + } + select { + case <-ctx.Done(): + _ = db.Close() + return nil, fmt.Errorf("database not ready after %s: %w", time.Since(start).Round(time.Second), ctx.Err()) + case <-time.After(delay): + } + delay *= 2 + if delay > defaultMaxDelay { + delay = defaultMaxDelay + } + } +} + +// resolveURLFile reads a database connection URL from a file and returns the +// trimmed contents. Returns an error if the file cannot be read or is empty. +func resolveURLFile(path string) (string, error) { + content, err := os.ReadFile(path) + if err != nil { + return "", fmt.Errorf("reading URL file: %w", err) + } + url := strings.TrimSpace(string(content)) + if url == "" { + return "", fmt.Errorf("URL file %s is empty or contains only whitespace", path) + } + return url, nil +} diff --git a/go/core/internal/database/manager_test.go b/go/core/internal/database/connect_test.go similarity index 92% rename from go/core/internal/database/manager_test.go rename to go/core/internal/database/connect_test.go index 925dcb1d4..ab60bb1ab 100644 --- a/go/core/internal/database/manager_test.go +++ b/go/core/internal/database/connect_test.go @@ -8,15 +8,13 @@ import ( "time" "github.com/stretchr/testify/assert" - "gorm.io/gorm" ) func TestRetryDBConnection_DeadlineExceeded(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() - // Use an unreachable address so every attempt fails immediately. - _, err := retryDBConnection(ctx, "postgres://user:pass@localhost:1/nodb?connect_timeout=1", &gorm.Config{}) + _, err := retryDBConnection(ctx, "postgres://user:pass@localhost:1/nodb?connect_timeout=1") assert.ErrorIs(t, err, context.DeadlineExceeded) } diff --git a/go/core/internal/database/fake/client.go b/go/core/internal/database/fake/client.go index 11157c478..838f59c6c 100644 --- a/go/core/internal/database/fake/client.go +++ b/go/core/internal/database/fake/client.go @@ -2,6 +2,7 @@ package fake import ( "context" + "database/sql" "encoding/json" "fmt" "math" @@ -13,7 +14,6 @@ import ( "github.com/kagent-dev/kagent/go/api/database" "github.com/kagent-dev/kagent/go/api/v1alpha2" "github.com/pgvector/pgvector-go" - "gorm.io/gorm" "trpc.group/trpc-go/trpc-a2a-go/protocol" ) @@ -83,7 +83,7 @@ func (c *InMemoryFakeClient) GetTask(_ context.Context, taskID string) (*protoco task, exists := c.tasks[taskID] if !exists { - return nil, gorm.ErrRecordNotFound + return nil, sql.ErrNoRows } parsedTask := &protocol.Task{} err := json.Unmarshal([]byte(task.Data), parsedTask) @@ -108,10 +108,11 @@ func (c *InMemoryFakeClient) StoreFeedback(_ context.Context, feedback *database // Copy the feedback and assign an ID newFeedback := *feedback - newFeedback.MessageID = uint(c.nextFeedbackID) + id := int64(c.nextFeedbackID) + newFeedback.MessageID = &id c.nextFeedbackID++ - key := fmt.Sprintf("%d", newFeedback.MessageID) + key := fmt.Sprintf("%d", id) c.feedback[key] = &newFeedback return nil } @@ -208,7 +209,7 @@ func (c *InMemoryFakeClient) DeleteAgent(_ context.Context, agentName string) er _, exists := c.agents[agentName] if !exists { - return gorm.ErrRecordNotFound + return sql.ErrNoRows } delete(c.agents, agentName) @@ -247,7 +248,7 @@ func (c *InMemoryFakeClient) GetSession(_ context.Context, sessionID string, use key := c.sessionKey(sessionID, userID) session, exists := c.sessions[key] if !exists { - return nil, gorm.ErrRecordNotFound + return nil, sql.ErrNoRows } return session, nil } @@ -259,7 +260,7 @@ func (c *InMemoryFakeClient) GetAgent(_ context.Context, agentName string) (*dat agent, exists := c.agents[agentName] if !exists { - return nil, gorm.ErrRecordNotFound + return nil, sql.ErrNoRows } return agent, nil } @@ -271,7 +272,7 @@ func (c *InMemoryFakeClient) GetTool(_ context.Context, toolName string) (*datab tool, exists := c.tools[toolName] if !exists { - return nil, gorm.ErrRecordNotFound + return nil, sql.ErrNoRows } return tool, nil } @@ -283,7 +284,7 @@ func (c *InMemoryFakeClient) GetToolServer(_ context.Context, serverName string) server, exists := c.toolServers[serverName] if !exists { - return nil, gorm.ErrRecordNotFound + return nil, sql.ErrNoRows } return server, nil } diff --git a/go/core/internal/database/gen/agents.sql.go b/go/core/internal/database/gen/agents.sql.go new file mode 100644 index 000000000..61975dd4f --- /dev/null +++ b/go/core/internal/database/gen/agents.sql.go @@ -0,0 +1,98 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: agents.sql + +package dbgen + +import ( + "context" + + "github.com/kagent-dev/kagent/go/api/adk" +) + +const getAgent = `-- name: GetAgent :one +SELECT id, created_at, updated_at, deleted_at, type, config FROM agent +WHERE id = $1 AND deleted_at IS NULL +LIMIT 1 +` + +func (q *Queries) GetAgent(ctx context.Context, id string) (Agent, error) { + row := q.db.QueryRowContext(ctx, getAgent, id) + var i Agent + err := row.Scan( + &i.ID, + &i.CreatedAt, + &i.UpdatedAt, + &i.DeletedAt, + &i.Type, + &i.Config, + ) + return i, err +} + +const listAgents = `-- name: ListAgents :many +SELECT id, created_at, updated_at, deleted_at, type, config FROM agent +WHERE deleted_at IS NULL +ORDER BY created_at ASC +` + +func (q *Queries) ListAgents(ctx context.Context) ([]Agent, error) { + rows, err := q.db.QueryContext(ctx, listAgents) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Agent + for rows.Next() { + var i Agent + if err := rows.Scan( + &i.ID, + &i.CreatedAt, + &i.UpdatedAt, + &i.DeletedAt, + &i.Type, + &i.Config, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const softDeleteAgent = `-- name: SoftDeleteAgent :exec +UPDATE agent SET deleted_at = NOW() WHERE id = $1 AND deleted_at IS NULL +` + +func (q *Queries) SoftDeleteAgent(ctx context.Context, id string) error { + _, err := q.db.ExecContext(ctx, softDeleteAgent, id) + return err +} + +const upsertAgent = `-- name: UpsertAgent :exec +INSERT INTO agent (id, type, config, created_at, updated_at) +VALUES ($1, $2, $3, NOW(), NOW()) +ON CONFLICT (id) DO UPDATE SET + type = EXCLUDED.type, + config = EXCLUDED.config, + updated_at = NOW(), + deleted_at = NULL +` + +type UpsertAgentParams struct { + ID string + Type string + Config *adk.AgentConfig +} + +func (q *Queries) UpsertAgent(ctx context.Context, arg UpsertAgentParams) error { + _, err := q.db.ExecContext(ctx, upsertAgent, arg.ID, arg.Type, arg.Config) + return err +} diff --git a/go/core/internal/database/gen/crewai.sql.go b/go/core/internal/database/gen/crewai.sql.go new file mode 100644 index 000000000..a3fd02a75 --- /dev/null +++ b/go/core/internal/database/gen/crewai.sql.go @@ -0,0 +1,191 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: crewai.sql + +package dbgen + +import ( + "context" +) + +const getLatestCrewAIFlowState = `-- name: GetLatestCrewAIFlowState :one +SELECT user_id, thread_id, method_name, created_at, updated_at, deleted_at, state_data FROM crewai_flow_state +WHERE user_id = $1 AND thread_id = $2 AND deleted_at IS NULL +ORDER BY created_at DESC +LIMIT 1 +` + +type GetLatestCrewAIFlowStateParams struct { + UserID string + ThreadID string +} + +func (q *Queries) GetLatestCrewAIFlowState(ctx context.Context, arg GetLatestCrewAIFlowStateParams) (CrewaiFlowState, error) { + row := q.db.QueryRowContext(ctx, getLatestCrewAIFlowState, arg.UserID, arg.ThreadID) + var i CrewaiFlowState + err := row.Scan( + &i.UserID, + &i.ThreadID, + &i.MethodName, + &i.CreatedAt, + &i.UpdatedAt, + &i.DeletedAt, + &i.StateData, + ) + return i, err +} + +const hardDeleteCrewAIMemory = `-- name: HardDeleteCrewAIMemory :exec +DELETE FROM crewai_agent_memory +WHERE user_id = $1 AND thread_id = $2 +` + +type HardDeleteCrewAIMemoryParams struct { + UserID string + ThreadID string +} + +func (q *Queries) HardDeleteCrewAIMemory(ctx context.Context, arg HardDeleteCrewAIMemoryParams) error { + _, err := q.db.ExecContext(ctx, hardDeleteCrewAIMemory, arg.UserID, arg.ThreadID) + return err +} + +const searchCrewAIMemoryByTask = `-- name: SearchCrewAIMemoryByTask :many +SELECT user_id, thread_id, created_at, updated_at, deleted_at, memory_data FROM crewai_agent_memory +WHERE user_id = $1 AND thread_id = $2 AND deleted_at IS NULL + AND (memory_data ILIKE $3 OR (memory_data::jsonb)->>'task_description' ILIKE $3) +ORDER BY created_at DESC, (memory_data::jsonb)->>'score' ASC +` + +type SearchCrewAIMemoryByTaskParams struct { + UserID string + ThreadID string + MemoryData string +} + +func (q *Queries) SearchCrewAIMemoryByTask(ctx context.Context, arg SearchCrewAIMemoryByTaskParams) ([]CrewaiAgentMemory, error) { + rows, err := q.db.QueryContext(ctx, searchCrewAIMemoryByTask, arg.UserID, arg.ThreadID, arg.MemoryData) + if err != nil { + return nil, err + } + defer rows.Close() + var items []CrewaiAgentMemory + for rows.Next() { + var i CrewaiAgentMemory + if err := rows.Scan( + &i.UserID, + &i.ThreadID, + &i.CreatedAt, + &i.UpdatedAt, + &i.DeletedAt, + &i.MemoryData, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const searchCrewAIMemoryByTaskLimit = `-- name: SearchCrewAIMemoryByTaskLimit :many +SELECT user_id, thread_id, created_at, updated_at, deleted_at, memory_data FROM crewai_agent_memory +WHERE user_id = $1 AND thread_id = $2 AND deleted_at IS NULL + AND (memory_data ILIKE $3 OR (memory_data::jsonb)->>'task_description' ILIKE $3) +ORDER BY created_at DESC, (memory_data::jsonb)->>'score' ASC +LIMIT $4 +` + +type SearchCrewAIMemoryByTaskLimitParams struct { + UserID string + ThreadID string + MemoryData string + Limit int32 +} + +func (q *Queries) SearchCrewAIMemoryByTaskLimit(ctx context.Context, arg SearchCrewAIMemoryByTaskLimitParams) ([]CrewaiAgentMemory, error) { + rows, err := q.db.QueryContext(ctx, searchCrewAIMemoryByTaskLimit, + arg.UserID, + arg.ThreadID, + arg.MemoryData, + arg.Limit, + ) + if err != nil { + return nil, err + } + defer rows.Close() + var items []CrewaiAgentMemory + for rows.Next() { + var i CrewaiAgentMemory + if err := rows.Scan( + &i.UserID, + &i.ThreadID, + &i.CreatedAt, + &i.UpdatedAt, + &i.DeletedAt, + &i.MemoryData, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const upsertCrewAIFlowState = `-- name: UpsertCrewAIFlowState :exec +INSERT INTO crewai_flow_state (user_id, thread_id, method_name, state_data, created_at, updated_at) +VALUES ($1, $2, $3, $4, NOW(), NOW()) +ON CONFLICT (user_id, thread_id, method_name) DO UPDATE SET + state_data = EXCLUDED.state_data, + updated_at = NOW(), + deleted_at = NULL +` + +type UpsertCrewAIFlowStateParams struct { + UserID string + ThreadID string + MethodName string + StateData string +} + +func (q *Queries) UpsertCrewAIFlowState(ctx context.Context, arg UpsertCrewAIFlowStateParams) error { + _, err := q.db.ExecContext(ctx, upsertCrewAIFlowState, + arg.UserID, + arg.ThreadID, + arg.MethodName, + arg.StateData, + ) + return err +} + +const upsertCrewAIMemory = `-- name: UpsertCrewAIMemory :exec +INSERT INTO crewai_agent_memory (user_id, thread_id, memory_data, created_at, updated_at) +VALUES ($1, $2, $3, NOW(), NOW()) +ON CONFLICT (user_id, thread_id) DO UPDATE SET + memory_data = EXCLUDED.memory_data, + updated_at = NOW(), + deleted_at = NULL +` + +type UpsertCrewAIMemoryParams struct { + UserID string + ThreadID string + MemoryData string +} + +func (q *Queries) UpsertCrewAIMemory(ctx context.Context, arg UpsertCrewAIMemoryParams) error { + _, err := q.db.ExecContext(ctx, upsertCrewAIMemory, arg.UserID, arg.ThreadID, arg.MemoryData) + return err +} diff --git a/go/core/internal/database/gen/db.go b/go/core/internal/database/gen/db.go new file mode 100644 index 000000000..d0d3db9fb --- /dev/null +++ b/go/core/internal/database/gen/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package dbgen + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/go/core/internal/database/gen/events.sql.go b/go/core/internal/database/gen/events.sql.go new file mode 100644 index 000000000..bc029f8a8 --- /dev/null +++ b/go/core/internal/database/gen/events.sql.go @@ -0,0 +1,340 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: events.sql + +package dbgen + +import ( + "context" + "database/sql" + "time" +) + +const getEvent = `-- name: GetEvent :one +SELECT id, user_id, session_id, created_at, updated_at, deleted_at, data FROM event +WHERE id = $1 AND user_id = $2 AND deleted_at IS NULL +LIMIT 1 +` + +type GetEventParams struct { + ID string + UserID string +} + +func (q *Queries) GetEvent(ctx context.Context, arg GetEventParams) (Event, error) { + row := q.db.QueryRowContext(ctx, getEvent, arg.ID, arg.UserID) + var i Event + err := row.Scan( + &i.ID, + &i.UserID, + &i.SessionID, + &i.CreatedAt, + &i.UpdatedAt, + &i.DeletedAt, + &i.Data, + ) + return i, err +} + +const insertEvent = `-- name: InsertEvent :exec +INSERT INTO event (id, user_id, session_id, data, created_at, updated_at) +VALUES ($1, $2, $3, $4, NOW(), NOW()) +` + +type InsertEventParams struct { + ID string + UserID string + SessionID sql.NullString + Data string +} + +func (q *Queries) InsertEvent(ctx context.Context, arg InsertEventParams) error { + _, err := q.db.ExecContext(ctx, insertEvent, + arg.ID, + arg.UserID, + arg.SessionID, + arg.Data, + ) + return err +} + +const listEventsByContextID = `-- name: ListEventsByContextID :many +SELECT id, user_id, session_id, created_at, updated_at, deleted_at, data FROM event +WHERE session_id = $1 AND deleted_at IS NULL +ORDER BY created_at DESC +` + +func (q *Queries) ListEventsByContextID(ctx context.Context, sessionID sql.NullString) ([]Event, error) { + rows, err := q.db.QueryContext(ctx, listEventsByContextID, sessionID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Event + for rows.Next() { + var i Event + if err := rows.Scan( + &i.ID, + &i.UserID, + &i.SessionID, + &i.CreatedAt, + &i.UpdatedAt, + &i.DeletedAt, + &i.Data, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listEventsByContextIDLimit = `-- name: ListEventsByContextIDLimit :many +SELECT id, user_id, session_id, created_at, updated_at, deleted_at, data FROM event +WHERE session_id = $1 AND deleted_at IS NULL +ORDER BY created_at DESC +LIMIT $2 +` + +type ListEventsByContextIDLimitParams struct { + SessionID sql.NullString + Limit int32 +} + +func (q *Queries) ListEventsByContextIDLimit(ctx context.Context, arg ListEventsByContextIDLimitParams) ([]Event, error) { + rows, err := q.db.QueryContext(ctx, listEventsByContextIDLimit, arg.SessionID, arg.Limit) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Event + for rows.Next() { + var i Event + if err := rows.Scan( + &i.ID, + &i.UserID, + &i.SessionID, + &i.CreatedAt, + &i.UpdatedAt, + &i.DeletedAt, + &i.Data, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listEventsForSessionAsc = `-- name: ListEventsForSessionAsc :many +SELECT id, user_id, session_id, created_at, updated_at, deleted_at, data FROM event +WHERE session_id = $1 AND user_id = $2 AND deleted_at IS NULL + AND ($3::timestamptz IS NULL OR created_at > $3) +ORDER BY created_at ASC +` + +type ListEventsForSessionAscParams struct { + SessionID sql.NullString + UserID string + Column3 time.Time +} + +func (q *Queries) ListEventsForSessionAsc(ctx context.Context, arg ListEventsForSessionAscParams) ([]Event, error) { + rows, err := q.db.QueryContext(ctx, listEventsForSessionAsc, arg.SessionID, arg.UserID, arg.Column3) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Event + for rows.Next() { + var i Event + if err := rows.Scan( + &i.ID, + &i.UserID, + &i.SessionID, + &i.CreatedAt, + &i.UpdatedAt, + &i.DeletedAt, + &i.Data, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listEventsForSessionAscLimit = `-- name: ListEventsForSessionAscLimit :many +SELECT id, user_id, session_id, created_at, updated_at, deleted_at, data FROM event +WHERE session_id = $1 AND user_id = $2 AND deleted_at IS NULL + AND ($3::timestamptz IS NULL OR created_at > $3) +ORDER BY created_at ASC +LIMIT $4 +` + +type ListEventsForSessionAscLimitParams struct { + SessionID sql.NullString + UserID string + Column3 time.Time + Limit int32 +} + +func (q *Queries) ListEventsForSessionAscLimit(ctx context.Context, arg ListEventsForSessionAscLimitParams) ([]Event, error) { + rows, err := q.db.QueryContext(ctx, listEventsForSessionAscLimit, + arg.SessionID, + arg.UserID, + arg.Column3, + arg.Limit, + ) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Event + for rows.Next() { + var i Event + if err := rows.Scan( + &i.ID, + &i.UserID, + &i.SessionID, + &i.CreatedAt, + &i.UpdatedAt, + &i.DeletedAt, + &i.Data, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listEventsForSessionDesc = `-- name: ListEventsForSessionDesc :many +SELECT id, user_id, session_id, created_at, updated_at, deleted_at, data FROM event +WHERE session_id = $1 AND user_id = $2 AND deleted_at IS NULL + AND ($3::timestamptz IS NULL OR created_at > $3) +ORDER BY created_at DESC +` + +type ListEventsForSessionDescParams struct { + SessionID sql.NullString + UserID string + Column3 time.Time +} + +func (q *Queries) ListEventsForSessionDesc(ctx context.Context, arg ListEventsForSessionDescParams) ([]Event, error) { + rows, err := q.db.QueryContext(ctx, listEventsForSessionDesc, arg.SessionID, arg.UserID, arg.Column3) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Event + for rows.Next() { + var i Event + if err := rows.Scan( + &i.ID, + &i.UserID, + &i.SessionID, + &i.CreatedAt, + &i.UpdatedAt, + &i.DeletedAt, + &i.Data, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listEventsForSessionDescLimit = `-- name: ListEventsForSessionDescLimit :many +SELECT id, user_id, session_id, created_at, updated_at, deleted_at, data FROM event +WHERE session_id = $1 AND user_id = $2 AND deleted_at IS NULL + AND ($3::timestamptz IS NULL OR created_at > $3) +ORDER BY created_at DESC +LIMIT $4 +` + +type ListEventsForSessionDescLimitParams struct { + SessionID sql.NullString + UserID string + Column3 time.Time + Limit int32 +} + +func (q *Queries) ListEventsForSessionDescLimit(ctx context.Context, arg ListEventsForSessionDescLimitParams) ([]Event, error) { + rows, err := q.db.QueryContext(ctx, listEventsForSessionDescLimit, + arg.SessionID, + arg.UserID, + arg.Column3, + arg.Limit, + ) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Event + for rows.Next() { + var i Event + if err := rows.Scan( + &i.ID, + &i.UserID, + &i.SessionID, + &i.CreatedAt, + &i.UpdatedAt, + &i.DeletedAt, + &i.Data, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const softDeleteEvent = `-- name: SoftDeleteEvent :exec +UPDATE event SET deleted_at = NOW() +WHERE id = $1 AND deleted_at IS NULL +` + +func (q *Queries) SoftDeleteEvent(ctx context.Context, id string) error { + _, err := q.db.ExecContext(ctx, softDeleteEvent, id) + return err +} diff --git a/go/core/internal/database/gen/feedback.sql.go b/go/core/internal/database/gen/feedback.sql.go new file mode 100644 index 000000000..548aff05b --- /dev/null +++ b/go/core/internal/database/gen/feedback.sql.go @@ -0,0 +1,89 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: feedback.sql + +package dbgen + +import ( + "context" + "database/sql" + + "github.com/kagent-dev/kagent/go/api/database" +) + +const insertFeedback = `-- name: InsertFeedback :one +INSERT INTO feedback (user_id, message_id, is_positive, feedback_text, issue_type, created_at, updated_at) +VALUES ($1, $2, $3, $4, $5, NOW(), NOW()) +RETURNING id, created_at, updated_at, deleted_at, user_id, message_id, is_positive, feedback_text, issue_type +` + +type InsertFeedbackParams struct { + UserID string + MessageID sql.NullInt64 + IsPositive sql.NullBool + FeedbackText string + IssueType *database.FeedbackIssueType +} + +func (q *Queries) InsertFeedback(ctx context.Context, arg InsertFeedbackParams) (Feedback, error) { + row := q.db.QueryRowContext(ctx, insertFeedback, + arg.UserID, + arg.MessageID, + arg.IsPositive, + arg.FeedbackText, + arg.IssueType, + ) + var i Feedback + err := row.Scan( + &i.ID, + &i.CreatedAt, + &i.UpdatedAt, + &i.DeletedAt, + &i.UserID, + &i.MessageID, + &i.IsPositive, + &i.FeedbackText, + &i.IssueType, + ) + return i, err +} + +const listFeedback = `-- name: ListFeedback :many +SELECT id, created_at, updated_at, deleted_at, user_id, message_id, is_positive, feedback_text, issue_type FROM feedback +WHERE user_id = $1 AND deleted_at IS NULL +ORDER BY created_at ASC +` + +func (q *Queries) ListFeedback(ctx context.Context, userID string) ([]Feedback, error) { + rows, err := q.db.QueryContext(ctx, listFeedback, userID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Feedback + for rows.Next() { + var i Feedback + if err := rows.Scan( + &i.ID, + &i.CreatedAt, + &i.UpdatedAt, + &i.DeletedAt, + &i.UserID, + &i.MessageID, + &i.IsPositive, + &i.FeedbackText, + &i.IssueType, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/go/core/internal/database/gen/langgraph.sql.go b/go/core/internal/database/gen/langgraph.sql.go new file mode 100644 index 000000000..d37b73b94 --- /dev/null +++ b/go/core/internal/database/gen/langgraph.sql.go @@ -0,0 +1,322 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: langgraph.sql + +package dbgen + +import ( + "context" + "database/sql" +) + +const getCheckpoint = `-- name: GetCheckpoint :one +SELECT user_id, thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, created_at, updated_at, deleted_at, metadata, checkpoint, checkpoint_type, version FROM lg_checkpoint +WHERE user_id = $1 AND thread_id = $2 AND checkpoint_ns = $3 + AND checkpoint_id = $4 AND deleted_at IS NULL +LIMIT 1 +` + +type GetCheckpointParams struct { + UserID string + ThreadID string + CheckpointNs string + CheckpointID string +} + +func (q *Queries) GetCheckpoint(ctx context.Context, arg GetCheckpointParams) (LgCheckpoint, error) { + row := q.db.QueryRowContext(ctx, getCheckpoint, + arg.UserID, + arg.ThreadID, + arg.CheckpointNs, + arg.CheckpointID, + ) + var i LgCheckpoint + err := row.Scan( + &i.UserID, + &i.ThreadID, + &i.CheckpointNs, + &i.CheckpointID, + &i.ParentCheckpointID, + &i.CreatedAt, + &i.UpdatedAt, + &i.DeletedAt, + &i.Metadata, + &i.Checkpoint, + &i.CheckpointType, + &i.Version, + ) + return i, err +} + +const listCheckpointWrites = `-- name: ListCheckpointWrites :many +SELECT user_id, thread_id, checkpoint_ns, checkpoint_id, write_idx, value, value_type, channel, task_id, created_at, updated_at, deleted_at FROM lg_checkpoint_write +WHERE user_id = $1 AND thread_id = $2 AND checkpoint_ns = $3 + AND checkpoint_id = $4 AND deleted_at IS NULL +ORDER BY task_id, write_idx +` + +type ListCheckpointWritesParams struct { + UserID string + ThreadID string + CheckpointNs string + CheckpointID string +} + +func (q *Queries) ListCheckpointWrites(ctx context.Context, arg ListCheckpointWritesParams) ([]LgCheckpointWrite, error) { + rows, err := q.db.QueryContext(ctx, listCheckpointWrites, + arg.UserID, + arg.ThreadID, + arg.CheckpointNs, + arg.CheckpointID, + ) + if err != nil { + return nil, err + } + defer rows.Close() + var items []LgCheckpointWrite + for rows.Next() { + var i LgCheckpointWrite + if err := rows.Scan( + &i.UserID, + &i.ThreadID, + &i.CheckpointNs, + &i.CheckpointID, + &i.WriteIdx, + &i.Value, + &i.ValueType, + &i.Channel, + &i.TaskID, + &i.CreatedAt, + &i.UpdatedAt, + &i.DeletedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listCheckpoints = `-- name: ListCheckpoints :many +SELECT user_id, thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, created_at, updated_at, deleted_at, metadata, checkpoint, checkpoint_type, version FROM lg_checkpoint +WHERE user_id = $1 AND thread_id = $2 AND checkpoint_ns = $3 + AND deleted_at IS NULL +ORDER BY checkpoint_id DESC +` + +type ListCheckpointsParams struct { + UserID string + ThreadID string + CheckpointNs string +} + +func (q *Queries) ListCheckpoints(ctx context.Context, arg ListCheckpointsParams) ([]LgCheckpoint, error) { + rows, err := q.db.QueryContext(ctx, listCheckpoints, arg.UserID, arg.ThreadID, arg.CheckpointNs) + if err != nil { + return nil, err + } + defer rows.Close() + var items []LgCheckpoint + for rows.Next() { + var i LgCheckpoint + if err := rows.Scan( + &i.UserID, + &i.ThreadID, + &i.CheckpointNs, + &i.CheckpointID, + &i.ParentCheckpointID, + &i.CreatedAt, + &i.UpdatedAt, + &i.DeletedAt, + &i.Metadata, + &i.Checkpoint, + &i.CheckpointType, + &i.Version, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listCheckpointsLimit = `-- name: ListCheckpointsLimit :many +SELECT user_id, thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, created_at, updated_at, deleted_at, metadata, checkpoint, checkpoint_type, version FROM lg_checkpoint +WHERE user_id = $1 AND thread_id = $2 AND checkpoint_ns = $3 + AND deleted_at IS NULL +ORDER BY checkpoint_id DESC +LIMIT $4 +` + +type ListCheckpointsLimitParams struct { + UserID string + ThreadID string + CheckpointNs string + Limit int32 +} + +func (q *Queries) ListCheckpointsLimit(ctx context.Context, arg ListCheckpointsLimitParams) ([]LgCheckpoint, error) { + rows, err := q.db.QueryContext(ctx, listCheckpointsLimit, + arg.UserID, + arg.ThreadID, + arg.CheckpointNs, + arg.Limit, + ) + if err != nil { + return nil, err + } + defer rows.Close() + var items []LgCheckpoint + for rows.Next() { + var i LgCheckpoint + if err := rows.Scan( + &i.UserID, + &i.ThreadID, + &i.CheckpointNs, + &i.CheckpointID, + &i.ParentCheckpointID, + &i.CreatedAt, + &i.UpdatedAt, + &i.DeletedAt, + &i.Metadata, + &i.Checkpoint, + &i.CheckpointType, + &i.Version, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const softDeleteCheckpointWrites = `-- name: SoftDeleteCheckpointWrites :exec +UPDATE lg_checkpoint_write SET deleted_at = NOW() +WHERE user_id = $1 AND thread_id = $2 AND deleted_at IS NULL +` + +type SoftDeleteCheckpointWritesParams struct { + UserID string + ThreadID string +} + +func (q *Queries) SoftDeleteCheckpointWrites(ctx context.Context, arg SoftDeleteCheckpointWritesParams) error { + _, err := q.db.ExecContext(ctx, softDeleteCheckpointWrites, arg.UserID, arg.ThreadID) + return err +} + +const softDeleteCheckpoints = `-- name: SoftDeleteCheckpoints :exec +UPDATE lg_checkpoint SET deleted_at = NOW() +WHERE user_id = $1 AND thread_id = $2 AND deleted_at IS NULL +` + +type SoftDeleteCheckpointsParams struct { + UserID string + ThreadID string +} + +func (q *Queries) SoftDeleteCheckpoints(ctx context.Context, arg SoftDeleteCheckpointsParams) error { + _, err := q.db.ExecContext(ctx, softDeleteCheckpoints, arg.UserID, arg.ThreadID) + return err +} + +const upsertCheckpoint = `-- name: UpsertCheckpoint :exec +INSERT INTO lg_checkpoint ( + user_id, thread_id, checkpoint_ns, checkpoint_id, + parent_checkpoint_id, metadata, checkpoint, checkpoint_type, version, + created_at, updated_at +) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, NOW(), NOW()) +ON CONFLICT (user_id, thread_id, checkpoint_ns, checkpoint_id) DO UPDATE SET + parent_checkpoint_id = EXCLUDED.parent_checkpoint_id, + metadata = EXCLUDED.metadata, + checkpoint = EXCLUDED.checkpoint, + checkpoint_type = EXCLUDED.checkpoint_type, + version = EXCLUDED.version, + updated_at = NOW() +` + +type UpsertCheckpointParams struct { + UserID string + ThreadID string + CheckpointNs string + CheckpointID string + ParentCheckpointID sql.NullString + Metadata string + Checkpoint string + CheckpointType string + Version sql.NullInt32 +} + +func (q *Queries) UpsertCheckpoint(ctx context.Context, arg UpsertCheckpointParams) error { + _, err := q.db.ExecContext(ctx, upsertCheckpoint, + arg.UserID, + arg.ThreadID, + arg.CheckpointNs, + arg.CheckpointID, + arg.ParentCheckpointID, + arg.Metadata, + arg.Checkpoint, + arg.CheckpointType, + arg.Version, + ) + return err +} + +const upsertCheckpointWrite = `-- name: UpsertCheckpointWrite :exec +INSERT INTO lg_checkpoint_write ( + user_id, thread_id, checkpoint_ns, checkpoint_id, write_idx, + value, value_type, channel, task_id, created_at, updated_at +) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, NOW(), NOW()) +ON CONFLICT (user_id, thread_id, checkpoint_ns, checkpoint_id, write_idx) DO UPDATE SET + value = EXCLUDED.value, + value_type = EXCLUDED.value_type, + channel = EXCLUDED.channel, + task_id = EXCLUDED.task_id, + updated_at = NOW() +` + +type UpsertCheckpointWriteParams struct { + UserID string + ThreadID string + CheckpointNs string + CheckpointID string + WriteIdx int32 + Value string + ValueType string + Channel string + TaskID string +} + +func (q *Queries) UpsertCheckpointWrite(ctx context.Context, arg UpsertCheckpointWriteParams) error { + _, err := q.db.ExecContext(ctx, upsertCheckpointWrite, + arg.UserID, + arg.ThreadID, + arg.CheckpointNs, + arg.CheckpointID, + arg.WriteIdx, + arg.Value, + arg.ValueType, + arg.Channel, + arg.TaskID, + ) + return err +} diff --git a/go/core/internal/database/gen/memory.sql.go b/go/core/internal/database/gen/memory.sql.go new file mode 100644 index 000000000..b5ee97f28 --- /dev/null +++ b/go/core/internal/database/gen/memory.sql.go @@ -0,0 +1,203 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: memory.sql + +package dbgen + +import ( + "context" + "database/sql" + "time" + + "github.com/lib/pq" + pgvector_go "github.com/pgvector/pgvector-go" +) + +const deleteAgentMemory = `-- name: DeleteAgentMemory :exec +DELETE FROM memory WHERE agent_name = $1 AND user_id = $2 +` + +type DeleteAgentMemoryParams struct { + AgentName sql.NullString + UserID sql.NullString +} + +func (q *Queries) DeleteAgentMemory(ctx context.Context, arg DeleteAgentMemoryParams) error { + _, err := q.db.ExecContext(ctx, deleteAgentMemory, arg.AgentName, arg.UserID) + return err +} + +const deleteExpiredMemories = `-- name: DeleteExpiredMemories :exec +DELETE FROM memory +WHERE expires_at < NOW() AND access_count < 10 +` + +func (q *Queries) DeleteExpiredMemories(ctx context.Context) error { + _, err := q.db.ExecContext(ctx, deleteExpiredMemories) + return err +} + +const extendMemoryTTL = `-- name: ExtendMemoryTTL :exec +UPDATE memory +SET expires_at = NOW() + INTERVAL '15 days', access_count = 0 +WHERE expires_at < NOW() AND access_count >= 10 +` + +func (q *Queries) ExtendMemoryTTL(ctx context.Context) error { + _, err := q.db.ExecContext(ctx, extendMemoryTTL) + return err +} + +const incrementMemoryAccessCount = `-- name: IncrementMemoryAccessCount :exec +UPDATE memory SET access_count = access_count + 1 +WHERE id = ANY($1::text[]) +` + +func (q *Queries) IncrementMemoryAccessCount(ctx context.Context, dollar_1 []string) error { + _, err := q.db.ExecContext(ctx, incrementMemoryAccessCount, pq.Array(dollar_1)) + return err +} + +const insertMemory = `-- name: InsertMemory :one +INSERT INTO memory (agent_name, user_id, content, embedding, metadata, created_at, expires_at, access_count) +VALUES ($1, $2, $3, $4, $5, NOW(), $6, $7) +RETURNING id +` + +type InsertMemoryParams struct { + AgentName sql.NullString + UserID sql.NullString + Content sql.NullString + Embedding pgvector_go.Vector + Metadata sql.NullString + ExpiresAt sql.NullTime + AccessCount sql.NullInt32 +} + +func (q *Queries) InsertMemory(ctx context.Context, arg InsertMemoryParams) (string, error) { + row := q.db.QueryRowContext(ctx, insertMemory, + arg.AgentName, + arg.UserID, + arg.Content, + arg.Embedding, + arg.Metadata, + arg.ExpiresAt, + arg.AccessCount, + ) + var id string + err := row.Scan(&id) + return id, err +} + +const listAgentMemories = `-- name: ListAgentMemories :many +SELECT id, agent_name, user_id, content, embedding, metadata, created_at, expires_at, access_count FROM memory +WHERE (agent_name = $1 OR agent_name = $2) AND user_id = $3 +ORDER BY access_count DESC +` + +type ListAgentMemoriesParams struct { + AgentName sql.NullString + AgentName_2 sql.NullString + UserID sql.NullString +} + +func (q *Queries) ListAgentMemories(ctx context.Context, arg ListAgentMemoriesParams) ([]Memory, error) { + rows, err := q.db.QueryContext(ctx, listAgentMemories, arg.AgentName, arg.AgentName_2, arg.UserID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Memory + for rows.Next() { + var i Memory + if err := rows.Scan( + &i.ID, + &i.AgentName, + &i.UserID, + &i.Content, + &i.Embedding, + &i.Metadata, + &i.CreatedAt, + &i.ExpiresAt, + &i.AccessCount, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const searchAgentMemory = `-- name: SearchAgentMemory :many +SELECT id, agent_name, user_id, content, embedding, metadata, created_at, expires_at, access_count, COALESCE(1 - (embedding <=> $1), 0) AS score +FROM memory +WHERE agent_name = $2 AND user_id = $3 +ORDER BY embedding <=> $1 ASC +LIMIT $4 +` + +type SearchAgentMemoryParams struct { + Embedding pgvector_go.Vector + AgentName sql.NullString + UserID sql.NullString + Limit int32 +} + +type SearchAgentMemoryRow struct { + ID string + AgentName sql.NullString + UserID sql.NullString + Content sql.NullString + Embedding pgvector_go.Vector + Metadata sql.NullString + CreatedAt time.Time + ExpiresAt sql.NullTime + AccessCount sql.NullInt32 + Score interface{} +} + +func (q *Queries) SearchAgentMemory(ctx context.Context, arg SearchAgentMemoryParams) ([]SearchAgentMemoryRow, error) { + rows, err := q.db.QueryContext(ctx, searchAgentMemory, + arg.Embedding, + arg.AgentName, + arg.UserID, + arg.Limit, + ) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SearchAgentMemoryRow + for rows.Next() { + var i SearchAgentMemoryRow + if err := rows.Scan( + &i.ID, + &i.AgentName, + &i.UserID, + &i.Content, + &i.Embedding, + &i.Metadata, + &i.CreatedAt, + &i.ExpiresAt, + &i.AccessCount, + &i.Score, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/go/core/internal/database/gen/models.go b/go/core/internal/database/gen/models.go new file mode 100644 index 000000000..8c8b6d9f8 --- /dev/null +++ b/go/core/internal/database/gen/models.go @@ -0,0 +1,155 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package dbgen + +import ( + "database/sql" + "time" + + "github.com/kagent-dev/kagent/go/api/adk" + "github.com/kagent-dev/kagent/go/api/database" + pgvector_go "github.com/pgvector/pgvector-go" +) + +type Agent struct { + ID string + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt sql.NullTime + Type string + Config *adk.AgentConfig +} + +type CrewaiAgentMemory struct { + UserID string + ThreadID string + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt sql.NullTime + MemoryData string +} + +type CrewaiFlowState struct { + UserID string + ThreadID string + MethodName string + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt sql.NullTime + StateData string +} + +type Event struct { + ID string + UserID string + SessionID sql.NullString + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt sql.NullTime + Data string +} + +type Feedback struct { + ID int64 + CreatedAt sql.NullTime + UpdatedAt sql.NullTime + DeletedAt sql.NullTime + UserID string + MessageID sql.NullInt64 + IsPositive sql.NullBool + FeedbackText string + IssueType *database.FeedbackIssueType +} + +type LgCheckpoint struct { + UserID string + ThreadID string + CheckpointNs string + CheckpointID string + ParentCheckpointID sql.NullString + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt sql.NullTime + Metadata string + Checkpoint string + CheckpointType string + Version sql.NullInt32 +} + +type LgCheckpointWrite struct { + UserID string + ThreadID string + CheckpointNs string + CheckpointID string + WriteIdx int32 + Value string + ValueType string + Channel string + TaskID string + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt sql.NullTime +} + +type Memory struct { + ID string + AgentName sql.NullString + UserID sql.NullString + Content sql.NullString + Embedding pgvector_go.Vector + Metadata sql.NullString + CreatedAt time.Time + ExpiresAt sql.NullTime + AccessCount sql.NullInt32 +} + +type PushNotification struct { + ID string + TaskID string + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt sql.NullTime + Data string +} + +type Session struct { + ID string + UserID string + Name sql.NullString + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt sql.NullTime + AgentID sql.NullString + Source sql.NullString +} + +type Task struct { + ID string + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt sql.NullTime + Data string + SessionID sql.NullString +} + +type Tool struct { + ID string + ServerName string + GroupKind string + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt sql.NullTime + Description sql.NullString +} + +type Toolserver struct { + Name string + GroupKind string + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt sql.NullTime + Description sql.NullString + LastConnected sql.NullTime +} diff --git a/go/core/internal/database/gen/push_notifications.sql.go b/go/core/internal/database/gen/push_notifications.sql.go new file mode 100644 index 000000000..cdedf3555 --- /dev/null +++ b/go/core/internal/database/gen/push_notifications.sql.go @@ -0,0 +1,100 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: push_notifications.sql + +package dbgen + +import ( + "context" +) + +const getPushNotification = `-- name: GetPushNotification :one +SELECT id, task_id, created_at, updated_at, deleted_at, data FROM push_notification +WHERE task_id = $1 AND id = $2 AND deleted_at IS NULL +LIMIT 1 +` + +type GetPushNotificationParams struct { + TaskID string + ID string +} + +func (q *Queries) GetPushNotification(ctx context.Context, arg GetPushNotificationParams) (PushNotification, error) { + row := q.db.QueryRowContext(ctx, getPushNotification, arg.TaskID, arg.ID) + var i PushNotification + err := row.Scan( + &i.ID, + &i.TaskID, + &i.CreatedAt, + &i.UpdatedAt, + &i.DeletedAt, + &i.Data, + ) + return i, err +} + +const listPushNotifications = `-- name: ListPushNotifications :many +SELECT id, task_id, created_at, updated_at, deleted_at, data FROM push_notification +WHERE task_id = $1 AND deleted_at IS NULL +ORDER BY created_at ASC +` + +func (q *Queries) ListPushNotifications(ctx context.Context, taskID string) ([]PushNotification, error) { + rows, err := q.db.QueryContext(ctx, listPushNotifications, taskID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []PushNotification + for rows.Next() { + var i PushNotification + if err := rows.Scan( + &i.ID, + &i.TaskID, + &i.CreatedAt, + &i.UpdatedAt, + &i.DeletedAt, + &i.Data, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const softDeletePushNotification = `-- name: SoftDeletePushNotification :exec +UPDATE push_notification SET deleted_at = NOW() +WHERE task_id = $1 AND deleted_at IS NULL +` + +func (q *Queries) SoftDeletePushNotification(ctx context.Context, taskID string) error { + _, err := q.db.ExecContext(ctx, softDeletePushNotification, taskID) + return err +} + +const upsertPushNotification = `-- name: UpsertPushNotification :exec +INSERT INTO push_notification (id, task_id, data, created_at, updated_at) +VALUES ($1, $2, $3, NOW(), NOW()) +ON CONFLICT (id) DO UPDATE SET + data = EXCLUDED.data, + updated_at = NOW() +` + +type UpsertPushNotificationParams struct { + ID string + TaskID string + Data string +} + +func (q *Queries) UpsertPushNotification(ctx context.Context, arg UpsertPushNotificationParams) error { + _, err := q.db.ExecContext(ctx, upsertPushNotification, arg.ID, arg.TaskID, arg.Data) + return err +} diff --git a/go/core/internal/database/gen/querier.go b/go/core/internal/database/gen/querier.go new file mode 100644 index 000000000..152a08988 --- /dev/null +++ b/go/core/internal/database/gen/querier.go @@ -0,0 +1,74 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package dbgen + +import ( + "context" + "database/sql" +) + +type Querier interface { + DeleteAgentMemory(ctx context.Context, arg DeleteAgentMemoryParams) error + DeleteExpiredMemories(ctx context.Context) error + ExtendMemoryTTL(ctx context.Context) error + GetAgent(ctx context.Context, id string) (Agent, error) + GetCheckpoint(ctx context.Context, arg GetCheckpointParams) (LgCheckpoint, error) + GetEvent(ctx context.Context, arg GetEventParams) (Event, error) + GetLatestCrewAIFlowState(ctx context.Context, arg GetLatestCrewAIFlowStateParams) (CrewaiFlowState, error) + GetPushNotification(ctx context.Context, arg GetPushNotificationParams) (PushNotification, error) + GetSession(ctx context.Context, arg GetSessionParams) (Session, error) + GetTask(ctx context.Context, id string) (Task, error) + GetTool(ctx context.Context, id string) (Tool, error) + GetToolServer(ctx context.Context, name string) (Toolserver, error) + HardDeleteCrewAIMemory(ctx context.Context, arg HardDeleteCrewAIMemoryParams) error + IncrementMemoryAccessCount(ctx context.Context, dollar_1 []string) error + InsertEvent(ctx context.Context, arg InsertEventParams) error + InsertFeedback(ctx context.Context, arg InsertFeedbackParams) (Feedback, error) + InsertMemory(ctx context.Context, arg InsertMemoryParams) (string, error) + ListAgentMemories(ctx context.Context, arg ListAgentMemoriesParams) ([]Memory, error) + ListAgents(ctx context.Context) ([]Agent, error) + ListCheckpointWrites(ctx context.Context, arg ListCheckpointWritesParams) ([]LgCheckpointWrite, error) + ListCheckpoints(ctx context.Context, arg ListCheckpointsParams) ([]LgCheckpoint, error) + ListCheckpointsLimit(ctx context.Context, arg ListCheckpointsLimitParams) ([]LgCheckpoint, error) + ListEventsByContextID(ctx context.Context, sessionID sql.NullString) ([]Event, error) + ListEventsByContextIDLimit(ctx context.Context, arg ListEventsByContextIDLimitParams) ([]Event, error) + ListEventsForSessionAsc(ctx context.Context, arg ListEventsForSessionAscParams) ([]Event, error) + ListEventsForSessionAscLimit(ctx context.Context, arg ListEventsForSessionAscLimitParams) ([]Event, error) + ListEventsForSessionDesc(ctx context.Context, arg ListEventsForSessionDescParams) ([]Event, error) + ListEventsForSessionDescLimit(ctx context.Context, arg ListEventsForSessionDescLimitParams) ([]Event, error) + ListFeedback(ctx context.Context, userID string) ([]Feedback, error) + ListPushNotifications(ctx context.Context, taskID string) ([]PushNotification, error) + ListSessions(ctx context.Context, userID string) ([]Session, error) + ListSessionsForAgent(ctx context.Context, arg ListSessionsForAgentParams) ([]Session, error) + ListTasksForSession(ctx context.Context, sessionID sql.NullString) ([]Task, error) + ListToolServers(ctx context.Context) ([]Toolserver, error) + ListTools(ctx context.Context) ([]Tool, error) + ListToolsForServer(ctx context.Context, arg ListToolsForServerParams) ([]Tool, error) + SearchAgentMemory(ctx context.Context, arg SearchAgentMemoryParams) ([]SearchAgentMemoryRow, error) + SearchCrewAIMemoryByTask(ctx context.Context, arg SearchCrewAIMemoryByTaskParams) ([]CrewaiAgentMemory, error) + SearchCrewAIMemoryByTaskLimit(ctx context.Context, arg SearchCrewAIMemoryByTaskLimitParams) ([]CrewaiAgentMemory, error) + SoftDeleteAgent(ctx context.Context, id string) error + SoftDeleteCheckpointWrites(ctx context.Context, arg SoftDeleteCheckpointWritesParams) error + SoftDeleteCheckpoints(ctx context.Context, arg SoftDeleteCheckpointsParams) error + SoftDeleteEvent(ctx context.Context, id string) error + SoftDeletePushNotification(ctx context.Context, taskID string) error + SoftDeleteSession(ctx context.Context, arg SoftDeleteSessionParams) error + SoftDeleteTask(ctx context.Context, id string) error + SoftDeleteToolServer(ctx context.Context, arg SoftDeleteToolServerParams) error + SoftDeleteToolsForServer(ctx context.Context, arg SoftDeleteToolsForServerParams) error + TaskExists(ctx context.Context, id string) (bool, error) + UpsertAgent(ctx context.Context, arg UpsertAgentParams) error + UpsertCheckpoint(ctx context.Context, arg UpsertCheckpointParams) error + UpsertCheckpointWrite(ctx context.Context, arg UpsertCheckpointWriteParams) error + UpsertCrewAIFlowState(ctx context.Context, arg UpsertCrewAIFlowStateParams) error + UpsertCrewAIMemory(ctx context.Context, arg UpsertCrewAIMemoryParams) error + UpsertPushNotification(ctx context.Context, arg UpsertPushNotificationParams) error + UpsertSession(ctx context.Context, arg UpsertSessionParams) error + UpsertTask(ctx context.Context, arg UpsertTaskParams) error + UpsertTool(ctx context.Context, arg UpsertToolParams) error + UpsertToolServer(ctx context.Context, arg UpsertToolServerParams) (Toolserver, error) +} + +var _ Querier = (*Queries)(nil) diff --git a/go/core/internal/database/gen/sessions.sql.go b/go/core/internal/database/gen/sessions.sql.go new file mode 100644 index 000000000..b7f58b70a --- /dev/null +++ b/go/core/internal/database/gen/sessions.sql.go @@ -0,0 +1,164 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: sessions.sql + +package dbgen + +import ( + "context" + "database/sql" +) + +const getSession = `-- name: GetSession :one +SELECT id, user_id, name, created_at, updated_at, deleted_at, agent_id, source FROM session +WHERE id = $1 AND user_id = $2 AND deleted_at IS NULL +LIMIT 1 +` + +type GetSessionParams struct { + ID string + UserID string +} + +func (q *Queries) GetSession(ctx context.Context, arg GetSessionParams) (Session, error) { + row := q.db.QueryRowContext(ctx, getSession, arg.ID, arg.UserID) + var i Session + err := row.Scan( + &i.ID, + &i.UserID, + &i.Name, + &i.CreatedAt, + &i.UpdatedAt, + &i.DeletedAt, + &i.AgentID, + &i.Source, + ) + return i, err +} + +const listSessions = `-- name: ListSessions :many +SELECT id, user_id, name, created_at, updated_at, deleted_at, agent_id, source FROM session +WHERE user_id = $1 AND deleted_at IS NULL +ORDER BY created_at ASC +` + +func (q *Queries) ListSessions(ctx context.Context, userID string) ([]Session, error) { + rows, err := q.db.QueryContext(ctx, listSessions, userID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Session + for rows.Next() { + var i Session + if err := rows.Scan( + &i.ID, + &i.UserID, + &i.Name, + &i.CreatedAt, + &i.UpdatedAt, + &i.DeletedAt, + &i.AgentID, + &i.Source, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listSessionsForAgent = `-- name: ListSessionsForAgent :many +SELECT id, user_id, name, created_at, updated_at, deleted_at, agent_id, source FROM session +WHERE agent_id = $1 AND user_id = $2 AND deleted_at IS NULL + AND (source IS NULL OR source != 'agent') +ORDER BY created_at ASC +` + +type ListSessionsForAgentParams struct { + AgentID sql.NullString + UserID string +} + +func (q *Queries) ListSessionsForAgent(ctx context.Context, arg ListSessionsForAgentParams) ([]Session, error) { + rows, err := q.db.QueryContext(ctx, listSessionsForAgent, arg.AgentID, arg.UserID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Session + for rows.Next() { + var i Session + if err := rows.Scan( + &i.ID, + &i.UserID, + &i.Name, + &i.CreatedAt, + &i.UpdatedAt, + &i.DeletedAt, + &i.AgentID, + &i.Source, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const softDeleteSession = `-- name: SoftDeleteSession :exec +UPDATE session SET deleted_at = NOW() +WHERE id = $1 AND user_id = $2 AND deleted_at IS NULL +` + +type SoftDeleteSessionParams struct { + ID string + UserID string +} + +func (q *Queries) SoftDeleteSession(ctx context.Context, arg SoftDeleteSessionParams) error { + _, err := q.db.ExecContext(ctx, softDeleteSession, arg.ID, arg.UserID) + return err +} + +const upsertSession = `-- name: UpsertSession :exec +INSERT INTO session (id, user_id, name, agent_id, source, created_at, updated_at) +VALUES ($1, $2, $3, $4, $5, NOW(), NOW()) +ON CONFLICT (id, user_id) DO UPDATE SET + name = EXCLUDED.name, + agent_id = EXCLUDED.agent_id, + source = EXCLUDED.source, + updated_at = NOW() +` + +type UpsertSessionParams struct { + ID string + UserID string + Name sql.NullString + AgentID sql.NullString + Source sql.NullString +} + +func (q *Queries) UpsertSession(ctx context.Context, arg UpsertSessionParams) error { + _, err := q.db.ExecContext(ctx, upsertSession, + arg.ID, + arg.UserID, + arg.Name, + arg.AgentID, + arg.Source, + ) + return err +} diff --git a/go/core/internal/database/gen/tasks.sql.go b/go/core/internal/database/gen/tasks.sql.go new file mode 100644 index 000000000..c3e59638d --- /dev/null +++ b/go/core/internal/database/gen/tasks.sql.go @@ -0,0 +1,109 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: tasks.sql + +package dbgen + +import ( + "context" + "database/sql" +) + +const getTask = `-- name: GetTask :one +SELECT id, created_at, updated_at, deleted_at, data, session_id FROM task +WHERE id = $1 AND deleted_at IS NULL +LIMIT 1 +` + +func (q *Queries) GetTask(ctx context.Context, id string) (Task, error) { + row := q.db.QueryRowContext(ctx, getTask, id) + var i Task + err := row.Scan( + &i.ID, + &i.CreatedAt, + &i.UpdatedAt, + &i.DeletedAt, + &i.Data, + &i.SessionID, + ) + return i, err +} + +const listTasksForSession = `-- name: ListTasksForSession :many +SELECT id, created_at, updated_at, deleted_at, data, session_id FROM task +WHERE session_id = $1 AND deleted_at IS NULL +ORDER BY created_at ASC +` + +func (q *Queries) ListTasksForSession(ctx context.Context, sessionID sql.NullString) ([]Task, error) { + rows, err := q.db.QueryContext(ctx, listTasksForSession, sessionID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Task + for rows.Next() { + var i Task + if err := rows.Scan( + &i.ID, + &i.CreatedAt, + &i.UpdatedAt, + &i.DeletedAt, + &i.Data, + &i.SessionID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const softDeleteTask = `-- name: SoftDeleteTask :exec +UPDATE task SET deleted_at = NOW() WHERE id = $1 AND deleted_at IS NULL +` + +func (q *Queries) SoftDeleteTask(ctx context.Context, id string) error { + _, err := q.db.ExecContext(ctx, softDeleteTask, id) + return err +} + +const taskExists = `-- name: TaskExists :one +SELECT EXISTS ( + SELECT 1 FROM task WHERE id = $1 AND deleted_at IS NULL +) AS exists +` + +func (q *Queries) TaskExists(ctx context.Context, id string) (bool, error) { + row := q.db.QueryRowContext(ctx, taskExists, id) + var exists bool + err := row.Scan(&exists) + return exists, err +} + +const upsertTask = `-- name: UpsertTask :exec +INSERT INTO task (id, data, session_id, created_at, updated_at) +VALUES ($1, $2, $3, NOW(), NOW()) +ON CONFLICT (id) DO UPDATE SET + data = EXCLUDED.data, + session_id = EXCLUDED.session_id, + updated_at = NOW() +` + +type UpsertTaskParams struct { + ID string + Data string + SessionID sql.NullString +} + +func (q *Queries) UpsertTask(ctx context.Context, arg UpsertTaskParams) error { + _, err := q.db.ExecContext(ctx, upsertTask, arg.ID, arg.Data, arg.SessionID) + return err +} diff --git a/go/core/internal/database/gen/tools.sql.go b/go/core/internal/database/gen/tools.sql.go new file mode 100644 index 000000000..67be7412e --- /dev/null +++ b/go/core/internal/database/gen/tools.sql.go @@ -0,0 +1,263 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: tools.sql + +package dbgen + +import ( + "context" + "database/sql" +) + +const getTool = `-- name: GetTool :one +SELECT id, server_name, group_kind, created_at, updated_at, deleted_at, description FROM tool +WHERE id = $1 AND deleted_at IS NULL +LIMIT 1 +` + +func (q *Queries) GetTool(ctx context.Context, id string) (Tool, error) { + row := q.db.QueryRowContext(ctx, getTool, id) + var i Tool + err := row.Scan( + &i.ID, + &i.ServerName, + &i.GroupKind, + &i.CreatedAt, + &i.UpdatedAt, + &i.DeletedAt, + &i.Description, + ) + return i, err +} + +const getToolServer = `-- name: GetToolServer :one +SELECT name, group_kind, created_at, updated_at, deleted_at, description, last_connected FROM toolserver +WHERE name = $1 AND deleted_at IS NULL +LIMIT 1 +` + +func (q *Queries) GetToolServer(ctx context.Context, name string) (Toolserver, error) { + row := q.db.QueryRowContext(ctx, getToolServer, name) + var i Toolserver + err := row.Scan( + &i.Name, + &i.GroupKind, + &i.CreatedAt, + &i.UpdatedAt, + &i.DeletedAt, + &i.Description, + &i.LastConnected, + ) + return i, err +} + +const listToolServers = `-- name: ListToolServers :many +SELECT name, group_kind, created_at, updated_at, deleted_at, description, last_connected FROM toolserver +WHERE deleted_at IS NULL +ORDER BY created_at ASC +` + +func (q *Queries) ListToolServers(ctx context.Context) ([]Toolserver, error) { + rows, err := q.db.QueryContext(ctx, listToolServers) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Toolserver + for rows.Next() { + var i Toolserver + if err := rows.Scan( + &i.Name, + &i.GroupKind, + &i.CreatedAt, + &i.UpdatedAt, + &i.DeletedAt, + &i.Description, + &i.LastConnected, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listTools = `-- name: ListTools :many +SELECT id, server_name, group_kind, created_at, updated_at, deleted_at, description FROM tool +WHERE deleted_at IS NULL +ORDER BY created_at ASC +` + +func (q *Queries) ListTools(ctx context.Context) ([]Tool, error) { + rows, err := q.db.QueryContext(ctx, listTools) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Tool + for rows.Next() { + var i Tool + if err := rows.Scan( + &i.ID, + &i.ServerName, + &i.GroupKind, + &i.CreatedAt, + &i.UpdatedAt, + &i.DeletedAt, + &i.Description, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listToolsForServer = `-- name: ListToolsForServer :many +SELECT id, server_name, group_kind, created_at, updated_at, deleted_at, description FROM tool +WHERE server_name = $1 AND group_kind = $2 AND deleted_at IS NULL +ORDER BY created_at ASC +` + +type ListToolsForServerParams struct { + ServerName string + GroupKind string +} + +func (q *Queries) ListToolsForServer(ctx context.Context, arg ListToolsForServerParams) ([]Tool, error) { + rows, err := q.db.QueryContext(ctx, listToolsForServer, arg.ServerName, arg.GroupKind) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Tool + for rows.Next() { + var i Tool + if err := rows.Scan( + &i.ID, + &i.ServerName, + &i.GroupKind, + &i.CreatedAt, + &i.UpdatedAt, + &i.DeletedAt, + &i.Description, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const softDeleteToolServer = `-- name: SoftDeleteToolServer :exec +UPDATE toolserver SET deleted_at = NOW() +WHERE name = $1 AND group_kind = $2 AND deleted_at IS NULL +` + +type SoftDeleteToolServerParams struct { + Name string + GroupKind string +} + +func (q *Queries) SoftDeleteToolServer(ctx context.Context, arg SoftDeleteToolServerParams) error { + _, err := q.db.ExecContext(ctx, softDeleteToolServer, arg.Name, arg.GroupKind) + return err +} + +const softDeleteToolsForServer = `-- name: SoftDeleteToolsForServer :exec +UPDATE tool SET deleted_at = NOW() +WHERE server_name = $1 AND group_kind = $2 AND deleted_at IS NULL +` + +type SoftDeleteToolsForServerParams struct { + ServerName string + GroupKind string +} + +func (q *Queries) SoftDeleteToolsForServer(ctx context.Context, arg SoftDeleteToolsForServerParams) error { + _, err := q.db.ExecContext(ctx, softDeleteToolsForServer, arg.ServerName, arg.GroupKind) + return err +} + +const upsertTool = `-- name: UpsertTool :exec +INSERT INTO tool (id, server_name, group_kind, description, created_at, updated_at) +VALUES ($1, $2, $3, $4, NOW(), NOW()) +ON CONFLICT (id, server_name, group_kind) DO UPDATE SET + description = EXCLUDED.description, + updated_at = NOW(), + deleted_at = NULL +` + +type UpsertToolParams struct { + ID string + ServerName string + GroupKind string + Description sql.NullString +} + +func (q *Queries) UpsertTool(ctx context.Context, arg UpsertToolParams) error { + _, err := q.db.ExecContext(ctx, upsertTool, + arg.ID, + arg.ServerName, + arg.GroupKind, + arg.Description, + ) + return err +} + +const upsertToolServer = `-- name: UpsertToolServer :one +INSERT INTO toolserver (name, group_kind, description, last_connected, created_at, updated_at) +VALUES ($1, $2, $3, $4, NOW(), NOW()) +ON CONFLICT (name, group_kind) DO UPDATE SET + description = EXCLUDED.description, + last_connected = EXCLUDED.last_connected, + updated_at = NOW(), + deleted_at = NULL +RETURNING name, group_kind, created_at, updated_at, deleted_at, description, last_connected +` + +type UpsertToolServerParams struct { + Name string + GroupKind string + Description sql.NullString + LastConnected sql.NullTime +} + +func (q *Queries) UpsertToolServer(ctx context.Context, arg UpsertToolServerParams) (Toolserver, error) { + row := q.db.QueryRowContext(ctx, upsertToolServer, + arg.Name, + arg.GroupKind, + arg.Description, + arg.LastConnected, + ) + var i Toolserver + err := row.Scan( + &i.Name, + &i.GroupKind, + &i.CreatedAt, + &i.UpdatedAt, + &i.DeletedAt, + &i.Description, + &i.LastConnected, + ) + return i, err +} diff --git a/go/core/internal/database/manager.go b/go/core/internal/database/manager.go deleted file mode 100644 index 3fbdf5176..000000000 --- a/go/core/internal/database/manager.go +++ /dev/null @@ -1,207 +0,0 @@ -package database - -import ( - "context" - "fmt" - "os" - "strings" - "sync" - "time" - - dbpkg "github.com/kagent-dev/kagent/go/api/database" - "github.com/kagent-dev/kagent/go/core/pkg/env" - "gorm.io/driver/postgres" - "gorm.io/gorm" - "gorm.io/gorm/logger" - ctrl "sigs.k8s.io/controller-runtime" -) - -var dbLog = ctrl.Log.WithName("database") - -// Manager handles database connection and initialization -type Manager struct { - db *gorm.DB - config *Config - initLock sync.Mutex -} - -type PostgresConfig struct { - URL string - URLFile string - VectorEnabled bool -} - -type Config struct { - PostgresConfig *PostgresConfig -} - -const ( - defaultMaxTimeout = 120 * time.Second - defaultInitialDelay = 500 * time.Millisecond - defaultMaxDelay = 5 * time.Second -) - -// retryDBConnection attempts to open a GORM connection, retrying with backoff until -// the context deadline or defaultMaxTimeout is reached. -func retryDBConnection(ctx context.Context, url string, cfg *gorm.Config) (*gorm.DB, error) { - ctx, cancel := context.WithTimeout(ctx, defaultMaxTimeout) - defer cancel() - - start := time.Now() - delay := defaultInitialDelay - for attempt := 1; ; attempt++ { - db, err := gorm.Open(postgres.Open(url), cfg) - if err == nil { - return db, nil - } - dbLog.Info("database not ready, retrying", "attempt", attempt, "delay", delay, "error", err.Error()) - select { - case <-ctx.Done(): - return nil, fmt.Errorf("database not ready after %s: %w", time.Since(start).Round(time.Second), ctx.Err()) - case <-time.After(delay): - } - delay *= 2 - if delay > defaultMaxDelay { - delay = defaultMaxDelay - } - } -} - -// NewManager creates a new database manager -func NewManager(ctx context.Context, config *Config) (*Manager, error) { - logLevel := logger.Silent - switch env.GormLogLevel.Get() { - case "error": - logLevel = logger.Error - case "warn": - logLevel = logger.Warn - case "info": - logLevel = logger.Info - case "silent": - logLevel = logger.Silent - } - - url := config.PostgresConfig.URL - if config.PostgresConfig.URLFile != "" { - resolved, err := resolveURLFile(config.PostgresConfig.URLFile) - if err != nil { - return nil, fmt.Errorf("failed to resolve postgres URL from file: %w", err) - } - url = resolved - } - - gormCfg := &gorm.Config{ - Logger: logger.Default.LogMode(logLevel), - TranslateError: true, - } - - db, err := retryDBConnection(ctx, url, gormCfg) - if err != nil { - return nil, err - } - - return &Manager{db: db, config: config}, nil -} - -// Initialize sets up the database tables -func (m *Manager) Initialize() error { - if m.config.PostgresConfig.VectorEnabled { - if err := m.db.Exec("CREATE EXTENSION IF NOT EXISTS vector").Error; err != nil { - return fmt.Errorf("failed to create vector extension: %w", err) - } - } - - // AutoMigrate all models - err := m.db.AutoMigrate( - &dbpkg.Agent{}, - &dbpkg.Session{}, - &dbpkg.Task{}, - &dbpkg.Event{}, - &dbpkg.PushNotification{}, - &dbpkg.Feedback{}, - &dbpkg.Tool{}, - &dbpkg.ToolServer{}, - &dbpkg.LangGraphCheckpoint{}, - &dbpkg.LangGraphCheckpointWrite{}, - &dbpkg.CrewAIAgentMemory{}, - &dbpkg.CrewAIFlowState{}, - ) - if err != nil { - return fmt.Errorf("failed to migrate database: %w", err) - } - - if m.config.PostgresConfig.VectorEnabled { - if err := m.db.AutoMigrate(&dbpkg.Memory{}); err != nil { - return fmt.Errorf("failed to migrate memory table: %w", err) - } - - // Manually create the HNSW index with the correct operator class — - // GORM doesn't support adding "op class" in struct tags for Postgres vectors. - indexQuery := `CREATE INDEX IF NOT EXISTS idx_memory_embedding_hnsw ON memory USING hnsw (embedding vector_cosine_ops)` - if err := m.db.Exec(indexQuery).Error; err != nil { - return fmt.Errorf("failed to create hnsw index: %w", err) - } - } - - return nil -} - -// Reset drops all tables and optionally recreates them -func (m *Manager) Reset(recreateTables bool) error { - if !m.initLock.TryLock() { - return fmt.Errorf("database reset already in progress") - } - defer m.initLock.Unlock() - - err := m.db.Migrator().DropTable( - &dbpkg.Agent{}, - &dbpkg.Session{}, - &dbpkg.Task{}, - &dbpkg.Event{}, - &dbpkg.PushNotification{}, - &dbpkg.Feedback{}, - &dbpkg.Tool{}, - &dbpkg.ToolServer{}, - &dbpkg.LangGraphCheckpoint{}, - &dbpkg.LangGraphCheckpointWrite{}, - &dbpkg.CrewAIAgentMemory{}, - &dbpkg.CrewAIFlowState{}, - &dbpkg.Memory{}, - ) - if err != nil { - return fmt.Errorf("failed to drop tables: %w", err) - } - - if recreateTables { - return m.Initialize() - } - - return nil -} - -// resolveURLFile reads a database connection URL from a file and returns the -// trimmed contents. Returns an error if the file cannot be read or is empty. -func resolveURLFile(path string) (string, error) { - content, err := os.ReadFile(path) - if err != nil { - return "", fmt.Errorf("reading URL file: %w", err) - } - url := strings.TrimSpace(string(content)) - if url == "" { - return "", fmt.Errorf("URL file %s is empty or contains only whitespace", path) - } - return url, nil -} - -// Close closes the database connection -func (m *Manager) Close() error { - if m.db == nil { - return nil - } - - sqlDB, err := m.db.DB() - if err != nil { - return err - } - return sqlDB.Close() -} diff --git a/go/core/internal/database/queries/agents.sql b/go/core/internal/database/queries/agents.sql new file mode 100644 index 000000000..135fbb1ec --- /dev/null +++ b/go/core/internal/database/queries/agents.sql @@ -0,0 +1,21 @@ +-- name: GetAgent :one +SELECT * FROM agent +WHERE id = $1 AND deleted_at IS NULL +LIMIT 1; + +-- name: ListAgents :many +SELECT * FROM agent +WHERE deleted_at IS NULL +ORDER BY created_at ASC; + +-- name: UpsertAgent :exec +INSERT INTO agent (id, type, config, created_at, updated_at) +VALUES ($1, $2, $3, NOW(), NOW()) +ON CONFLICT (id) DO UPDATE SET + type = EXCLUDED.type, + config = EXCLUDED.config, + updated_at = NOW(), + deleted_at = NULL; + +-- name: SoftDeleteAgent :exec +UPDATE agent SET deleted_at = NOW() WHERE id = $1 AND deleted_at IS NULL; diff --git a/go/core/internal/database/queries/crewai.sql b/go/core/internal/database/queries/crewai.sql new file mode 100644 index 000000000..b9b2783a7 --- /dev/null +++ b/go/core/internal/database/queries/crewai.sql @@ -0,0 +1,38 @@ +-- name: UpsertCrewAIMemory :exec +INSERT INTO crewai_agent_memory (user_id, thread_id, memory_data, created_at, updated_at) +VALUES ($1, $2, $3, NOW(), NOW()) +ON CONFLICT (user_id, thread_id) DO UPDATE SET + memory_data = EXCLUDED.memory_data, + updated_at = NOW(), + deleted_at = NULL; + +-- name: SearchCrewAIMemoryByTask :many +SELECT * FROM crewai_agent_memory +WHERE user_id = $1 AND thread_id = $2 AND deleted_at IS NULL + AND (memory_data ILIKE $3 OR (memory_data::jsonb)->>'task_description' ILIKE $3) +ORDER BY created_at DESC, (memory_data::jsonb)->>'score' ASC; + +-- name: SearchCrewAIMemoryByTaskLimit :many +SELECT * FROM crewai_agent_memory +WHERE user_id = $1 AND thread_id = $2 AND deleted_at IS NULL + AND (memory_data ILIKE $3 OR (memory_data::jsonb)->>'task_description' ILIKE $3) +ORDER BY created_at DESC, (memory_data::jsonb)->>'score' ASC +LIMIT $4; + +-- name: HardDeleteCrewAIMemory :exec +DELETE FROM crewai_agent_memory +WHERE user_id = $1 AND thread_id = $2; + +-- name: UpsertCrewAIFlowState :exec +INSERT INTO crewai_flow_state (user_id, thread_id, method_name, state_data, created_at, updated_at) +VALUES ($1, $2, $3, $4, NOW(), NOW()) +ON CONFLICT (user_id, thread_id, method_name) DO UPDATE SET + state_data = EXCLUDED.state_data, + updated_at = NOW(), + deleted_at = NULL; + +-- name: GetLatestCrewAIFlowState :one +SELECT * FROM crewai_flow_state +WHERE user_id = $1 AND thread_id = $2 AND deleted_at IS NULL +ORDER BY created_at DESC +LIMIT 1; diff --git a/go/core/internal/database/queries/events.sql b/go/core/internal/database/queries/events.sql new file mode 100644 index 000000000..ae6fd88dc --- /dev/null +++ b/go/core/internal/database/queries/events.sql @@ -0,0 +1,49 @@ +-- name: InsertEvent :exec +INSERT INTO event (id, user_id, session_id, data, created_at, updated_at) +VALUES ($1, $2, $3, $4, NOW(), NOW()); + +-- name: GetEvent :one +SELECT * FROM event +WHERE id = $1 AND user_id = $2 AND deleted_at IS NULL +LIMIT 1; + +-- name: ListEventsForSessionAsc :many +SELECT * FROM event +WHERE session_id = $1 AND user_id = $2 AND deleted_at IS NULL + AND ($3::timestamptz IS NULL OR created_at > $3) +ORDER BY created_at ASC; + +-- name: ListEventsForSessionDesc :many +SELECT * FROM event +WHERE session_id = $1 AND user_id = $2 AND deleted_at IS NULL + AND ($3::timestamptz IS NULL OR created_at > $3) +ORDER BY created_at DESC; + +-- name: ListEventsForSessionAscLimit :many +SELECT * FROM event +WHERE session_id = $1 AND user_id = $2 AND deleted_at IS NULL + AND ($3::timestamptz IS NULL OR created_at > $3) +ORDER BY created_at ASC +LIMIT $4; + +-- name: ListEventsForSessionDescLimit :many +SELECT * FROM event +WHERE session_id = $1 AND user_id = $2 AND deleted_at IS NULL + AND ($3::timestamptz IS NULL OR created_at > $3) +ORDER BY created_at DESC +LIMIT $4; + +-- name: ListEventsByContextID :many +SELECT * FROM event +WHERE session_id = $1 AND deleted_at IS NULL +ORDER BY created_at DESC; + +-- name: ListEventsByContextIDLimit :many +SELECT * FROM event +WHERE session_id = $1 AND deleted_at IS NULL +ORDER BY created_at DESC +LIMIT $2; + +-- name: SoftDeleteEvent :exec +UPDATE event SET deleted_at = NOW() +WHERE id = $1 AND deleted_at IS NULL; diff --git a/go/core/internal/database/queries/feedback.sql b/go/core/internal/database/queries/feedback.sql new file mode 100644 index 000000000..e5f9a48b2 --- /dev/null +++ b/go/core/internal/database/queries/feedback.sql @@ -0,0 +1,9 @@ +-- name: InsertFeedback :one +INSERT INTO feedback (user_id, message_id, is_positive, feedback_text, issue_type, created_at, updated_at) +VALUES ($1, $2, $3, $4, $5, NOW(), NOW()) +RETURNING *; + +-- name: ListFeedback :many +SELECT * FROM feedback +WHERE user_id = $1 AND deleted_at IS NULL +ORDER BY created_at ASC; diff --git a/go/core/internal/database/queries/langgraph.sql b/go/core/internal/database/queries/langgraph.sql new file mode 100644 index 000000000..bd20481cf --- /dev/null +++ b/go/core/internal/database/queries/langgraph.sql @@ -0,0 +1,58 @@ +-- name: UpsertCheckpoint :exec +INSERT INTO lg_checkpoint ( + user_id, thread_id, checkpoint_ns, checkpoint_id, + parent_checkpoint_id, metadata, checkpoint, checkpoint_type, version, + created_at, updated_at +) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, NOW(), NOW()) +ON CONFLICT (user_id, thread_id, checkpoint_ns, checkpoint_id) DO UPDATE SET + parent_checkpoint_id = EXCLUDED.parent_checkpoint_id, + metadata = EXCLUDED.metadata, + checkpoint = EXCLUDED.checkpoint, + checkpoint_type = EXCLUDED.checkpoint_type, + version = EXCLUDED.version, + updated_at = NOW(); + +-- name: ListCheckpoints :many +SELECT * FROM lg_checkpoint +WHERE user_id = $1 AND thread_id = $2 AND checkpoint_ns = $3 + AND deleted_at IS NULL +ORDER BY checkpoint_id DESC; + +-- name: ListCheckpointsLimit :many +SELECT * FROM lg_checkpoint +WHERE user_id = $1 AND thread_id = $2 AND checkpoint_ns = $3 + AND deleted_at IS NULL +ORDER BY checkpoint_id DESC +LIMIT $4; + +-- name: GetCheckpoint :one +SELECT * FROM lg_checkpoint +WHERE user_id = $1 AND thread_id = $2 AND checkpoint_ns = $3 + AND checkpoint_id = $4 AND deleted_at IS NULL +LIMIT 1; + +-- name: ListCheckpointWrites :many +SELECT * FROM lg_checkpoint_write +WHERE user_id = $1 AND thread_id = $2 AND checkpoint_ns = $3 + AND checkpoint_id = $4 AND deleted_at IS NULL +ORDER BY task_id, write_idx; + +-- name: UpsertCheckpointWrite :exec +INSERT INTO lg_checkpoint_write ( + user_id, thread_id, checkpoint_ns, checkpoint_id, write_idx, + value, value_type, channel, task_id, created_at, updated_at +) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, NOW(), NOW()) +ON CONFLICT (user_id, thread_id, checkpoint_ns, checkpoint_id, write_idx) DO UPDATE SET + value = EXCLUDED.value, + value_type = EXCLUDED.value_type, + channel = EXCLUDED.channel, + task_id = EXCLUDED.task_id, + updated_at = NOW(); + +-- name: SoftDeleteCheckpoints :exec +UPDATE lg_checkpoint SET deleted_at = NOW() +WHERE user_id = $1 AND thread_id = $2 AND deleted_at IS NULL; + +-- name: SoftDeleteCheckpointWrites :exec +UPDATE lg_checkpoint_write SET deleted_at = NOW() +WHERE user_id = $1 AND thread_id = $2 AND deleted_at IS NULL; diff --git a/go/core/internal/database/queries/memory.sql b/go/core/internal/database/queries/memory.sql new file mode 100644 index 000000000..4cb88edb6 --- /dev/null +++ b/go/core/internal/database/queries/memory.sql @@ -0,0 +1,32 @@ +-- name: InsertMemory :one +INSERT INTO memory (agent_name, user_id, content, embedding, metadata, created_at, expires_at, access_count) +VALUES ($1, $2, $3, $4, $5, NOW(), $6, $7) +RETURNING id; + +-- name: SearchAgentMemory :many +SELECT *, COALESCE(1 - (embedding <=> $1), 0) AS score +FROM memory +WHERE agent_name = $2 AND user_id = $3 +ORDER BY embedding <=> $1 ASC +LIMIT $4; + +-- name: IncrementMemoryAccessCount :exec +UPDATE memory SET access_count = access_count + 1 +WHERE id = ANY($1::text[]); + +-- name: ListAgentMemories :many +SELECT * FROM memory +WHERE (agent_name = $1 OR agent_name = $2) AND user_id = $3 +ORDER BY access_count DESC; + +-- name: DeleteAgentMemory :exec +DELETE FROM memory WHERE agent_name = $1 AND user_id = $2; + +-- name: ExtendMemoryTTL :exec +UPDATE memory +SET expires_at = NOW() + INTERVAL '15 days', access_count = 0 +WHERE expires_at < NOW() AND access_count >= 10; + +-- name: DeleteExpiredMemories :exec +DELETE FROM memory +WHERE expires_at < NOW() AND access_count < 10; diff --git a/go/core/internal/database/queries/push_notifications.sql b/go/core/internal/database/queries/push_notifications.sql new file mode 100644 index 000000000..ccc7553f6 --- /dev/null +++ b/go/core/internal/database/queries/push_notifications.sql @@ -0,0 +1,20 @@ +-- name: GetPushNotification :one +SELECT * FROM push_notification +WHERE task_id = $1 AND id = $2 AND deleted_at IS NULL +LIMIT 1; + +-- name: ListPushNotifications :many +SELECT * FROM push_notification +WHERE task_id = $1 AND deleted_at IS NULL +ORDER BY created_at ASC; + +-- name: UpsertPushNotification :exec +INSERT INTO push_notification (id, task_id, data, created_at, updated_at) +VALUES ($1, $2, $3, NOW(), NOW()) +ON CONFLICT (id) DO UPDATE SET + data = EXCLUDED.data, + updated_at = NOW(); + +-- name: SoftDeletePushNotification :exec +UPDATE push_notification SET deleted_at = NOW() +WHERE task_id = $1 AND deleted_at IS NULL; diff --git a/go/core/internal/database/queries/sessions.sql b/go/core/internal/database/queries/sessions.sql new file mode 100644 index 000000000..9dcf9bbbb --- /dev/null +++ b/go/core/internal/database/queries/sessions.sql @@ -0,0 +1,28 @@ +-- name: GetSession :one +SELECT * FROM session +WHERE id = $1 AND user_id = $2 AND deleted_at IS NULL +LIMIT 1; + +-- name: ListSessions :many +SELECT * FROM session +WHERE user_id = $1 AND deleted_at IS NULL +ORDER BY created_at ASC; + +-- name: ListSessionsForAgent :many +SELECT * FROM session +WHERE agent_id = $1 AND user_id = $2 AND deleted_at IS NULL + AND (source IS NULL OR source != 'agent') +ORDER BY created_at ASC; + +-- name: UpsertSession :exec +INSERT INTO session (id, user_id, name, agent_id, source, created_at, updated_at) +VALUES ($1, $2, $3, $4, $5, NOW(), NOW()) +ON CONFLICT (id, user_id) DO UPDATE SET + name = EXCLUDED.name, + agent_id = EXCLUDED.agent_id, + source = EXCLUDED.source, + updated_at = NOW(); + +-- name: SoftDeleteSession :exec +UPDATE session SET deleted_at = NOW() +WHERE id = $1 AND user_id = $2 AND deleted_at IS NULL; diff --git a/go/core/internal/database/queries/tasks.sql b/go/core/internal/database/queries/tasks.sql new file mode 100644 index 000000000..ae72627c1 --- /dev/null +++ b/go/core/internal/database/queries/tasks.sql @@ -0,0 +1,25 @@ +-- name: GetTask :one +SELECT * FROM task +WHERE id = $1 AND deleted_at IS NULL +LIMIT 1; + +-- name: TaskExists :one +SELECT EXISTS ( + SELECT 1 FROM task WHERE id = $1 AND deleted_at IS NULL +) AS exists; + +-- name: ListTasksForSession :many +SELECT * FROM task +WHERE session_id = $1 AND deleted_at IS NULL +ORDER BY created_at ASC; + +-- name: UpsertTask :exec +INSERT INTO task (id, data, session_id, created_at, updated_at) +VALUES ($1, $2, $3, NOW(), NOW()) +ON CONFLICT (id) DO UPDATE SET + data = EXCLUDED.data, + session_id = EXCLUDED.session_id, + updated_at = NOW(); + +-- name: SoftDeleteTask :exec +UPDATE task SET deleted_at = NOW() WHERE id = $1 AND deleted_at IS NULL; diff --git a/go/core/internal/database/queries/tools.sql b/go/core/internal/database/queries/tools.sql new file mode 100644 index 000000000..90e119760 --- /dev/null +++ b/go/core/internal/database/queries/tools.sql @@ -0,0 +1,50 @@ +-- name: GetTool :one +SELECT * FROM tool +WHERE id = $1 AND deleted_at IS NULL +LIMIT 1; + +-- name: ListTools :many +SELECT * FROM tool +WHERE deleted_at IS NULL +ORDER BY created_at ASC; + +-- name: ListToolsForServer :many +SELECT * FROM tool +WHERE server_name = $1 AND group_kind = $2 AND deleted_at IS NULL +ORDER BY created_at ASC; + +-- name: UpsertTool :exec +INSERT INTO tool (id, server_name, group_kind, description, created_at, updated_at) +VALUES ($1, $2, $3, $4, NOW(), NOW()) +ON CONFLICT (id, server_name, group_kind) DO UPDATE SET + description = EXCLUDED.description, + updated_at = NOW(), + deleted_at = NULL; + +-- name: SoftDeleteToolsForServer :exec +UPDATE tool SET deleted_at = NOW() +WHERE server_name = $1 AND group_kind = $2 AND deleted_at IS NULL; + +-- name: GetToolServer :one +SELECT * FROM toolserver +WHERE name = $1 AND deleted_at IS NULL +LIMIT 1; + +-- name: ListToolServers :many +SELECT * FROM toolserver +WHERE deleted_at IS NULL +ORDER BY created_at ASC; + +-- name: UpsertToolServer :one +INSERT INTO toolserver (name, group_kind, description, last_connected, created_at, updated_at) +VALUES ($1, $2, $3, $4, NOW(), NOW()) +ON CONFLICT (name, group_kind) DO UPDATE SET + description = EXCLUDED.description, + last_connected = EXCLUDED.last_connected, + updated_at = NOW(), + deleted_at = NULL +RETURNING *; + +-- name: SoftDeleteToolServer :exec +UPDATE toolserver SET deleted_at = NOW() +WHERE name = $1 AND group_kind = $2 AND deleted_at IS NULL; diff --git a/go/core/internal/database/service.go b/go/core/internal/database/service.go deleted file mode 100644 index d96b79366..000000000 --- a/go/core/internal/database/service.go +++ /dev/null @@ -1,88 +0,0 @@ -package database - -import ( - "fmt" - "strings" - - "gorm.io/gorm" - "gorm.io/gorm/clause" -) - -type Model interface { - TableName() string -} - -type Clause struct { - Key string - Value any -} - -func list[T Model](db *gorm.DB, clauses ...Clause) ([]T, error) { - var models []T - query := db - - for _, clause := range clauses { - query = query.Where(fmt.Sprintf("%s = ?", clause.Key), clause.Value) - } - - err := query.Order("created_at ASC").Find(&models).Error - if err != nil { - return nil, fmt.Errorf("failed to list models: %w", err) - } - return models, nil -} - -func get[T Model](db *gorm.DB, clauses ...Clause) (*T, error) { - var model T - query := db - - for _, clause := range clauses { - query = query.Where(fmt.Sprintf("%s = ?", clause.Key), clause.Value) - } - - err := query.First(&model).Error - if err != nil { - return nil, fmt.Errorf("failed to get model: %w", err) - } - return &model, nil -} - -// save performs an upsert operation (INSERT ON CONFLICT DO UPDATE) -// args: -// - db: the database connection -// - model: the model to save -func save[T Model](db *gorm.DB, model *T) error { - if err := db.Clauses(clause.OnConflict{ - UpdateAll: true, - }).Create(model).Error; err != nil { - return fmt.Errorf("failed to upsert model: %w", err) - } - return nil -} - -func delete[T Model](db *gorm.DB, clauses ...Clause) error { - t := new(T) - query := db - - for _, clause := range clauses { - query = query.Where(fmt.Sprintf("%s = ?", clause.Key), clause.Value) - } - - result := query.Delete(t) - if result.Error != nil { - return fmt.Errorf("failed to delete model: %w", result.Error) - } - return nil -} - -// BuildWhereClause is deprecated, use individual Where clauses instead -func BuildWhereClause(clauses ...Clause) string { - var clausesStr strings.Builder - for idx, clause := range clauses { - if idx > 0 { - clausesStr.WriteString(" AND ") - } - fmt.Fprintf(&clausesStr, "%s = %v", clause.Key, clause.Value) - } - return clausesStr.String() -} diff --git a/go/core/internal/database/sqlc.yaml b/go/core/internal/database/sqlc.yaml new file mode 100644 index 000000000..bef4f26da --- /dev/null +++ b/go/core/internal/database/sqlc.yaml @@ -0,0 +1,27 @@ +version: "2" +sql: + - schema: ["../../pkg/migrations/core", "../../pkg/migrations/vector"] + queries: "queries" + engine: "postgresql" + gen: + go: + package: "dbgen" + out: "gen" + emit_interface: true + emit_pointers_for_null_types: true + overrides: + # Use domain types for columns that would otherwise be plain strings/bytes. + - column: "agent.config" + go_type: + import: "github.com/kagent-dev/kagent/go/api/adk" + type: "AgentConfig" + pointer: true + - column: "feedback.issue_type" + go_type: + import: "github.com/kagent-dev/kagent/go/api/database" + type: "FeedbackIssueType" + pointer: true + - column: "memory.embedding" + go_type: + import: "github.com/pgvector/pgvector-go" + type: "Vector" diff --git a/go/core/internal/database/testhelpers_test.go b/go/core/internal/database/testhelpers_test.go index e24cb700d..cb7d6c56b 100644 --- a/go/core/internal/database/testhelpers_test.go +++ b/go/core/internal/database/testhelpers_test.go @@ -2,6 +2,7 @@ package database import ( "context" + "database/sql" "flag" "fmt" "os" @@ -10,7 +11,10 @@ import ( "github.com/kagent-dev/kagent/go/core/internal/dbtest" ) -var sharedManager *Manager +var ( + sharedDB *sql.DB + sharedConnStr string +) func TestMain(m *testing.M) { flag.Parse() @@ -23,17 +27,19 @@ func TestMain(m *testing.M) { fmt.Fprintf(os.Stderr, "failed to start postgres container: %v\n", err) os.Exit(1) } + sharedConnStr = connStr + + if err := dbtest.Migrate(connStr, true); err != nil { + fmt.Fprintf(os.Stderr, "failed to migrate test database: %v\n", err) + os.Exit(1) + } - sharedManager, err = NewManager(context.Background(), &Config{ - PostgresConfig: &PostgresConfig{ - URL: connStr, - VectorEnabled: true, - }, - }) + db, err := Connect(context.Background(), &PostgresConfig{URL: connStr}) if err != nil { - fmt.Fprintf(os.Stderr, "failed to create shared manager: %v\n", err) + fmt.Fprintf(os.Stderr, "failed to connect to test database: %v\n", err) os.Exit(1) } + sharedDB = db os.Exit(m.Run()) } diff --git a/go/core/internal/database/upgrade_test.go b/go/core/internal/database/upgrade_test.go new file mode 100644 index 000000000..54d0e8ba4 --- /dev/null +++ b/go/core/internal/database/upgrade_test.go @@ -0,0 +1,358 @@ +package database + +// TestUpgradeFromGORM validates that the golang-migrate migrations run cleanly +// against a database that was previously managed by GORM AutoMigrate, and that +// pre-existing data is accessible via the new sqlc client afterwards. +// +// It simulates an existing deployment by: +// 1. Creating the schema that GORM AutoMigrate would have produced (no migration +// tracking tables, no gen_random_uuid() default on memory.id). +// 2. Seeding representative rows, including soft-deleted CrewAI rows that GORM's +// Delete() hook would have left behind. +// 3. Running the new golang-migrate migrations. +// 4. Verifying that all pre-existing data is readable and that new writes work. + +import ( + "context" + "database/sql" + "testing" + "time" + + dbpkg "github.com/kagent-dev/kagent/go/api/database" + "github.com/kagent-dev/kagent/go/core/internal/dbtest" + "github.com/pgvector/pgvector-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// gormSchema reproduces the DDL that GORM AutoMigrate emitted for the kagent +// models. Key differences from the current migrations: +// - No schema_migrations / vector_schema_migrations tracking tables. +// - memory.id has no DEFAULT (GORM relied on the BeforeCreate hook). +// - Indexes may have different names (GORM derives them from the struct name). +const gormSchema = ` +CREATE EXTENSION IF NOT EXISTS vector; + +CREATE TABLE IF NOT EXISTS agent ( + id TEXT PRIMARY KEY, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + deleted_at TIMESTAMPTZ, + type TEXT NOT NULL, + config JSON +); +CREATE INDEX IF NOT EXISTS idx_agent_deleted_at ON agent(deleted_at); + +CREATE TABLE IF NOT EXISTS session ( + id TEXT NOT NULL, + user_id TEXT NOT NULL, + name TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + deleted_at TIMESTAMPTZ, + agent_id TEXT, + source TEXT, + PRIMARY KEY (id, user_id) +); +CREATE INDEX IF NOT EXISTS idx_session_name ON session(name); +CREATE INDEX IF NOT EXISTS idx_session_agent_id ON session(agent_id); +CREATE INDEX IF NOT EXISTS idx_session_deleted_at ON session(deleted_at); +CREATE INDEX IF NOT EXISTS idx_session_source ON session(source); + +CREATE TABLE IF NOT EXISTS event ( + id TEXT NOT NULL, + user_id TEXT NOT NULL, + session_id TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + deleted_at TIMESTAMPTZ, + data TEXT NOT NULL, + PRIMARY KEY (id, user_id) +); +CREATE INDEX IF NOT EXISTS idx_event_session_id ON event(session_id); +CREATE INDEX IF NOT EXISTS idx_event_deleted_at ON event(deleted_at); + +CREATE TABLE IF NOT EXISTS task ( + id TEXT PRIMARY KEY, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + deleted_at TIMESTAMPTZ, + data TEXT NOT NULL, + session_id TEXT +); +CREATE INDEX IF NOT EXISTS idx_task_session_id ON task(session_id); +CREATE INDEX IF NOT EXISTS idx_task_deleted_at ON task(deleted_at); + +CREATE TABLE IF NOT EXISTS push_notification ( + id TEXT PRIMARY KEY, + task_id TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + deleted_at TIMESTAMPTZ, + data TEXT NOT NULL +); +CREATE INDEX IF NOT EXISTS idx_push_notification_task_id ON push_notification(task_id); +CREATE INDEX IF NOT EXISTS idx_push_notification_deleted_at ON push_notification(deleted_at); + +CREATE TABLE IF NOT EXISTS feedback ( + id BIGSERIAL PRIMARY KEY, + created_at TIMESTAMPTZ, + updated_at TIMESTAMPTZ, + deleted_at TIMESTAMPTZ, + user_id TEXT NOT NULL, + message_id BIGINT, + is_positive BOOLEAN DEFAULT false, + feedback_text TEXT NOT NULL, + issue_type TEXT +); +CREATE INDEX IF NOT EXISTS idx_feedback_deleted_at ON feedback(deleted_at); +CREATE INDEX IF NOT EXISTS idx_feedback_user_id ON feedback(user_id); + +CREATE TABLE IF NOT EXISTS tool ( + id TEXT NOT NULL, + server_name TEXT NOT NULL, + group_kind TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + deleted_at TIMESTAMPTZ, + description TEXT, + PRIMARY KEY (id, server_name, group_kind) +); +CREATE INDEX IF NOT EXISTS idx_tool_deleted_at ON tool(deleted_at); + +CREATE TABLE IF NOT EXISTS toolserver ( + name TEXT NOT NULL, + group_kind TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + deleted_at TIMESTAMPTZ, + description TEXT, + last_connected TIMESTAMPTZ, + PRIMARY KEY (name, group_kind) +); +CREATE INDEX IF NOT EXISTS idx_toolserver_deleted_at ON toolserver(deleted_at); + +CREATE TABLE IF NOT EXISTS lg_checkpoint ( + user_id TEXT NOT NULL, + thread_id TEXT NOT NULL, + checkpoint_ns TEXT NOT NULL DEFAULT '', + checkpoint_id TEXT NOT NULL, + parent_checkpoint_id TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + deleted_at TIMESTAMPTZ, + metadata TEXT NOT NULL, + checkpoint TEXT NOT NULL, + checkpoint_type TEXT NOT NULL, + version INTEGER DEFAULT 1, + PRIMARY KEY (user_id, thread_id, checkpoint_ns, checkpoint_id) +); +CREATE INDEX IF NOT EXISTS idx_lg_checkpoint_parent_checkpoint_id ON lg_checkpoint(parent_checkpoint_id); +CREATE INDEX IF NOT EXISTS idx_lgcp_list ON lg_checkpoint(created_at); +CREATE INDEX IF NOT EXISTS idx_lg_checkpoint_deleted_at ON lg_checkpoint(deleted_at); + +CREATE TABLE IF NOT EXISTS lg_checkpoint_write ( + user_id TEXT NOT NULL, + thread_id TEXT NOT NULL, + checkpoint_ns TEXT NOT NULL DEFAULT '', + checkpoint_id TEXT NOT NULL, + write_idx INTEGER NOT NULL, + value TEXT NOT NULL, + value_type TEXT NOT NULL, + channel TEXT NOT NULL, + task_id TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + deleted_at TIMESTAMPTZ, + PRIMARY KEY (user_id, thread_id, checkpoint_ns, checkpoint_id, write_idx) +); +CREATE INDEX IF NOT EXISTS idx_lg_checkpoint_write_deleted_at ON lg_checkpoint_write(deleted_at); + +CREATE TABLE IF NOT EXISTS crewai_agent_memory ( + user_id TEXT NOT NULL, + thread_id TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + deleted_at TIMESTAMPTZ, + memory_data TEXT NOT NULL, + PRIMARY KEY (user_id, thread_id) +); + +CREATE TABLE IF NOT EXISTS crewai_flow_state ( + user_id TEXT NOT NULL, + thread_id TEXT NOT NULL, + method_name TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + deleted_at TIMESTAMPTZ, + state_data TEXT NOT NULL, + PRIMARY KEY (user_id, thread_id, method_name) +); + +CREATE TABLE IF NOT EXISTS memory ( + id TEXT PRIMARY KEY, + agent_name TEXT, + user_id TEXT, + content TEXT, + embedding vector(768), + metadata TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + expires_at TIMESTAMPTZ, + access_count INTEGER DEFAULT 0 +); +CREATE INDEX IF NOT EXISTS idx_memory_embedding_hnsw ON memory USING hnsw (embedding vector_cosine_ops); +` + +func TestUpgradeFromGORM(t *testing.T) { + if testing.Short() { + t.Skip("skipping upgrade test in short mode") + } + + ctx := context.Background() + connStr := dbtest.StartT(ctx, t) + + // ── Step 1: apply the GORM-era schema ──────────────────────────────────── + rawDB, err := sql.Open("postgres", connStr) + require.NoError(t, err) + t.Cleanup(func() { rawDB.Close() }) + + _, err = rawDB.ExecContext(ctx, gormSchema) + require.NoError(t, err, "GORM schema setup failed") + + // ── Step 2: seed pre-migration data ────────────────────────────────────── + now := time.Now().UTC().Truncate(time.Millisecond) + softDeleted := now.Add(-24 * time.Hour) + + seeds := []struct { + name string + query string + args []any + }{ + { + "agent", + `INSERT INTO agent (id, type, created_at, updated_at) VALUES ($1, $2, $3, $4)`, + []any{"agent-1", "autogen", now, now}, + }, + { + "session", + `INSERT INTO session (id, user_id, name, agent_id, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6)`, + []any{"session-1", "user-1", "test session", "agent-1", now, now}, + }, + { + "event", + `INSERT INTO event (id, user_id, session_id, data, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6)`, + []any{"event-1", "user-1", "session-1", `{"role":"user"}`, now, now}, + }, + { + "task", + `INSERT INTO task (id, session_id, data, created_at, updated_at) VALUES ($1, $2, $3, $4, $5)`, + []any{"task-1", "session-1", `{"id":"task-1"}`, now, now}, + }, + { + "toolserver", + `INSERT INTO toolserver (name, group_kind, description, created_at, updated_at) VALUES ($1, $2, $3, $4, $5)`, + []any{"server-1", "MCPServer.kagent.dev", "test server", now, now}, + }, + { + "tool", + `INSERT INTO tool (id, server_name, group_kind, description, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6)`, + []any{"tool-1", "server-1", "MCPServer.kagent.dev", "a tool", now, now}, + }, + // Soft-deleted CrewAI memory row — simulates GORM's Delete() behaviour. + // After migration the upsert must revive it (deleted_at = NULL). + { + "crewai_agent_memory (soft-deleted)", + `INSERT INTO crewai_agent_memory (user_id, thread_id, memory_data, created_at, updated_at, deleted_at) VALUES ($1, $2, $3, $4, $5, $6)`, + []any{"user-1", "thread-1", `{"task_description":"old task"}`, now, now, softDeleted}, + }, + // Soft-deleted CrewAI flow state row — same scenario. + { + "crewai_flow_state (soft-deleted)", + `INSERT INTO crewai_flow_state (user_id, thread_id, method_name, state_data, created_at, updated_at, deleted_at) VALUES ($1, $2, $3, $4, $5, $6, $7)`, + []any{"user-1", "thread-1", "kickoff", `{"status":"done"}`, now, now, softDeleted}, + }, + // Memory row with a manually supplied ID (old GORM BeforeCreate behaviour). + { + "memory", + `INSERT INTO memory (id, agent_name, user_id, content, embedding, metadata, created_at) VALUES ($1, $2, $3, $4, $5, $6, $7)`, + []any{"mem-1", "agent-1", "user-1", "hello world", pgvector.NewVector(make([]float32, 768)), "{}", now}, + }, + } + + for _, s := range seeds { + _, err := rawDB.ExecContext(ctx, s.query, s.args...) + require.NoError(t, err, "seeding %s failed", s.name) + } + + // ── Step 3: run the new migrations ─────────────────────────────────────── + dbtest.MigrateT(t, connStr, true) + + // ── Step 4: connect via the new client ─────────────────────────────────── + db, err := Connect(ctx, &PostgresConfig{URL: connStr}) + require.NoError(t, err) + client := NewClient(db) + + // ── Step 5: verify pre-existing data is readable ───────────────────────── + + agent, err := client.GetAgent(ctx, "agent-1") + require.NoError(t, err) + assert.Equal(t, "agent-1", agent.ID) + assert.Equal(t, "autogen", agent.Type) + + session, err := client.GetSession(ctx, "session-1", "user-1") + require.NoError(t, err) + assert.Equal(t, "session-1", session.ID) + assert.Equal(t, "agent-1", *session.AgentID) + + events, err := client.ListEventsForSession(ctx, "session-1", "user-1", dbpkg.QueryOptions{}) + require.NoError(t, err) + require.Len(t, events, 1) + assert.Equal(t, "event-1", events[0].ID) + + tasks, err := client.ListTasksForSession(ctx, "session-1") + require.NoError(t, err) + require.Len(t, tasks, 1) + + toolServer, err := client.GetToolServer(ctx, "server-1") + require.NoError(t, err) + assert.Equal(t, "server-1", toolServer.Name) + + tools, err := client.ListToolsForServer(ctx, "server-1", "MCPServer.kagent.dev") + require.NoError(t, err) + require.Len(t, tools, 1) + assert.Equal(t, "tool-1", tools[0].ID) + + // ── Step 6: verify soft-deleted CrewAI rows are revived by upsert ──────── + // Before upsert both rows are invisible (deleted_at IS NOT NULL). + results, err := client.SearchCrewAIMemoryByTask(ctx, "user-1", "thread-1", "old task", 10) + require.NoError(t, err) + assert.Empty(t, results, "soft-deleted memory should be invisible before upsert") + + err = client.StoreCrewAIMemory(ctx, &dbpkg.CrewAIAgentMemory{ + UserID: "user-1", + ThreadID: "thread-1", + MemoryData: `{"task_description":"old task"}`, + }) + require.NoError(t, err) + + results, err = client.SearchCrewAIMemoryByTask(ctx, "user-1", "thread-1", "old task", 10) + require.NoError(t, err) + assert.Len(t, results, 1, "upsert should revive soft-deleted memory row") + + // ── Step 7: verify new writes work (gen_random_uuid() default) ─────────── + embedding := pgvector.NewVector(make([]float32, 768)) + mem := &dbpkg.Memory{ + AgentName: "agent-1", + UserID: "user-1", + Content: "new memory content", + Embedding: embedding, + Metadata: "{}", + } + err = client.StoreAgentMemory(ctx, mem) + require.NoError(t, err) + assert.NotEmpty(t, mem.ID, "StoreAgentMemory should populate ID via gen_random_uuid()") + + memories, err := client.ListAgentMemories(ctx, "agent-1", "user-1") + require.NoError(t, err) + assert.Len(t, memories, 2, "should see the seeded memory row and the new one") +} diff --git a/go/core/internal/dbtest/dbtest.go b/go/core/internal/dbtest/dbtest.go index 59c879b33..b6ea42f4f 100644 --- a/go/core/internal/dbtest/dbtest.go +++ b/go/core/internal/dbtest/dbtest.go @@ -3,10 +3,17 @@ package dbtest import ( "context" + "database/sql" + "errors" "fmt" "testing" "time" + "github.com/golang-migrate/migrate/v4" + migratepg "github.com/golang-migrate/migrate/v4/database/postgres" + "github.com/golang-migrate/migrate/v4/source/iofs" + "github.com/kagent-dev/kagent/go/core/pkg/migrations" + _ "github.com/lib/pq" testcontainers "github.com/testcontainers/testcontainers-go" tcpostgres "github.com/testcontainers/testcontainers-go/modules/postgres" "github.com/testcontainers/testcontainers-go/wait" @@ -58,3 +65,109 @@ func StartT(ctx context.Context, t *testing.T) string { return connStr } + +// Migrate runs the embedded OSS migrations against connStr and returns any error. +// If vectorEnabled is true the vector pass is also applied. +// Use MigrateT in tests that have a *testing.T; use Migrate in TestMain where no T is available. +func Migrate(connStr string, vectorEnabled bool) error { + if err := runMigrationDir(connStr, "core", "schema_migrations"); err != nil { + return fmt.Errorf("core migrations: %w", err) + } + if vectorEnabled { + if err := runMigrationDir(connStr, "vector", "vector_schema_migrations"); err != nil { + return fmt.Errorf("vector migrations: %w", err) + } + } + return nil +} + +// MigrateT runs the embedded OSS migrations against connStr and calls t.Fatal on error. +// If vectorEnabled is true the vector pass is also applied. +func MigrateT(t *testing.T, connStr string, vectorEnabled bool) { + t.Helper() + if err := Migrate(connStr, vectorEnabled); err != nil { + t.Fatalf("dbtest.MigrateT: %v", err) + } +} + +// MigrateDown runs the embedded OSS down-migrations against connStr and returns any error. +// If vectorEnabled is true the vector pass is also rolled back first. +func MigrateDown(connStr string, vectorEnabled bool) error { + if vectorEnabled { + if err := downMigrationDir(connStr, "vector", "vector_schema_migrations"); err != nil { + return fmt.Errorf("vector down migrations: %w", err) + } + } + return downMigrationDir(connStr, "core", "schema_migrations") +} + +func runMigrationDir(connStr, dir, migrationsTable string) error { + db, err := sql.Open("postgres", connStr) + if err != nil { + return fmt.Errorf("open db for %s: %w", dir, err) + } + + src, err := iofs.New(migrations.FS, dir) + if err != nil { + return fmt.Errorf("load migration files from %s: %w", dir, err) + } + + driver, err := migratepg.WithInstance(db, &migratepg.Config{ + MigrationsTable: migrationsTable, + }) + if err != nil { + return fmt.Errorf("create migration driver for %s: %w", dir, err) + } + + mg, err := migrate.NewWithInstance("iofs", src, "postgres", driver) + if err != nil { + return fmt.Errorf("create migrator for %s: %w", dir, err) + } + defer closeMigrate(dir, mg) + + if err := mg.Up(); err != nil && !errors.Is(err, migrate.ErrNoChange) { + return fmt.Errorf("run migrations for %s: %w", dir, err) + } + return nil +} + +func downMigrationDir(connStr, dir, migrationsTable string) error { + db, err := sql.Open("postgres", connStr) + if err != nil { + return fmt.Errorf("open db for %s: %w", dir, err) + } + + src, err := iofs.New(migrations.FS, dir) + if err != nil { + return fmt.Errorf("load migration files from %s: %w", dir, err) + } + + driver, err := migratepg.WithInstance(db, &migratepg.Config{ + MigrationsTable: migrationsTable, + }) + if err != nil { + return fmt.Errorf("create migration driver for %s: %w", dir, err) + } + + mg, err := migrate.NewWithInstance("iofs", src, "postgres", driver) + if err != nil { + return fmt.Errorf("create migrator for %s: %w", dir, err) + } + defer closeMigrate(dir, mg) + + if err := mg.Down(); err != nil && !errors.Is(err, migrate.ErrNoChange) { + return fmt.Errorf("down migrations for %s: %w", dir, err) + } + return nil +} + +// closeMigrate closes mg, logging source and database close errors separately. +func closeMigrate(dir string, mg *migrate.Migrate) { + srcErr, dbErr := mg.Close() + if srcErr != nil { + fmt.Printf("warning: closing migration source for %s: %v\n", dir, srcErr) + } + if dbErr != nil { + fmt.Printf("warning: closing migration database for %s: %v\n", dir, dbErr) + } +} diff --git a/go/core/internal/httpserver/handlers/checkpoints.go b/go/core/internal/httpserver/handlers/checkpoints.go index f0b652724..9664b4727 100644 --- a/go/core/internal/httpserver/handlers/checkpoints.go +++ b/go/core/internal/httpserver/handlers/checkpoints.go @@ -110,7 +110,7 @@ func (h *CheckpointsHandler) HandlePutCheckpoint(w ErrorResponseWriter, r *http. ParentCheckpointID: req.ParentCheckpointID, Metadata: req.Metadata, Checkpoint: req.Checkpoint, - Version: req.Version, + Version: int32(req.Version), CheckpointType: req.Type, } // Store checkpoint and writes atomically @@ -171,7 +171,7 @@ func (h *CheckpointsHandler) HandleListCheckpoints(w ErrorResponseWriter, r *htt for j, write := range tuple.Writes { taskID = write.TaskID writes[j] = KagentCheckpointWrite{ - Idx: write.WriteIdx, + Idx: int(write.WriteIdx), Channel: write.Channel, Type: write.ValueType, Value: write.Value, @@ -232,7 +232,7 @@ func (h *CheckpointsHandler) HandlePutWrites(w ErrorResponseWriter, r *http.Requ ThreadID: req.ThreadID, CheckpointNS: req.CheckpointNS, CheckpointID: req.CheckpointID, - WriteIdx: writeReq.Idx, + WriteIdx: int32(writeReq.Idx), Value: writeReq.Value, ValueType: writeReq.Type, Channel: writeReq.Channel, diff --git a/go/core/internal/httpserver/handlers/memory.go b/go/core/internal/httpserver/handlers/memory.go index 532b42e33..607ad0e7e 100644 --- a/go/core/internal/httpserver/handlers/memory.go +++ b/go/core/internal/httpserver/handlers/memory.go @@ -262,7 +262,7 @@ func (h *MemoryHandler) List(w ErrorResponseWriter, r *http.Request) { item := ListMemoryResponse{ ID: m.ID, Content: m.Content, - AccessCount: m.AccessCount, + AccessCount: int(m.AccessCount), CreatedAt: m.CreatedAt.Format(time.RFC3339), } if m.ExpiresAt != nil { diff --git a/go/core/internal/httpserver/middleware_error.go b/go/core/internal/httpserver/middleware_error.go index 51dd52a4e..8b1ba74ec 100644 --- a/go/core/internal/httpserver/middleware_error.go +++ b/go/core/internal/httpserver/middleware_error.go @@ -1,13 +1,13 @@ package httpserver import ( + "database/sql" "encoding/json" "errors" "net/http" apierrors "github.com/kagent-dev/kagent/go/core/internal/httpserver/errors" "github.com/kagent-dev/kagent/go/core/internal/httpserver/handlers" - "gorm.io/gorm" ctrllog "sigs.k8s.io/controller-runtime/pkg/log" ) @@ -56,7 +56,7 @@ func (w *errorResponseWriter) RespondWithError(err error) { } } - if !errors.Is(err, gorm.ErrRecordNotFound) { + if !errors.Is(err, sql.ErrNoRows) { log.Error(err, message) } else { log.Info(message) diff --git a/go/core/internal/httpserver/server.go b/go/core/internal/httpserver/server.go index bdbfde190..9be889836 100644 --- a/go/core/internal/httpserver/server.go +++ b/go/core/internal/httpserver/server.go @@ -10,7 +10,6 @@ import ( api "github.com/kagent-dev/kagent/go/api/httpapi" "github.com/kagent-dev/kagent/go/core/internal/a2a" "github.com/kagent-dev/kagent/go/core/internal/controller/reconciler" - "github.com/kagent-dev/kagent/go/core/internal/database" "github.com/kagent-dev/kagent/go/core/internal/httpserver/handlers" "github.com/kagent-dev/kagent/go/core/internal/mcp" common "github.com/kagent-dev/kagent/go/core/internal/utils" @@ -71,7 +70,6 @@ type HTTPServer struct { config ServerConfig router *mux.Router handlers *handlers.Handlers - dbManager *database.Manager authenticator auth.AuthProvider } @@ -125,12 +123,6 @@ func (s *HTTPServer) Start(ctx context.Context) error { if err := s.httpServer.Shutdown(shutdownCtx); err != nil { log.Error(err, "Failed to properly shutdown HTTP server") } - // Close database connection - if s.dbManager != nil { - if err := s.dbManager.Close(); err != nil { - log.Error(err, "Failed to close database connection") - } - } }() return nil diff --git a/go/core/pkg/app/app.go b/go/core/pkg/app/app.go index fd141336b..73994f88d 100644 --- a/go/core/pkg/app/app.go +++ b/go/core/pkg/app/app.go @@ -410,26 +410,18 @@ func Start(getExtensionConfig GetExtensionConfig) { os.Exit(1) } - // Initialize database - dbManager, err := database.NewManager(ctx, &database.Config{ - PostgresConfig: &database.PostgresConfig{ - URL: cfg.Database.Url, - URLFile: cfg.Database.UrlFile, - VectorEnabled: cfg.Database.VectorEnabled, - }, + // Connect to database + db, err := database.Connect(ctx, &database.PostgresConfig{ + URL: cfg.Database.Url, + URLFile: cfg.Database.UrlFile, + VectorEnabled: cfg.Database.VectorEnabled, }) if err != nil { - setupLog.Error(err, "unable to initialize database") - os.Exit(1) - } - - // Initialize database tables - if err := dbManager.Initialize(); err != nil { - setupLog.Error(err, "unable to initialize database") + setupLog.Error(err, "unable to connect to database") os.Exit(1) } - dbClient := database.NewClient(dbManager) + dbClient := database.NewClient(db) router := mux.NewRouter() extensionCfg, err := getExtensionConfig(BootstrapConfig{ Ctx: ctx, diff --git a/go/core/pkg/env/database.go b/go/core/pkg/env/database.go deleted file mode 100644 index 312839345..000000000 --- a/go/core/pkg/env/database.go +++ /dev/null @@ -1,11 +0,0 @@ -package env - -// Database environment variables. -var ( - GormLogLevel = RegisterStringVar( - "GORM_LOG_LEVEL", - "silent", - "GORM database logging level. Valid values: error, warn, info, silent.", - ComponentDatabase, - ) -) diff --git a/go/core/pkg/migrations/core/000001_initial.down.sql b/go/core/pkg/migrations/core/000001_initial.down.sql new file mode 100644 index 000000000..fd1f2bc68 --- /dev/null +++ b/go/core/pkg/migrations/core/000001_initial.down.sql @@ -0,0 +1,12 @@ +DROP TABLE IF EXISTS crewai_flow_state; +DROP TABLE IF EXISTS crewai_agent_memory; +DROP TABLE IF EXISTS lg_checkpoint_write; +DROP TABLE IF EXISTS lg_checkpoint; +DROP TABLE IF EXISTS toolserver; +DROP TABLE IF EXISTS tool; +DROP TABLE IF EXISTS feedback; +DROP TABLE IF EXISTS push_notification; +DROP TABLE IF EXISTS task; +DROP TABLE IF EXISTS event; +DROP TABLE IF EXISTS session; +DROP TABLE IF EXISTS agent; diff --git a/go/core/pkg/migrations/core/000001_initial.up.sql b/go/core/pkg/migrations/core/000001_initial.up.sql new file mode 100644 index 000000000..193f076a4 --- /dev/null +++ b/go/core/pkg/migrations/core/000001_initial.up.sql @@ -0,0 +1,161 @@ +-- Baseline migration: matches the schema produced by GORM AutoMigrate in the +-- prior kagent release. Each subsequent migration applies additive changes. + +CREATE TABLE IF NOT EXISTS agent ( + id TEXT PRIMARY KEY, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + deleted_at TIMESTAMPTZ, + type TEXT NOT NULL, + config JSONB +); +CREATE INDEX IF NOT EXISTS idx_agent_deleted_at ON agent(deleted_at); + +CREATE TABLE IF NOT EXISTS session ( + id TEXT NOT NULL, + user_id TEXT NOT NULL, + name TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + deleted_at TIMESTAMPTZ, + agent_id TEXT, + PRIMARY KEY (id, user_id) +); +CREATE INDEX IF NOT EXISTS idx_session_name ON session(name); +CREATE INDEX IF NOT EXISTS idx_session_agent_id ON session(agent_id); +CREATE INDEX IF NOT EXISTS idx_session_deleted_at ON session(deleted_at); + +CREATE TABLE IF NOT EXISTS event ( + id TEXT NOT NULL, + user_id TEXT NOT NULL, + session_id TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + deleted_at TIMESTAMPTZ, + data TEXT NOT NULL, + PRIMARY KEY (id, user_id) +); +CREATE INDEX IF NOT EXISTS idx_event_session_id ON event(session_id); +CREATE INDEX IF NOT EXISTS idx_event_deleted_at ON event(deleted_at); + +CREATE TABLE IF NOT EXISTS task ( + id TEXT PRIMARY KEY, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + deleted_at TIMESTAMPTZ, + data TEXT NOT NULL, + session_id TEXT +); +CREATE INDEX IF NOT EXISTS idx_task_session_id ON task(session_id); +CREATE INDEX IF NOT EXISTS idx_task_deleted_at ON task(deleted_at); + +CREATE TABLE IF NOT EXISTS push_notification ( + id TEXT PRIMARY KEY, + task_id TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + deleted_at TIMESTAMPTZ, + data TEXT NOT NULL +); +CREATE INDEX IF NOT EXISTS idx_push_notification_task_id ON push_notification(task_id); +CREATE INDEX IF NOT EXISTS idx_push_notification_deleted_at ON push_notification(deleted_at); + +CREATE TABLE IF NOT EXISTS feedback ( + id BIGSERIAL PRIMARY KEY, + created_at TIMESTAMPTZ, + updated_at TIMESTAMPTZ, + deleted_at TIMESTAMPTZ, + user_id TEXT NOT NULL, + message_id BIGINT, + is_positive BOOLEAN DEFAULT false, + feedback_text TEXT NOT NULL, + issue_type TEXT +); +CREATE INDEX IF NOT EXISTS idx_feedback_deleted_at ON feedback(deleted_at); +CREATE INDEX IF NOT EXISTS idx_feedback_user_id ON feedback(user_id); +CREATE INDEX IF NOT EXISTS idx_feedback_message_id ON feedback(message_id); + +CREATE TABLE IF NOT EXISTS tool ( + id TEXT NOT NULL, + server_name TEXT NOT NULL, + group_kind TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + deleted_at TIMESTAMPTZ, + description TEXT, + PRIMARY KEY (id, server_name, group_kind) +); +CREATE INDEX IF NOT EXISTS idx_tool_deleted_at ON tool(deleted_at); + +CREATE TABLE IF NOT EXISTS toolserver ( + name TEXT NOT NULL, + group_kind TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + deleted_at TIMESTAMPTZ, + description TEXT, + last_connected TIMESTAMPTZ, + PRIMARY KEY (name, group_kind) +); +CREATE INDEX IF NOT EXISTS idx_toolserver_deleted_at ON toolserver(deleted_at); + +CREATE TABLE IF NOT EXISTS lg_checkpoint ( + user_id TEXT NOT NULL, + thread_id TEXT NOT NULL, + checkpoint_ns TEXT NOT NULL DEFAULT '', + checkpoint_id TEXT NOT NULL, + parent_checkpoint_id TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + deleted_at TIMESTAMPTZ, + metadata TEXT NOT NULL, + checkpoint TEXT NOT NULL, + checkpoint_type TEXT NOT NULL, + version INTEGER DEFAULT 1, + PRIMARY KEY (user_id, thread_id, checkpoint_ns, checkpoint_id) +); +CREATE INDEX IF NOT EXISTS idx_lg_checkpoint_parent_checkpoint_id ON lg_checkpoint(parent_checkpoint_id); +CREATE INDEX IF NOT EXISTS idx_lgcp_list ON lg_checkpoint(created_at); +CREATE INDEX IF NOT EXISTS idx_lg_checkpoint_deleted_at ON lg_checkpoint(deleted_at); + +CREATE TABLE IF NOT EXISTS lg_checkpoint_write ( + user_id TEXT NOT NULL, + thread_id TEXT NOT NULL, + checkpoint_ns TEXT NOT NULL DEFAULT '', + checkpoint_id TEXT NOT NULL, + write_idx INTEGER NOT NULL, + value TEXT NOT NULL, + value_type TEXT NOT NULL, + channel TEXT NOT NULL, + task_id TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + deleted_at TIMESTAMPTZ, + PRIMARY KEY (user_id, thread_id, checkpoint_ns, checkpoint_id, write_idx) +); +CREATE INDEX IF NOT EXISTS idx_lg_checkpoint_write_deleted_at ON lg_checkpoint_write(deleted_at); + +CREATE TABLE IF NOT EXISTS crewai_agent_memory ( + user_id TEXT NOT NULL, + thread_id TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + deleted_at TIMESTAMPTZ, + memory_data TEXT NOT NULL, + PRIMARY KEY (user_id, thread_id) +); +CREATE INDEX IF NOT EXISTS idx_crewai_memory_list ON crewai_agent_memory(created_at); +CREATE INDEX IF NOT EXISTS idx_crewai_agent_memory_deleted_at ON crewai_agent_memory(deleted_at); + +CREATE TABLE IF NOT EXISTS crewai_flow_state ( + user_id TEXT NOT NULL, + thread_id TEXT NOT NULL, + method_name TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + deleted_at TIMESTAMPTZ, + state_data TEXT NOT NULL, + PRIMARY KEY (user_id, thread_id, method_name) +); +CREATE INDEX IF NOT EXISTS idx_crewai_flow_state_list ON crewai_flow_state(created_at); +CREATE INDEX IF NOT EXISTS idx_crewai_flow_state_deleted_at ON crewai_flow_state(deleted_at); diff --git a/go/core/pkg/migrations/core/000002_add_session_source.down.sql b/go/core/pkg/migrations/core/000002_add_session_source.down.sql new file mode 100644 index 000000000..0ef080534 --- /dev/null +++ b/go/core/pkg/migrations/core/000002_add_session_source.down.sql @@ -0,0 +1,2 @@ +DROP INDEX IF EXISTS idx_session_source; +ALTER TABLE session DROP COLUMN IF EXISTS source; diff --git a/go/core/pkg/migrations/core/000002_add_session_source.up.sql b/go/core/pkg/migrations/core/000002_add_session_source.up.sql new file mode 100644 index 000000000..ca940b2f3 --- /dev/null +++ b/go/core/pkg/migrations/core/000002_add_session_source.up.sql @@ -0,0 +1,6 @@ +-- Add session.source column, introduced after the initial GORM-managed schema. +-- ALTER TABLE ... ADD COLUMN IF NOT EXISTS is idempotent: a no-op on fresh installs +-- where migration 000001 created the table without this column, and adds the column +-- on existing GORM-managed deployments upgrading to golang-migrate. +ALTER TABLE session ADD COLUMN IF NOT EXISTS source TEXT; +CREATE INDEX IF NOT EXISTS idx_session_source ON session(source); diff --git a/go/core/pkg/migrations/migrations.go b/go/core/pkg/migrations/migrations.go new file mode 100644 index 000000000..48746a7f6 --- /dev/null +++ b/go/core/pkg/migrations/migrations.go @@ -0,0 +1,9 @@ +// Package migrations exports the embedded SQL migration files for the kagent OSS +// database schema. Enterprise builds import this FS to bundle OSS migrations +// alongside enterprise-specific ones at build time. +package migrations + +import "embed" + +//go:embed core vector +var FS embed.FS diff --git a/go/core/pkg/migrations/vector/000001_vector_support.down.sql b/go/core/pkg/migrations/vector/000001_vector_support.down.sql new file mode 100644 index 000000000..b403931d8 --- /dev/null +++ b/go/core/pkg/migrations/vector/000001_vector_support.down.sql @@ -0,0 +1,2 @@ +DROP TABLE IF EXISTS memory; +DROP EXTENSION IF EXISTS vector; diff --git a/go/core/pkg/migrations/vector/000001_vector_support.up.sql b/go/core/pkg/migrations/vector/000001_vector_support.up.sql new file mode 100644 index 000000000..97cc13d98 --- /dev/null +++ b/go/core/pkg/migrations/vector/000001_vector_support.up.sql @@ -0,0 +1,17 @@ +CREATE EXTENSION IF NOT EXISTS vector; + +-- Matches the schema GORM AutoMigrate produced for the Memory struct. +-- GORM does not create HNSW indexes automatically; that is added in migration 000002. +CREATE TABLE IF NOT EXISTS memory ( + id TEXT PRIMARY KEY, + agent_name TEXT, + user_id TEXT, + content TEXT, + embedding vector(768), + metadata TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + expires_at TIMESTAMPTZ, + access_count INTEGER DEFAULT 0 +); +CREATE INDEX IF NOT EXISTS idx_memory_agent_user ON memory(agent_name, user_id); +CREATE INDEX IF NOT EXISTS idx_memory_expires_at ON memory(expires_at); diff --git a/go/core/pkg/migrations/vector/000002_add_memory_hnsw_index.down.sql b/go/core/pkg/migrations/vector/000002_add_memory_hnsw_index.down.sql new file mode 100644 index 000000000..ee9bc90ed --- /dev/null +++ b/go/core/pkg/migrations/vector/000002_add_memory_hnsw_index.down.sql @@ -0,0 +1 @@ +DROP INDEX IF EXISTS idx_memory_embedding_hnsw; diff --git a/go/core/pkg/migrations/vector/000002_add_memory_hnsw_index.up.sql b/go/core/pkg/migrations/vector/000002_add_memory_hnsw_index.up.sql new file mode 100644 index 000000000..ad4ddfdc1 --- /dev/null +++ b/go/core/pkg/migrations/vector/000002_add_memory_hnsw_index.up.sql @@ -0,0 +1,4 @@ +-- Add HNSW index for fast approximate nearest-neighbor vector similarity search. +-- GORM did not create this index automatically; pgvector's HNSW index significantly +-- outperforms IVFFlat for production workloads (better recall, no reindex on insert). +CREATE INDEX IF NOT EXISTS idx_memory_embedding_hnsw ON memory USING hnsw (embedding vector_cosine_ops); diff --git a/go/core/pkg/migrations/vector/000003_memory_uuid_default.down.sql b/go/core/pkg/migrations/vector/000003_memory_uuid_default.down.sql new file mode 100644 index 000000000..d636b8066 --- /dev/null +++ b/go/core/pkg/migrations/vector/000003_memory_uuid_default.down.sql @@ -0,0 +1 @@ +ALTER TABLE memory ALTER COLUMN id DROP DEFAULT; diff --git a/go/core/pkg/migrations/vector/000003_memory_uuid_default.up.sql b/go/core/pkg/migrations/vector/000003_memory_uuid_default.up.sql new file mode 100644 index 000000000..f04772e30 --- /dev/null +++ b/go/core/pkg/migrations/vector/000003_memory_uuid_default.up.sql @@ -0,0 +1 @@ +ALTER TABLE memory ALTER COLUMN id SET DEFAULT gen_random_uuid(); diff --git a/go/go.mod b/go/go.mod index ae67a02b6..b002eb934 100644 --- a/go/go.mod +++ b/go/go.mod @@ -20,8 +20,10 @@ require ( github.com/go-logr/zapr v1.3.0 // api dependencies github.com/google/uuid v1.6.0 + github.com/golang-migrate/migrate/v4 v4.19.1 github.com/gorilla/mux v1.8.1 github.com/hashicorp/go-multierror v1.1.1 + github.com/lib/pq v1.11.2 github.com/jedib0t/go-pretty/v6 v6.7.8 github.com/kagent-dev/kmcp v0.2.7 github.com/kagent-dev/mockllm v0.0.5 @@ -48,8 +50,6 @@ require ( google.golang.org/genai v1.40.0 google.golang.org/protobuf v1.36.11 gopkg.in/yaml.v3 v3.0.1 - gorm.io/driver/postgres v1.6.0 - gorm.io/gorm v1.31.1 k8s.io/api v0.35.1 k8s.io/apimachinery v0.35.1 k8s.io/client-go v0.35.1 @@ -140,12 +140,6 @@ require ( github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 // indirect github.com/hashicorp/errwrap v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect - github.com/jackc/pgpassfile v1.0.0 // indirect - github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect - github.com/jackc/pgx/v5 v5.7.2 // indirect - github.com/jackc/puddle/v2 v2.2.2 // indirect - github.com/jinzhu/inflection v1.0.0 // indirect - github.com/jinzhu/now v1.1.5 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/compress v1.18.2 // indirect diff --git a/go/go.sum b/go/go.sum index 876582ce6..606bde40f 100644 --- a/go/go.sum +++ b/go/go.sum @@ -124,6 +124,8 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 h1:NMZiJj8QnKe1LgsbDayM4UoHwbvwDRwnI3hwNaAHRnc= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0/go.mod h1:ZXNYxsqcloTdSy/rNShjYzMhyjf0LaoftYK0p+A3h40= +github.com/dhui/dktest v0.4.6 h1:+DPKyScKSEp3VLtbMDHcUq6V5Lm5zfZZVb0Sk7Ahom4= +github.com/dhui/dktest v0.4.6/go.mod h1:JHTSYDtKkvFNFHJKqCzVzqXecyv+tKt8EzceOmQOgbU= github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI= @@ -193,6 +195,8 @@ github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= +github.com/golang-migrate/migrate/v4 v4.19.1 h1:OCyb44lFuQfYXYLx1SCxPZQGU7mcaZ7gH9yH4jSFbBA= +github.com/golang-migrate/migrate/v4 v4.19.1/go.mod h1:CTcgfjxhaUtsLipnLoQRWCrjYXycRz/g5+RWDuYgPrE= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg= @@ -280,8 +284,8 @@ github.com/lestrrat-go/jwx/v2 v2.1.4 h1:uBCMmJX8oRZStmKuMMOFb0Yh9xmEMgNJLgjuKKt4 github.com/lestrrat-go/jwx/v2 v2.1.4/go.mod h1:nWRbDFR1ALG2Z6GJbBXzfQaYyvn751KuuyySN2yR6is= github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU= github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I= -github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= -github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/lib/pq v1.11.2 h1:x6gxUeu39V0BHZiugWe8LXZYZ+Utk7hSJGThs8sdzfs= +github.com/lib/pq v1.11.2/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA= github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4= @@ -340,6 +344,7 @@ github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/onsi/ginkgo v1.16.4 h1:29JGrr5oVBm5ulCWet69zQkzWipVXIol6ygQUe/EzNc= github.com/onsi/ginkgo/v2 v2.27.2 h1:LzwLj0b89qtIy6SSASkzlNvX6WktqurSHwkk2ipF/Ns= github.com/onsi/ginkgo/v2 v2.27.2/go.mod h1:ArE1D/XhNXBXCBkKOLkbsb2c81dQHCRcF5zwn/ykDRo= github.com/onsi/gomega v1.38.2 h1:eZCjf2xjZAqe+LeWvKb5weQ+NcPwX84kqJ0cZNxok2A= @@ -593,10 +598,10 @@ gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gorm.io/driver/postgres v1.6.0 h1:2dxzU8xJ+ivvqTRph34QX+WrRaJlmfyPqXmoGVjMBa4= -gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo= -gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg= -gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= +gorm.io/driver/postgres v1.5.4 h1:Iyrp9Meh3GmbSuyIAGyjkN+n9K+GHX9b9MqsTL4EJCo= +gorm.io/driver/postgres v1.5.4/go.mod h1:Bgo89+h0CRcdA33Y6frlaHHVuTdOf87pmyzwW9C/BH0= +gorm.io/gorm v1.31.0 h1:0VlycGreVhK7RF/Bwt51Fk8v0xLiiiFdbGDPIZQ7mJY= +gorm.io/gorm v1.31.0/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q= gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA= k8s.io/api v0.35.1 h1:0PO/1FhlK/EQNVK5+txc4FuhQibV25VLSdLMmGpDE/Q= diff --git a/helm/kagent/templates/controller-deployment.yaml b/helm/kagent/templates/controller-deployment.yaml index 3d264913a..eace760ea 100644 --- a/helm/kagent/templates/controller-deployment.yaml +++ b/helm/kagent/templates/controller-deployment.yaml @@ -44,6 +44,36 @@ spec: tolerations: {{- toYaml . | nindent 8 }} {{- end }} + initContainers: + - name: migrate + image: "{{ .Values.controller.migrate.image.registry | default .Values.registry }}/{{ .Values.controller.migrate.image.repository }}:{{ coalesce .Values.tag .Values.controller.migrate.image.tag .Chart.Version }}" + imagePullPolicy: {{ .Values.controller.migrate.image.pullPolicy | default .Values.imagePullPolicy }} + env: + {{- if .Values.database.postgres.urlFile }} + - name: POSTGRES_DATABASE_URL_FILE + value: {{ .Values.database.postgres.urlFile | quote }} + {{- else if .Values.database.postgres.url }} + - name: POSTGRES_DATABASE_URL + value: {{ .Values.database.postgres.url | quote }} + {{- else if .Values.database.postgres.bundled.enabled }} + - name: POSTGRES_PASSWORD + valueFrom: + secretKeyRef: + name: {{ include "kagent.passwordSecretName" . }} + key: POSTGRES_PASSWORD + - name: POSTGRES_DATABASE_URL + value: {{ printf "postgres://kagent:$(POSTGRES_PASSWORD)@%s.%s.svc.cluster.local:5432/kagent?sslmode=disable" (include "kagent.postgresqlServiceName" .) (include "kagent.namespace" .) | quote }} + {{- else }} + {{ fail "No database connection configured. Set database.postgres.url, database.postgres.urlFile, or enable database.postgres.bundled." }} + {{- end }} + - name: KAGENT_DATABASE_VECTOR_ENABLED + value: {{ .Values.database.postgres.vectorEnabled | default false | quote }} + {{- if gt (len .Values.controller.volumeMounts) 0 }} + volumeMounts: + {{- with .Values.controller.volumeMounts }} + {{- toYaml . | nindent 12 }} + {{- end }} + {{- end }} containers: - name: controller image: "{{ .Values.controller.image.registry | default .Values.registry }}/{{ .Values.controller.image.repository }}:{{ coalesce .Values.tag .Values.controller.image.tag .Chart.Version }}" diff --git a/helm/kagent/values.yaml b/helm/kagent/values.yaml index d8605aaf2..1875b70e4 100644 --- a/helm/kagent/values.yaml +++ b/helm/kagent/values.yaml @@ -165,6 +165,12 @@ controller: repository: kagent-dev/kagent/controller tag: "" # Will default to global, then Chart version pullPolicy: "" + migrate: + image: + registry: "" + repository: kagent-dev/kagent/migrate + tag: "" # Will default to global, then Chart version + pullPolicy: "" resources: requests: cpu: 100m diff --git a/python/samples/langgraph/currency/currency/agent.py b/python/samples/langgraph/currency/currency/agent.py index a7964cee5..890c5139e 100644 --- a/python/samples/langgraph/currency/currency/agent.py +++ b/python/samples/langgraph/currency/currency/agent.py @@ -38,6 +38,7 @@ def _get_exchange_rate( response = httpx.get( f"https://api.frankfurter.app/{currency_date}", params={"from": currency_from, "to": currency_to}, + timeout=30.0, ) response.raise_for_status() From fbde2c711688f413fc4646500dacfa28fdca2e5b Mon Sep 17 00:00:00 2001 From: Jeremy Alvis Date: Fri, 27 Mar 2026 14:55:26 -0700 Subject: [PATCH 02/16] Remove unnecessary timout change Signed-off-by: Jeremy Alvis --- python/samples/langgraph/currency/currency/agent.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/samples/langgraph/currency/currency/agent.py b/python/samples/langgraph/currency/currency/agent.py index 890c5139e..a7964cee5 100644 --- a/python/samples/langgraph/currency/currency/agent.py +++ b/python/samples/langgraph/currency/currency/agent.py @@ -38,7 +38,6 @@ def _get_exchange_rate( response = httpx.get( f"https://api.frankfurter.app/{currency_date}", params={"from": currency_from, "to": currency_to}, - timeout=30.0, ) response.raise_for_status() From 7f5d7144aaae185a6edbf9644d4fd42b63739333 Mon Sep 17 00:00:00 2001 From: Jeremy Alvis Date: Mon, 30 Mar 2026 10:26:12 -0700 Subject: [PATCH 03/16] Use pgx instead of lib/sql Signed-off-by: Jeremy Alvis --- go/core/cmd/migrate/main.go | 8 +- go/core/cmd/migrate/main_test.go | 4 +- go/core/internal/database/client_postgres.go | 475 ++++++++---------- go/core/internal/database/client_test.go | 18 +- go/core/internal/database/connect.go | 38 +- go/core/internal/database/connect_test.go | 2 +- go/core/internal/database/fake/client.go | 14 +- go/core/internal/database/gen/agents.sql.go | 11 +- go/core/internal/database/gen/crewai.sql.go | 18 +- go/core/internal/database/gen/db.go | 13 +- go/core/internal/database/gen/events.sql.go | 51 +- go/core/internal/database/gen/feedback.sql.go | 12 +- .../internal/database/gen/langgraph.sql.go | 30 +- go/core/internal/database/gen/memory.sql.go | 60 +-- go/core/internal/database/gen/models.go | 65 ++- .../database/gen/push_notifications.sql.go | 11 +- go/core/internal/database/gen/querier.go | 5 +- go/core/internal/database/gen/sessions.sql.go | 25 +- go/core/internal/database/gen/tasks.sql.go | 18 +- go/core/internal/database/gen/tools.sql.go | 35 +- go/core/internal/database/sqlc.yaml | 12 + go/core/internal/database/testhelpers_test.go | 6 +- go/core/internal/database/upgrade_test.go | 2 +- go/core/internal/dbtest/dbtest.go | 12 +- .../internal/httpserver/middleware_error.go | 4 +- go/go.mod | 9 +- go/go.sum | 6 +- 27 files changed, 447 insertions(+), 517 deletions(-) diff --git a/go/core/cmd/migrate/main.go b/go/core/cmd/migrate/main.go index f0f25e373..3d4717aea 100644 --- a/go/core/cmd/migrate/main.go +++ b/go/core/cmd/migrate/main.go @@ -37,10 +37,10 @@ import ( "strings" "github.com/golang-migrate/migrate/v4" - migratepg "github.com/golang-migrate/migrate/v4/database/postgres" + migratepgx "github.com/golang-migrate/migrate/v4/database/pgx/v5" "github.com/golang-migrate/migrate/v4/source/iofs" + _ "github.com/jackc/pgx/v5/stdlib" "github.com/kagent-dev/kagent/go/core/pkg/migrations" - _ "github.com/lib/pq" ) func main() { @@ -248,7 +248,7 @@ func resolveURL() (string, error) { // newMigrate opens a database connection and constructs a migrate.Migrate for the given dir/table. // The caller is responsible for calling closeMigrate on the returned instance. func newMigrate(url string, migrationsFS fs.FS, dir, migrationsTable string) (*migrate.Migrate, error) { - db, err := sql.Open("postgres", url) + db, err := sql.Open("pgx", url) if err != nil { return nil, fmt.Errorf("open database for %s: %w", dir, err) } @@ -258,7 +258,7 @@ func newMigrate(url string, migrationsFS fs.FS, dir, migrationsTable string) (*m return nil, fmt.Errorf("load migration files from %s: %w", dir, err) } - driver, err := migratepg.WithInstance(db, &migratepg.Config{ + driver, err := migratepgx.WithInstance(db, &migratepgx.Config{ MigrationsTable: migrationsTable, }) if err != nil { diff --git a/go/core/cmd/migrate/main_test.go b/go/core/cmd/migrate/main_test.go index 0e17321a4..87055d65f 100644 --- a/go/core/cmd/migrate/main_test.go +++ b/go/core/cmd/migrate/main_test.go @@ -8,8 +8,8 @@ import ( "testing" "testing/fstest" + _ "github.com/jackc/pgx/v5/stdlib" "github.com/kagent-dev/kagent/go/core/internal/dbtest" - _ "github.com/lib/pq" ) // --- migration fixtures --- @@ -61,7 +61,7 @@ func mergeFS(fsMaps ...fstest.MapFS) fstest.MapFS { // Returns 0 if the table is empty or does not exist (fully rolled back). func trackVersion(t *testing.T, connStr, table string) uint { t.Helper() - db, err := sql.Open("postgres", connStr) + db, err := sql.Open("pgx", connStr) if err != nil { t.Fatalf("trackVersion: open db: %v", err) } diff --git a/go/core/internal/database/client_postgres.go b/go/core/internal/database/client_postgres.go index cd8e424e2..a80a7f002 100644 --- a/go/core/internal/database/client_postgres.go +++ b/go/core/internal/database/client_postgres.go @@ -2,13 +2,13 @@ package database import ( "context" - "database/sql" "encoding/json" "errors" "fmt" "strings" - "time" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" dbpkg "github.com/kagent-dev/kagent/go/api/database" "github.com/kagent-dev/kagent/go/api/v1alpha2" dbgen "github.com/kagent-dev/kagent/go/core/internal/database/gen" @@ -18,16 +18,28 @@ import ( type postgresClient struct { q *dbgen.Queries - db *sql.DB + db *pgxpool.Pool } -func NewClient(db *sql.DB) dbpkg.Client { +func NewClient(db *pgxpool.Pool) dbpkg.Client { return &postgresClient{ q: dbgen.New(db), db: db, } } +func (c *postgresClient) withTx(ctx context.Context, fn func(*dbgen.Queries) error) error { + tx, err := c.db.Begin(ctx) + if err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } + defer tx.Rollback(ctx) //nolint:errcheck + if err := fn(c.q.WithTx(tx)); err != nil { + return err + } + return tx.Commit(ctx) +} + // ── Agents ──────────────────────────────────────────────────────────────────── func (c *postgresClient) StoreAgent(ctx context.Context, agent *dbpkg.Agent) error { @@ -68,11 +80,12 @@ func (c *postgresClient) StoreSession(ctx context.Context, session *dbpkg.Sessio params := dbgen.UpsertSessionParams{ ID: session.ID, UserID: session.UserID, - Name: ptrToNullString(session.Name), - AgentID: ptrToNullString(session.AgentID), + Name: session.Name, + AgentID: session.AgentID, } if session.Source != nil { - params.Source = sql.NullString{String: string(*session.Source), Valid: true} + src := string(*session.Source) + params.Source = &src } return c.q.UpsertSession(ctx, params) } @@ -99,7 +112,7 @@ func (c *postgresClient) ListSessions(ctx context.Context, userID string) ([]dbp func (c *postgresClient) ListSessionsForAgent(ctx context.Context, agentID, userID string) ([]dbpkg.Session, error) { rows, err := c.q.ListSessionsForAgent(ctx, dbgen.ListSessionsForAgentParams{ - AgentID: ptrToNullString(&agentID), + AgentID: &agentID, UserID: userID, }) if err != nil { @@ -123,7 +136,7 @@ func (c *postgresClient) StoreEvents(ctx context.Context, events ...*dbpkg.Event if err := c.q.InsertEvent(ctx, dbgen.InsertEventParams{ ID: e.ID, UserID: e.UserID, - SessionID: sql.NullString{String: e.SessionID, Valid: e.SessionID != ""}, + SessionID: strPtrIfNotEmpty(e.SessionID), Data: e.Data, }); err != nil { return fmt.Errorf("failed to store event %s: %w", e.ID, err) @@ -135,24 +148,24 @@ func (c *postgresClient) StoreEvents(ctx context.Context, events ...*dbpkg.Event func (c *postgresClient) ListEventsForSession(ctx context.Context, sessionID, userID string, opts dbpkg.QueryOptions) ([]*dbpkg.Event, error) { var rows []dbgen.Event var err error - nullSessionID := sql.NullString{String: sessionID, Valid: sessionID != ""} + sessionIDPtr := strPtrIfNotEmpty(sessionID) switch { case opts.OrderAsc && opts.Limit > 0: rows, err = c.q.ListEventsForSessionAscLimit(ctx, dbgen.ListEventsForSessionAscLimitParams{ - SessionID: nullSessionID, UserID: userID, Column3: opts.After, Limit: int32(opts.Limit), + SessionID: sessionIDPtr, UserID: userID, Column3: opts.After, Limit: int32(opts.Limit), }) case opts.OrderAsc: rows, err = c.q.ListEventsForSessionAsc(ctx, dbgen.ListEventsForSessionAscParams{ - SessionID: nullSessionID, UserID: userID, Column3: opts.After, + SessionID: sessionIDPtr, UserID: userID, Column3: opts.After, }) case opts.Limit > 0: rows, err = c.q.ListEventsForSessionDescLimit(ctx, dbgen.ListEventsForSessionDescLimitParams{ - SessionID: nullSessionID, UserID: userID, Column3: opts.After, Limit: int32(opts.Limit), + SessionID: sessionIDPtr, UserID: userID, Column3: opts.After, Limit: int32(opts.Limit), }) default: rows, err = c.q.ListEventsForSessionDesc(ctx, dbgen.ListEventsForSessionDescParams{ - SessionID: nullSessionID, UserID: userID, Column3: opts.After, + SessionID: sessionIDPtr, UserID: userID, Column3: opts.After, }) } if err != nil { @@ -176,7 +189,7 @@ func (c *postgresClient) StoreTask(ctx context.Context, task *protocol.Task) err return c.q.UpsertTask(ctx, dbgen.UpsertTaskParams{ ID: task.ID, Data: string(data), - SessionID: sql.NullString{String: task.ContextID, Valid: task.ContextID != ""}, + SessionID: strPtrIfNotEmpty(task.ContextID), }) } @@ -193,7 +206,7 @@ func (c *postgresClient) GetTask(ctx context.Context, taskID string) (*protocol. } func (c *postgresClient) ListTasksForSession(ctx context.Context, sessionID string) ([]*protocol.Task, error) { - rows, err := c.q.ListTasksForSession(ctx, sql.NullString{String: sessionID, Valid: true}) + rows, err := c.q.ListTasksForSession(ctx, &sessionID) if err != nil { return nil, fmt.Errorf("failed to list tasks for session: %w", err) } @@ -257,10 +270,11 @@ func (c *postgresClient) DeletePushNotification(ctx context.Context, taskID stri // ── Feedback ────────────────────────────────────────────────────────────────── func (c *postgresClient) StoreFeedback(ctx context.Context, feedback *dbpkg.Feedback) error { + isPositive := feedback.IsPositive _, err := c.q.InsertFeedback(ctx, dbgen.InsertFeedbackParams{ UserID: feedback.UserID, - MessageID: ptrToNullInt64(feedback.MessageID), - IsPositive: sql.NullBool{Bool: feedback.IsPositive, Valid: true}, + MessageID: feedback.MessageID, + IsPositive: &isPositive, FeedbackText: feedback.FeedbackText, IssueType: feedback.IssueType, }) @@ -318,32 +332,24 @@ func (c *postgresClient) DeleteToolsForServer(ctx context.Context, serverName, g } func (c *postgresClient) RefreshToolsForServer(ctx context.Context, serverName, groupKind string, tools ...*v1alpha2.MCPTool) error { - tx, err := c.db.BeginTx(ctx, nil) - if err != nil { - return fmt.Errorf("failed to begin transaction: %w", err) - } - defer tx.Rollback() //nolint:errcheck - - q := c.q.WithTx(tx) - - if err := q.SoftDeleteToolsForServer(ctx, dbgen.SoftDeleteToolsForServerParams{ - ServerName: serverName, GroupKind: groupKind, - }); err != nil { - return fmt.Errorf("failed to delete existing tools: %w", err) - } - - for _, tool := range tools { - if err := q.UpsertTool(ctx, dbgen.UpsertToolParams{ - ID: tool.Name, - ServerName: serverName, - GroupKind: groupKind, - Description: sql.NullString{String: tool.Description, Valid: true}, + return c.withTx(ctx, func(q *dbgen.Queries) error { + if err := q.SoftDeleteToolsForServer(ctx, dbgen.SoftDeleteToolsForServerParams{ + ServerName: serverName, GroupKind: groupKind, }); err != nil { - return fmt.Errorf("failed to upsert tool %s: %w", tool.Name, err) + return fmt.Errorf("failed to delete existing tools: %w", err) } - } - - return tx.Commit() + for _, tool := range tools { + if err := q.UpsertTool(ctx, dbgen.UpsertToolParams{ + ID: tool.Name, + ServerName: serverName, + GroupKind: groupKind, + Description: &tool.Description, + }); err != nil { + return fmt.Errorf("failed to upsert tool %s: %w", tool.Name, err) + } + } + return nil + }) } func (c *postgresClient) GetToolServer(ctx context.Context, name string) (*dbpkg.ToolServer, error) { @@ -370,8 +376,8 @@ func (c *postgresClient) StoreToolServer(ctx context.Context, ts *dbpkg.ToolServ row, err := c.q.UpsertToolServer(ctx, dbgen.UpsertToolServerParams{ Name: ts.Name, GroupKind: ts.GroupKind, - Description: sql.NullString{String: ts.Description, Valid: true}, - LastConnected: ptrToNullTime(ts.LastConnected), + Description: &ts.Description, + LastConnected: ts.LastConnected, }) if err != nil { return nil, fmt.Errorf("failed to store tool server: %w", err) @@ -386,115 +392,102 @@ func (c *postgresClient) DeleteToolServer(ctx context.Context, serverName, group // ── LangGraph Checkpoints ───────────────────────────────────────────────────── func (c *postgresClient) StoreCheckpoint(ctx context.Context, cp *dbpkg.LangGraphCheckpoint) error { + version := cp.Version return c.q.UpsertCheckpoint(ctx, dbgen.UpsertCheckpointParams{ UserID: cp.UserID, ThreadID: cp.ThreadID, CheckpointNs: cp.CheckpointNS, CheckpointID: cp.CheckpointID, - ParentCheckpointID: ptrToNullString(cp.ParentCheckpointID), + ParentCheckpointID: cp.ParentCheckpointID, Metadata: cp.Metadata, Checkpoint: cp.Checkpoint, CheckpointType: cp.CheckpointType, - Version: sql.NullInt32{Int32: cp.Version, Valid: true}, + Version: &version, }) } func (c *postgresClient) StoreCheckpointWrites(ctx context.Context, writes []*dbpkg.LangGraphCheckpointWrite) error { - tx, err := c.db.BeginTx(ctx, nil) - if err != nil { - return fmt.Errorf("failed to begin transaction: %w", err) - } - defer tx.Rollback() //nolint:errcheck - - q := c.q.WithTx(tx) - for _, w := range writes { - if err := q.UpsertCheckpointWrite(ctx, dbgen.UpsertCheckpointWriteParams{ - UserID: w.UserID, - ThreadID: w.ThreadID, - CheckpointNs: w.CheckpointNS, - CheckpointID: w.CheckpointID, - WriteIdx: w.WriteIdx, - Value: w.Value, - ValueType: w.ValueType, - Channel: w.Channel, - TaskID: w.TaskID, - }); err != nil { - return fmt.Errorf("failed to store checkpoint write: %w", err) + return c.withTx(ctx, func(q *dbgen.Queries) error { + for _, w := range writes { + if err := q.UpsertCheckpointWrite(ctx, dbgen.UpsertCheckpointWriteParams{ + UserID: w.UserID, + ThreadID: w.ThreadID, + CheckpointNs: w.CheckpointNS, + CheckpointID: w.CheckpointID, + WriteIdx: w.WriteIdx, + Value: w.Value, + ValueType: w.ValueType, + Channel: w.Channel, + TaskID: w.TaskID, + }); err != nil { + return fmt.Errorf("failed to store checkpoint write: %w", err) + } } - } - return tx.Commit() + return nil + }) } func (c *postgresClient) ListCheckpoints(ctx context.Context, userID, threadID, checkpointNS string, checkpointID *string, limit int) ([]*dbpkg.LangGraphCheckpointTuple, error) { - tx, err := c.db.BeginTx(ctx, nil) - if err != nil { - return nil, fmt.Errorf("failed to begin transaction: %w", err) - } - defer tx.Rollback() //nolint:errcheck - - q := c.q.WithTx(tx) - - var checkpoints []dbgen.LgCheckpoint - if checkpointID != nil { - cp, err := q.GetCheckpoint(ctx, dbgen.GetCheckpointParams{ - UserID: userID, ThreadID: threadID, CheckpointNs: checkpointNS, CheckpointID: *checkpointID, - }) - if err != nil { - return nil, fmt.Errorf("failed to get checkpoint: %w", err) - } - checkpoints = []dbgen.LgCheckpoint{cp} - } else if limit > 0 { - checkpoints, err = q.ListCheckpointsLimit(ctx, dbgen.ListCheckpointsLimitParams{ - UserID: userID, ThreadID: threadID, CheckpointNs: checkpointNS, Limit: int32(limit), - }) - if err != nil { - return nil, fmt.Errorf("failed to list checkpoints: %w", err) + var tuples []*dbpkg.LangGraphCheckpointTuple + err := c.withTx(ctx, func(q *dbgen.Queries) error { + var checkpoints []dbgen.LgCheckpoint + var err error + if checkpointID != nil { + cp, err := q.GetCheckpoint(ctx, dbgen.GetCheckpointParams{ + UserID: userID, ThreadID: threadID, CheckpointNs: checkpointNS, CheckpointID: *checkpointID, + }) + if err != nil { + return fmt.Errorf("failed to get checkpoint: %w", err) + } + checkpoints = []dbgen.LgCheckpoint{cp} + } else if limit > 0 { + checkpoints, err = q.ListCheckpointsLimit(ctx, dbgen.ListCheckpointsLimitParams{ + UserID: userID, ThreadID: threadID, CheckpointNs: checkpointNS, Limit: int32(limit), + }) + if err != nil { + return fmt.Errorf("failed to list checkpoints: %w", err) + } + } else { + checkpoints, err = q.ListCheckpoints(ctx, dbgen.ListCheckpointsParams{ + UserID: userID, ThreadID: threadID, CheckpointNs: checkpointNS, + }) + if err != nil { + return fmt.Errorf("failed to list checkpoints: %w", err) + } } - } else { - checkpoints, err = q.ListCheckpoints(ctx, dbgen.ListCheckpointsParams{ - UserID: userID, ThreadID: threadID, CheckpointNs: checkpointNS, - }) - if err != nil { - return nil, fmt.Errorf("failed to list checkpoints: %w", err) - } - } - tuples := make([]*dbpkg.LangGraphCheckpointTuple, 0, len(checkpoints)) - for _, cp := range checkpoints { - writes, err := q.ListCheckpointWrites(ctx, dbgen.ListCheckpointWritesParams{ - UserID: userID, ThreadID: threadID, CheckpointNs: checkpointNS, CheckpointID: cp.CheckpointID, - }) - if err != nil { - return nil, fmt.Errorf("failed to get checkpoint writes: %w", err) - } - dbWrites := make([]*dbpkg.LangGraphCheckpointWrite, len(writes)) - for i, w := range writes { - dbWrites[i] = toCheckpointWrite(w) + tuples = make([]*dbpkg.LangGraphCheckpointTuple, 0, len(checkpoints)) + for _, cp := range checkpoints { + writes, err := q.ListCheckpointWrites(ctx, dbgen.ListCheckpointWritesParams{ + UserID: userID, ThreadID: threadID, CheckpointNs: checkpointNS, CheckpointID: cp.CheckpointID, + }) + if err != nil { + return fmt.Errorf("failed to get checkpoint writes: %w", err) + } + dbWrites := make([]*dbpkg.LangGraphCheckpointWrite, len(writes)) + for i, w := range writes { + dbWrites[i] = toCheckpointWrite(w) + } + tuples = append(tuples, &dbpkg.LangGraphCheckpointTuple{ + Checkpoint: toCheckpoint(cp), + Writes: dbWrites, + }) } - tuples = append(tuples, &dbpkg.LangGraphCheckpointTuple{ - Checkpoint: toCheckpoint(cp), - Writes: dbWrites, - }) - } - - return tuples, tx.Commit() + return nil + }) + return tuples, err } func (c *postgresClient) DeleteCheckpoint(ctx context.Context, userID, threadID string) error { - tx, err := c.db.BeginTx(ctx, nil) - if err != nil { - return fmt.Errorf("failed to begin transaction: %w", err) - } - defer tx.Rollback() //nolint:errcheck - - q := c.q.WithTx(tx) - if err := q.SoftDeleteCheckpoints(ctx, dbgen.SoftDeleteCheckpointsParams{UserID: userID, ThreadID: threadID}); err != nil { - return fmt.Errorf("failed to delete checkpoints: %w", err) - } - if err := q.SoftDeleteCheckpointWrites(ctx, dbgen.SoftDeleteCheckpointWritesParams{UserID: userID, ThreadID: threadID}); err != nil { - return fmt.Errorf("failed to delete checkpoint writes: %w", err) - } - return tx.Commit() + return c.withTx(ctx, func(q *dbgen.Queries) error { + if err := q.SoftDeleteCheckpoints(ctx, dbgen.SoftDeleteCheckpointsParams{UserID: userID, ThreadID: threadID}); err != nil { + return fmt.Errorf("failed to delete checkpoints: %w", err) + } + if err := q.SoftDeleteCheckpointWrites(ctx, dbgen.SoftDeleteCheckpointWritesParams{UserID: userID, ThreadID: threadID}); err != nil { + return fmt.Errorf("failed to delete checkpoint writes: %w", err) + } + return nil + }) } // ── CrewAI ──────────────────────────────────────────────────────────────────── @@ -548,7 +541,7 @@ func (c *postgresClient) StoreCrewAIFlowState(ctx context.Context, state *dbpkg. func (c *postgresClient) GetCrewAIFlowState(ctx context.Context, userID, threadID string) (*dbpkg.CrewAIFlowState, error) { row, err := c.q.GetLatestCrewAIFlowState(ctx, dbgen.GetLatestCrewAIFlowStateParams{UserID: userID, ThreadID: threadID}) if err != nil { - if errors.Is(err, sql.ErrNoRows) { + if errors.Is(err, pgx.ErrNoRows) { return nil, nil } return nil, fmt.Errorf("failed to get CrewAI flow state: %w", err) @@ -560,13 +553,13 @@ func (c *postgresClient) GetCrewAIFlowState(ctx context.Context, userID, threadI func (c *postgresClient) StoreAgentMemory(ctx context.Context, memory *dbpkg.Memory) error { id, err := c.q.InsertMemory(ctx, dbgen.InsertMemoryParams{ - AgentName: sql.NullString{String: memory.AgentName, Valid: true}, - UserID: sql.NullString{String: memory.UserID, Valid: true}, - Content: sql.NullString{String: memory.Content, Valid: true}, + AgentName: &memory.AgentName, + UserID: &memory.UserID, + Content: &memory.Content, Embedding: memory.Embedding, - Metadata: sql.NullString{String: memory.Metadata, Valid: true}, - ExpiresAt: ptrToNullTime(memory.ExpiresAt), - AccessCount: sql.NullInt32{Int32: memory.AccessCount, Valid: true}, + Metadata: &memory.Metadata, + ExpiresAt: memory.ExpiresAt, + AccessCount: &memory.AccessCount, }) if err != nil { return err @@ -576,36 +569,31 @@ func (c *postgresClient) StoreAgentMemory(ctx context.Context, memory *dbpkg.Mem } func (c *postgresClient) StoreAgentMemories(ctx context.Context, memories []*dbpkg.Memory) error { - tx, err := c.db.BeginTx(ctx, nil) - if err != nil { - return fmt.Errorf("failed to begin transaction: %w", err) - } - defer tx.Rollback() //nolint:errcheck - - q := c.q.WithTx(tx) - for _, m := range memories { - id, err := q.InsertMemory(ctx, dbgen.InsertMemoryParams{ - AgentName: sql.NullString{String: m.AgentName, Valid: true}, - UserID: sql.NullString{String: m.UserID, Valid: true}, - Content: sql.NullString{String: m.Content, Valid: true}, - Embedding: m.Embedding, - Metadata: sql.NullString{String: m.Metadata, Valid: true}, - ExpiresAt: ptrToNullTime(m.ExpiresAt), - AccessCount: sql.NullInt32{Int32: m.AccessCount, Valid: true}, - }) - if err != nil { - return fmt.Errorf("failed to store memory: %w", err) + return c.withTx(ctx, func(q *dbgen.Queries) error { + for _, m := range memories { + id, err := q.InsertMemory(ctx, dbgen.InsertMemoryParams{ + AgentName: &m.AgentName, + UserID: &m.UserID, + Content: &m.Content, + Embedding: m.Embedding, + Metadata: &m.Metadata, + ExpiresAt: m.ExpiresAt, + AccessCount: &m.AccessCount, + }) + if err != nil { + return fmt.Errorf("failed to store memory: %w", err) + } + m.ID = id } - m.ID = id - } - return tx.Commit() + return nil + }) } func (c *postgresClient) SearchAgentMemory(ctx context.Context, agentName, userID string, embedding pgvector.Vector, limit int) ([]dbpkg.AgentMemorySearchResult, error) { rows, err := c.q.SearchAgentMemory(ctx, dbgen.SearchAgentMemoryParams{ Embedding: embedding, - AgentName: sql.NullString{String: agentName, Valid: true}, - UserID: sql.NullString{String: userID, Valid: true}, + AgentName: &agentName, + UserID: &userID, Limit: int32(limit), }) if err != nil { @@ -618,14 +606,14 @@ func (c *postgresClient) SearchAgentMemory(ctx context.Context, agentName, userI results[i] = dbpkg.AgentMemorySearchResult{ Memory: dbpkg.Memory{ ID: r.ID, - AgentName: r.AgentName.String, - UserID: r.UserID.String, - Content: r.Content.String, + AgentName: derefStr(r.AgentName), + UserID: derefStr(r.UserID), + Content: derefStr(r.Content), Embedding: r.Embedding, - Metadata: r.Metadata.String, + Metadata: derefStr(r.Metadata), CreatedAt: r.CreatedAt, - ExpiresAt: nullTimeToPtr(r.ExpiresAt), - AccessCount: nullInt32ToVal(r.AccessCount), + ExpiresAt: r.ExpiresAt, + AccessCount: derefInt32(r.AccessCount), }, Score: score, } @@ -647,9 +635,9 @@ func (c *postgresClient) SearchAgentMemory(ctx context.Context, agentName, userI func (c *postgresClient) ListAgentMemories(ctx context.Context, agentName, userID string) ([]dbpkg.Memory, error) { normalized := strings.ReplaceAll(agentName, "-", "_") rows, err := c.q.ListAgentMemories(ctx, dbgen.ListAgentMemoriesParams{ - AgentName: sql.NullString{String: agentName, Valid: true}, - AgentName_2: sql.NullString{String: normalized, Valid: true}, - UserID: sql.NullString{String: userID, Valid: true}, + AgentName: &agentName, + AgentName_2: &normalized, + UserID: &userID, }) if err != nil { return nil, fmt.Errorf("failed to list agent memories: %w", err) @@ -663,16 +651,16 @@ func (c *postgresClient) ListAgentMemories(ctx context.Context, agentName, userI func (c *postgresClient) DeleteAgentMemory(ctx context.Context, agentName, userID string) error { if err := c.q.DeleteAgentMemory(ctx, dbgen.DeleteAgentMemoryParams{ - AgentName: sql.NullString{String: agentName, Valid: true}, - UserID: sql.NullString{String: userID, Valid: true}, + AgentName: &agentName, + UserID: &userID, }); err != nil { return fmt.Errorf("failed to delete agent memory: %w", err) } normalized := strings.ReplaceAll(agentName, "-", "_") if normalized != agentName { if err := c.q.DeleteAgentMemory(ctx, dbgen.DeleteAgentMemoryParams{ - AgentName: sql.NullString{String: normalized, Valid: true}, - UserID: sql.NullString{String: userID, Valid: true}, + AgentName: &normalized, + UserID: &userID, }); err != nil { return fmt.Errorf("failed to delete normalized agent memory: %w", err) } @@ -681,20 +669,15 @@ func (c *postgresClient) DeleteAgentMemory(ctx context.Context, agentName, userI } func (c *postgresClient) PruneExpiredMemories(ctx context.Context) error { - tx, err := c.db.BeginTx(ctx, nil) - if err != nil { - return fmt.Errorf("failed to begin transaction: %w", err) - } - defer tx.Rollback() //nolint:errcheck - - q := c.q.WithTx(tx) - if err := q.ExtendMemoryTTL(ctx); err != nil { - return fmt.Errorf("failed to extend TTL for popular memories: %w", err) - } - if err := q.DeleteExpiredMemories(ctx); err != nil { - return fmt.Errorf("failed to delete expired memories: %w", err) - } - return tx.Commit() + return c.withTx(ctx, func(q *dbgen.Queries) error { + if err := q.ExtendMemoryTTL(ctx); err != nil { + return fmt.Errorf("failed to extend TTL for popular memories: %w", err) + } + if err := q.DeleteExpiredMemories(ctx); err != nil { + return fmt.Errorf("failed to delete expired memories: %w", err) + } + return nil + }) } // ── Conversion helpers ──────────────────────────────────────────────────────── @@ -704,7 +687,7 @@ func toAgent(r dbgen.Agent) *dbpkg.Agent { ID: r.ID, CreatedAt: r.CreatedAt, UpdatedAt: r.UpdatedAt, - DeletedAt: nullTimeToPtr(r.DeletedAt), + DeletedAt: r.DeletedAt, Type: r.Type, Config: r.Config, } @@ -714,14 +697,14 @@ func toSession(r dbgen.Session) *dbpkg.Session { s := &dbpkg.Session{ ID: r.ID, UserID: r.UserID, - Name: nullStringToPtr(r.Name), + Name: r.Name, CreatedAt: r.CreatedAt, UpdatedAt: r.UpdatedAt, - DeletedAt: nullTimeToPtr(r.DeletedAt), - AgentID: nullStringToPtr(r.AgentID), + DeletedAt: r.DeletedAt, + AgentID: r.AgentID, } - if r.Source.Valid { - src := dbpkg.SessionSource(r.Source.String) + if r.Source != nil { + src := dbpkg.SessionSource(*r.Source) s.Source = &src } return s @@ -731,10 +714,10 @@ func toEvent(r dbgen.Event) *dbpkg.Event { return &dbpkg.Event{ ID: r.ID, UserID: r.UserID, - SessionID: r.SessionID.String, + SessionID: derefStr(r.SessionID), CreatedAt: r.CreatedAt, UpdatedAt: r.UpdatedAt, - DeletedAt: nullTimeToPtr(r.DeletedAt), + DeletedAt: r.DeletedAt, Data: r.Data, } } @@ -744,21 +727,21 @@ func toTask(r dbgen.Task) *dbpkg.Task { ID: r.ID, CreatedAt: r.CreatedAt, UpdatedAt: r.UpdatedAt, - DeletedAt: nullTimeToPtr(r.DeletedAt), + DeletedAt: r.DeletedAt, Data: r.Data, - SessionID: r.SessionID.String, + SessionID: derefStr(r.SessionID), } } func toFeedback(r dbgen.Feedback) *dbpkg.Feedback { return &dbpkg.Feedback{ ID: r.ID, - CreatedAt: nullTimeToPtr(r.CreatedAt), - UpdatedAt: nullTimeToPtr(r.UpdatedAt), - DeletedAt: nullTimeToPtr(r.DeletedAt), + CreatedAt: r.CreatedAt, + UpdatedAt: r.UpdatedAt, + DeletedAt: r.DeletedAt, UserID: r.UserID, - MessageID: nullInt64ToPtr(r.MessageID), - IsPositive: r.IsPositive.Bool, + MessageID: r.MessageID, + IsPositive: derefBool(r.IsPositive), FeedbackText: r.FeedbackText, IssueType: r.IssueType, } @@ -771,8 +754,8 @@ func toTool(r dbgen.Tool) *dbpkg.Tool { GroupKind: r.GroupKind, CreatedAt: r.CreatedAt, UpdatedAt: r.UpdatedAt, - DeletedAt: nullTimeToPtr(r.DeletedAt), - Description: r.Description.String, + DeletedAt: r.DeletedAt, + Description: derefStr(r.Description), } } @@ -782,9 +765,9 @@ func toToolServer(r dbgen.Toolserver) *dbpkg.ToolServer { GroupKind: r.GroupKind, CreatedAt: r.CreatedAt, UpdatedAt: r.UpdatedAt, - DeletedAt: nullTimeToPtr(r.DeletedAt), - Description: r.Description.String, - LastConnected: nullTimeToPtr(r.LastConnected), + DeletedAt: r.DeletedAt, + Description: derefStr(r.Description), + LastConnected: r.LastConnected, } } @@ -794,14 +777,14 @@ func toCheckpoint(r dbgen.LgCheckpoint) *dbpkg.LangGraphCheckpoint { ThreadID: r.ThreadID, CheckpointNS: r.CheckpointNs, CheckpointID: r.CheckpointID, - ParentCheckpointID: nullStringToPtr(r.ParentCheckpointID), + ParentCheckpointID: r.ParentCheckpointID, CreatedAt: r.CreatedAt, UpdatedAt: r.UpdatedAt, - DeletedAt: nullTimeToPtr(r.DeletedAt), + DeletedAt: r.DeletedAt, Metadata: r.Metadata, Checkpoint: r.Checkpoint, CheckpointType: r.CheckpointType, - Version: r.Version.Int32, + Version: derefInt32(r.Version), } } @@ -818,7 +801,7 @@ func toCheckpointWrite(r dbgen.LgCheckpointWrite) *dbpkg.LangGraphCheckpointWrit TaskID: r.TaskID, CreatedAt: r.CreatedAt, UpdatedAt: r.UpdatedAt, - DeletedAt: nullTimeToPtr(r.DeletedAt), + DeletedAt: r.DeletedAt, } } @@ -828,7 +811,7 @@ func toCrewAIMemory(r dbgen.CrewaiAgentMemory) *dbpkg.CrewAIAgentMemory { ThreadID: r.ThreadID, CreatedAt: r.CreatedAt, UpdatedAt: r.UpdatedAt, - DeletedAt: nullTimeToPtr(r.DeletedAt), + DeletedAt: r.DeletedAt, MemoryData: r.MemoryData, } } @@ -840,7 +823,7 @@ func toCrewAIFlowState(r dbgen.CrewaiFlowState) *dbpkg.CrewAIFlowState { MethodName: r.MethodName, CreatedAt: r.CreatedAt, UpdatedAt: r.UpdatedAt, - DeletedAt: nullTimeToPtr(r.DeletedAt), + DeletedAt: r.DeletedAt, StateData: r.StateData, } } @@ -848,61 +831,43 @@ func toCrewAIFlowState(r dbgen.CrewaiFlowState) *dbpkg.CrewAIFlowState { func toMemory(r dbgen.Memory) *dbpkg.Memory { return &dbpkg.Memory{ ID: r.ID, - AgentName: r.AgentName.String, - UserID: r.UserID.String, - Content: r.Content.String, + AgentName: derefStr(r.AgentName), + UserID: derefStr(r.UserID), + Content: derefStr(r.Content), Embedding: r.Embedding, - Metadata: r.Metadata.String, + Metadata: derefStr(r.Metadata), CreatedAt: r.CreatedAt, - ExpiresAt: nullTimeToPtr(r.ExpiresAt), - AccessCount: nullInt32ToVal(r.AccessCount), - } -} - -// ── sql.Null* helpers ───────────────────────────────────────────────────────── - -func nullStringToPtr(s sql.NullString) *string { - if s.Valid { - return &s.String + ExpiresAt: r.ExpiresAt, + AccessCount: derefInt32(r.AccessCount), } - return nil } -func nullTimeToPtr(t sql.NullTime) *time.Time { - if t.Valid { - return &t.Time - } - return nil -} +// ── Pointer helpers ─────────────────────────────────────────────────────────── -func nullInt64ToPtr(n sql.NullInt64) *int64 { - if n.Valid { - return &n.Int64 +func strPtrIfNotEmpty(s string) *string { + if s == "" { + return nil } - return nil -} - -func nullInt32ToVal(n sql.NullInt32) int32 { - return n.Int32 + return &s } -func ptrToNullString(s *string) sql.NullString { +func derefStr(s *string) string { if s != nil { - return sql.NullString{String: *s, Valid: true} + return *s } - return sql.NullString{} + return "" } -func ptrToNullTime(t *time.Time) sql.NullTime { - if t != nil { - return sql.NullTime{Time: *t, Valid: true} +func derefInt32(n *int32) int32 { + if n != nil { + return *n } - return sql.NullTime{} + return 0 } -func ptrToNullInt64(n *int64) sql.NullInt64 { - if n != nil { - return sql.NullInt64{Int64: *n, Valid: true} +func derefBool(b *bool) bool { + if b != nil { + return *b } - return sql.NullInt64{} + return false } diff --git a/go/core/internal/database/client_test.go b/go/core/internal/database/client_test.go index e24639983..ab57a95bd 100644 --- a/go/core/internal/database/client_test.go +++ b/go/core/internal/database/client_test.go @@ -2,15 +2,14 @@ package database import ( "context" - "database/sql" "fmt" "sync" "testing" "time" + "github.com/jackc/pgx/v5/pgxpool" dbpkg "github.com/kagent-dev/kagent/go/api/database" "github.com/kagent-dev/kagent/go/api/v1alpha2" - "github.com/kagent-dev/kagent/go/core/internal/dbtest" "github.com/pgvector/pgvector-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -210,14 +209,23 @@ func TestStoreToolServerIdempotence(t *testing.T) { } // setupTestDB resets the shared Postgres database's tables for test isolation. -func setupTestDB(t *testing.T) *sql.DB { +func setupTestDB(t *testing.T) *pgxpool.Pool { t.Helper() if testing.Short() { t.Skip("skipping database test in short mode") } - require.NoError(t, dbtest.MigrateDown(sharedConnStr, true), "Failed to reset test database (down)") - require.NoError(t, dbtest.Migrate(sharedConnStr, true), "Failed to reset test database (up)") + // Truncate application tables instead of full down+up migrations. + // Full down migration drops and recreates the pgvector extension, which + // changes type OIDs and breaks existing pool connections. + _, err := sharedDB.Exec(context.Background(), ` + TRUNCATE TABLE + agent, session, event, task, push_notification, feedback, + tool, toolserver, lg_checkpoint, lg_checkpoint_write, + crewai_agent_memory, crewai_flow_state, memory + RESTART IDENTITY CASCADE + `) + require.NoError(t, err, "Failed to truncate test tables") return sharedDB } diff --git a/go/core/internal/database/connect.go b/go/core/internal/database/connect.go index 169207073..8fd3e2ee0 100644 --- a/go/core/internal/database/connect.go +++ b/go/core/internal/database/connect.go @@ -2,14 +2,15 @@ package database import ( "context" - "database/sql" "fmt" "log" "os" "strings" "time" - _ "github.com/lib/pq" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + pgvectorpgx "github.com/pgvector/pgvector-go/pgx" ) // PostgresConfig holds the connection parameters for a Postgres database. @@ -26,9 +27,9 @@ const ( ) // Connect opens a Postgres connection using cfg, resolving the URL from a file -// if URLFile is set, and retries PingContext with exponential backoff until the +// if URLFile is set, and retries Ping with exponential backoff until the // connection succeeds or defaultMaxTimeout elapses. -func Connect(ctx context.Context, cfg *PostgresConfig) (*sql.DB, error) { +func Connect(ctx context.Context, cfg *PostgresConfig) (*pgxpool.Pool, error) { url := cfg.URL if cfg.URLFile != "" { resolved, err := resolveURLFile(cfg.URLFile) @@ -37,31 +38,42 @@ func Connect(ctx context.Context, cfg *PostgresConfig) (*sql.DB, error) { } url = resolved } - return retryDBConnection(ctx, url) + return retryDBConnection(ctx, url, cfg.VectorEnabled) } -// retryDBConnection opens a database connection and retries PingContext with -// exponential backoff until the connection succeeds or defaultMaxTimeout elapses. -func retryDBConnection(ctx context.Context, url string) (*sql.DB, error) { +// retryDBConnection opens a pgxpool connection, registering pgvector types when +// vectorEnabled is true, and retries Ping with exponential backoff until the +// connection succeeds or defaultMaxTimeout elapses. +func retryDBConnection(ctx context.Context, url string, vectorEnabled bool) (*pgxpool.Pool, error) { ctx, cancel := context.WithTimeout(ctx, defaultMaxTimeout) defer cancel() - db, err := sql.Open("postgres", url) + config, err := pgxpool.ParseConfig(url) + if err != nil { + return nil, fmt.Errorf("failed to parse database URL: %w", err) + } + if vectorEnabled { + config.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error { + return pgvectorpgx.RegisterTypes(ctx, conn) + } + } + + pool, err := pgxpool.NewWithConfig(ctx, config) if err != nil { - return nil, fmt.Errorf("failed to open database: %w", err) + return nil, fmt.Errorf("failed to create database pool: %w", err) } start := time.Now() delay := defaultInitialDelay for attempt := 1; ; attempt++ { - if err := db.PingContext(ctx); err == nil { - return db, nil + if err := pool.Ping(ctx); err == nil { + return pool, nil } else { log.Printf("database not ready (attempt %d, elapsed %s): %v", attempt, time.Since(start).Round(time.Second), err) } select { case <-ctx.Done(): - _ = db.Close() + pool.Close() return nil, fmt.Errorf("database not ready after %s: %w", time.Since(start).Round(time.Second), ctx.Err()) case <-time.After(delay): } diff --git a/go/core/internal/database/connect_test.go b/go/core/internal/database/connect_test.go index ab60bb1ab..681e58f32 100644 --- a/go/core/internal/database/connect_test.go +++ b/go/core/internal/database/connect_test.go @@ -14,7 +14,7 @@ func TestRetryDBConnection_DeadlineExceeded(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() - _, err := retryDBConnection(ctx, "postgres://user:pass@localhost:1/nodb?connect_timeout=1") + _, err := retryDBConnection(ctx, "postgres://user:pass@localhost:1/nodb?connect_timeout=1", false) assert.ErrorIs(t, err, context.DeadlineExceeded) } diff --git a/go/core/internal/database/fake/client.go b/go/core/internal/database/fake/client.go index 838f59c6c..86005a61d 100644 --- a/go/core/internal/database/fake/client.go +++ b/go/core/internal/database/fake/client.go @@ -2,7 +2,6 @@ package fake import ( "context" - "database/sql" "encoding/json" "fmt" "math" @@ -11,6 +10,7 @@ import ( "sync" "time" + "github.com/jackc/pgx/v5" "github.com/kagent-dev/kagent/go/api/database" "github.com/kagent-dev/kagent/go/api/v1alpha2" "github.com/pgvector/pgvector-go" @@ -83,7 +83,7 @@ func (c *InMemoryFakeClient) GetTask(_ context.Context, taskID string) (*protoco task, exists := c.tasks[taskID] if !exists { - return nil, sql.ErrNoRows + return nil, pgx.ErrNoRows } parsedTask := &protocol.Task{} err := json.Unmarshal([]byte(task.Data), parsedTask) @@ -209,7 +209,7 @@ func (c *InMemoryFakeClient) DeleteAgent(_ context.Context, agentName string) er _, exists := c.agents[agentName] if !exists { - return sql.ErrNoRows + return pgx.ErrNoRows } delete(c.agents, agentName) @@ -248,7 +248,7 @@ func (c *InMemoryFakeClient) GetSession(_ context.Context, sessionID string, use key := c.sessionKey(sessionID, userID) session, exists := c.sessions[key] if !exists { - return nil, sql.ErrNoRows + return nil, pgx.ErrNoRows } return session, nil } @@ -260,7 +260,7 @@ func (c *InMemoryFakeClient) GetAgent(_ context.Context, agentName string) (*dat agent, exists := c.agents[agentName] if !exists { - return nil, sql.ErrNoRows + return nil, pgx.ErrNoRows } return agent, nil } @@ -272,7 +272,7 @@ func (c *InMemoryFakeClient) GetTool(_ context.Context, toolName string) (*datab tool, exists := c.tools[toolName] if !exists { - return nil, sql.ErrNoRows + return nil, pgx.ErrNoRows } return tool, nil } @@ -284,7 +284,7 @@ func (c *InMemoryFakeClient) GetToolServer(_ context.Context, serverName string) server, exists := c.toolServers[serverName] if !exists { - return nil, sql.ErrNoRows + return nil, pgx.ErrNoRows } return server, nil } diff --git a/go/core/internal/database/gen/agents.sql.go b/go/core/internal/database/gen/agents.sql.go index 61975dd4f..e987bbc63 100644 --- a/go/core/internal/database/gen/agents.sql.go +++ b/go/core/internal/database/gen/agents.sql.go @@ -18,7 +18,7 @@ LIMIT 1 ` func (q *Queries) GetAgent(ctx context.Context, id string) (Agent, error) { - row := q.db.QueryRowContext(ctx, getAgent, id) + row := q.db.QueryRow(ctx, getAgent, id) var i Agent err := row.Scan( &i.ID, @@ -38,7 +38,7 @@ ORDER BY created_at ASC ` func (q *Queries) ListAgents(ctx context.Context) ([]Agent, error) { - rows, err := q.db.QueryContext(ctx, listAgents) + rows, err := q.db.Query(ctx, listAgents) if err != nil { return nil, err } @@ -58,9 +58,6 @@ func (q *Queries) ListAgents(ctx context.Context) ([]Agent, error) { } items = append(items, i) } - if err := rows.Close(); err != nil { - return nil, err - } if err := rows.Err(); err != nil { return nil, err } @@ -72,7 +69,7 @@ UPDATE agent SET deleted_at = NOW() WHERE id = $1 AND deleted_at IS NULL ` func (q *Queries) SoftDeleteAgent(ctx context.Context, id string) error { - _, err := q.db.ExecContext(ctx, softDeleteAgent, id) + _, err := q.db.Exec(ctx, softDeleteAgent, id) return err } @@ -93,6 +90,6 @@ type UpsertAgentParams struct { } func (q *Queries) UpsertAgent(ctx context.Context, arg UpsertAgentParams) error { - _, err := q.db.ExecContext(ctx, upsertAgent, arg.ID, arg.Type, arg.Config) + _, err := q.db.Exec(ctx, upsertAgent, arg.ID, arg.Type, arg.Config) return err } diff --git a/go/core/internal/database/gen/crewai.sql.go b/go/core/internal/database/gen/crewai.sql.go index a3fd02a75..7424fb6d7 100644 --- a/go/core/internal/database/gen/crewai.sql.go +++ b/go/core/internal/database/gen/crewai.sql.go @@ -22,7 +22,7 @@ type GetLatestCrewAIFlowStateParams struct { } func (q *Queries) GetLatestCrewAIFlowState(ctx context.Context, arg GetLatestCrewAIFlowStateParams) (CrewaiFlowState, error) { - row := q.db.QueryRowContext(ctx, getLatestCrewAIFlowState, arg.UserID, arg.ThreadID) + row := q.db.QueryRow(ctx, getLatestCrewAIFlowState, arg.UserID, arg.ThreadID) var i CrewaiFlowState err := row.Scan( &i.UserID, @@ -47,7 +47,7 @@ type HardDeleteCrewAIMemoryParams struct { } func (q *Queries) HardDeleteCrewAIMemory(ctx context.Context, arg HardDeleteCrewAIMemoryParams) error { - _, err := q.db.ExecContext(ctx, hardDeleteCrewAIMemory, arg.UserID, arg.ThreadID) + _, err := q.db.Exec(ctx, hardDeleteCrewAIMemory, arg.UserID, arg.ThreadID) return err } @@ -65,7 +65,7 @@ type SearchCrewAIMemoryByTaskParams struct { } func (q *Queries) SearchCrewAIMemoryByTask(ctx context.Context, arg SearchCrewAIMemoryByTaskParams) ([]CrewaiAgentMemory, error) { - rows, err := q.db.QueryContext(ctx, searchCrewAIMemoryByTask, arg.UserID, arg.ThreadID, arg.MemoryData) + rows, err := q.db.Query(ctx, searchCrewAIMemoryByTask, arg.UserID, arg.ThreadID, arg.MemoryData) if err != nil { return nil, err } @@ -85,9 +85,6 @@ func (q *Queries) SearchCrewAIMemoryByTask(ctx context.Context, arg SearchCrewAI } items = append(items, i) } - if err := rows.Close(); err != nil { - return nil, err - } if err := rows.Err(); err != nil { return nil, err } @@ -110,7 +107,7 @@ type SearchCrewAIMemoryByTaskLimitParams struct { } func (q *Queries) SearchCrewAIMemoryByTaskLimit(ctx context.Context, arg SearchCrewAIMemoryByTaskLimitParams) ([]CrewaiAgentMemory, error) { - rows, err := q.db.QueryContext(ctx, searchCrewAIMemoryByTaskLimit, + rows, err := q.db.Query(ctx, searchCrewAIMemoryByTaskLimit, arg.UserID, arg.ThreadID, arg.MemoryData, @@ -135,9 +132,6 @@ func (q *Queries) SearchCrewAIMemoryByTaskLimit(ctx context.Context, arg SearchC } items = append(items, i) } - if err := rows.Close(); err != nil { - return nil, err - } if err := rows.Err(); err != nil { return nil, err } @@ -161,7 +155,7 @@ type UpsertCrewAIFlowStateParams struct { } func (q *Queries) UpsertCrewAIFlowState(ctx context.Context, arg UpsertCrewAIFlowStateParams) error { - _, err := q.db.ExecContext(ctx, upsertCrewAIFlowState, + _, err := q.db.Exec(ctx, upsertCrewAIFlowState, arg.UserID, arg.ThreadID, arg.MethodName, @@ -186,6 +180,6 @@ type UpsertCrewAIMemoryParams struct { } func (q *Queries) UpsertCrewAIMemory(ctx context.Context, arg UpsertCrewAIMemoryParams) error { - _, err := q.db.ExecContext(ctx, upsertCrewAIMemory, arg.UserID, arg.ThreadID, arg.MemoryData) + _, err := q.db.Exec(ctx, upsertCrewAIMemory, arg.UserID, arg.ThreadID, arg.MemoryData) return err } diff --git a/go/core/internal/database/gen/db.go b/go/core/internal/database/gen/db.go index d0d3db9fb..67cd40fb1 100644 --- a/go/core/internal/database/gen/db.go +++ b/go/core/internal/database/gen/db.go @@ -6,14 +6,15 @@ package dbgen import ( "context" - "database/sql" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" ) type DBTX interface { - ExecContext(context.Context, string, ...interface{}) (sql.Result, error) - PrepareContext(context.Context, string) (*sql.Stmt, error) - QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) - QueryRowContext(context.Context, string, ...interface{}) *sql.Row + Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) + Query(context.Context, string, ...interface{}) (pgx.Rows, error) + QueryRow(context.Context, string, ...interface{}) pgx.Row } func New(db DBTX) *Queries { @@ -24,7 +25,7 @@ type Queries struct { db DBTX } -func (q *Queries) WithTx(tx *sql.Tx) *Queries { +func (q *Queries) WithTx(tx pgx.Tx) *Queries { return &Queries{ db: tx, } diff --git a/go/core/internal/database/gen/events.sql.go b/go/core/internal/database/gen/events.sql.go index bc029f8a8..4c16cfd3e 100644 --- a/go/core/internal/database/gen/events.sql.go +++ b/go/core/internal/database/gen/events.sql.go @@ -7,7 +7,6 @@ package dbgen import ( "context" - "database/sql" "time" ) @@ -23,7 +22,7 @@ type GetEventParams struct { } func (q *Queries) GetEvent(ctx context.Context, arg GetEventParams) (Event, error) { - row := q.db.QueryRowContext(ctx, getEvent, arg.ID, arg.UserID) + row := q.db.QueryRow(ctx, getEvent, arg.ID, arg.UserID) var i Event err := row.Scan( &i.ID, @@ -45,12 +44,12 @@ VALUES ($1, $2, $3, $4, NOW(), NOW()) type InsertEventParams struct { ID string UserID string - SessionID sql.NullString + SessionID *string Data string } func (q *Queries) InsertEvent(ctx context.Context, arg InsertEventParams) error { - _, err := q.db.ExecContext(ctx, insertEvent, + _, err := q.db.Exec(ctx, insertEvent, arg.ID, arg.UserID, arg.SessionID, @@ -65,8 +64,8 @@ WHERE session_id = $1 AND deleted_at IS NULL ORDER BY created_at DESC ` -func (q *Queries) ListEventsByContextID(ctx context.Context, sessionID sql.NullString) ([]Event, error) { - rows, err := q.db.QueryContext(ctx, listEventsByContextID, sessionID) +func (q *Queries) ListEventsByContextID(ctx context.Context, sessionID *string) ([]Event, error) { + rows, err := q.db.Query(ctx, listEventsByContextID, sessionID) if err != nil { return nil, err } @@ -87,9 +86,6 @@ func (q *Queries) ListEventsByContextID(ctx context.Context, sessionID sql.NullS } items = append(items, i) } - if err := rows.Close(); err != nil { - return nil, err - } if err := rows.Err(); err != nil { return nil, err } @@ -104,12 +100,12 @@ LIMIT $2 ` type ListEventsByContextIDLimitParams struct { - SessionID sql.NullString + SessionID *string Limit int32 } func (q *Queries) ListEventsByContextIDLimit(ctx context.Context, arg ListEventsByContextIDLimitParams) ([]Event, error) { - rows, err := q.db.QueryContext(ctx, listEventsByContextIDLimit, arg.SessionID, arg.Limit) + rows, err := q.db.Query(ctx, listEventsByContextIDLimit, arg.SessionID, arg.Limit) if err != nil { return nil, err } @@ -130,9 +126,6 @@ func (q *Queries) ListEventsByContextIDLimit(ctx context.Context, arg ListEvents } items = append(items, i) } - if err := rows.Close(); err != nil { - return nil, err - } if err := rows.Err(); err != nil { return nil, err } @@ -147,13 +140,13 @@ ORDER BY created_at ASC ` type ListEventsForSessionAscParams struct { - SessionID sql.NullString + SessionID *string UserID string Column3 time.Time } func (q *Queries) ListEventsForSessionAsc(ctx context.Context, arg ListEventsForSessionAscParams) ([]Event, error) { - rows, err := q.db.QueryContext(ctx, listEventsForSessionAsc, arg.SessionID, arg.UserID, arg.Column3) + rows, err := q.db.Query(ctx, listEventsForSessionAsc, arg.SessionID, arg.UserID, arg.Column3) if err != nil { return nil, err } @@ -174,9 +167,6 @@ func (q *Queries) ListEventsForSessionAsc(ctx context.Context, arg ListEventsFor } items = append(items, i) } - if err := rows.Close(); err != nil { - return nil, err - } if err := rows.Err(); err != nil { return nil, err } @@ -192,14 +182,14 @@ LIMIT $4 ` type ListEventsForSessionAscLimitParams struct { - SessionID sql.NullString + SessionID *string UserID string Column3 time.Time Limit int32 } func (q *Queries) ListEventsForSessionAscLimit(ctx context.Context, arg ListEventsForSessionAscLimitParams) ([]Event, error) { - rows, err := q.db.QueryContext(ctx, listEventsForSessionAscLimit, + rows, err := q.db.Query(ctx, listEventsForSessionAscLimit, arg.SessionID, arg.UserID, arg.Column3, @@ -225,9 +215,6 @@ func (q *Queries) ListEventsForSessionAscLimit(ctx context.Context, arg ListEven } items = append(items, i) } - if err := rows.Close(); err != nil { - return nil, err - } if err := rows.Err(); err != nil { return nil, err } @@ -242,13 +229,13 @@ ORDER BY created_at DESC ` type ListEventsForSessionDescParams struct { - SessionID sql.NullString + SessionID *string UserID string Column3 time.Time } func (q *Queries) ListEventsForSessionDesc(ctx context.Context, arg ListEventsForSessionDescParams) ([]Event, error) { - rows, err := q.db.QueryContext(ctx, listEventsForSessionDesc, arg.SessionID, arg.UserID, arg.Column3) + rows, err := q.db.Query(ctx, listEventsForSessionDesc, arg.SessionID, arg.UserID, arg.Column3) if err != nil { return nil, err } @@ -269,9 +256,6 @@ func (q *Queries) ListEventsForSessionDesc(ctx context.Context, arg ListEventsFo } items = append(items, i) } - if err := rows.Close(); err != nil { - return nil, err - } if err := rows.Err(); err != nil { return nil, err } @@ -287,14 +271,14 @@ LIMIT $4 ` type ListEventsForSessionDescLimitParams struct { - SessionID sql.NullString + SessionID *string UserID string Column3 time.Time Limit int32 } func (q *Queries) ListEventsForSessionDescLimit(ctx context.Context, arg ListEventsForSessionDescLimitParams) ([]Event, error) { - rows, err := q.db.QueryContext(ctx, listEventsForSessionDescLimit, + rows, err := q.db.Query(ctx, listEventsForSessionDescLimit, arg.SessionID, arg.UserID, arg.Column3, @@ -320,9 +304,6 @@ func (q *Queries) ListEventsForSessionDescLimit(ctx context.Context, arg ListEve } items = append(items, i) } - if err := rows.Close(); err != nil { - return nil, err - } if err := rows.Err(); err != nil { return nil, err } @@ -335,6 +316,6 @@ WHERE id = $1 AND deleted_at IS NULL ` func (q *Queries) SoftDeleteEvent(ctx context.Context, id string) error { - _, err := q.db.ExecContext(ctx, softDeleteEvent, id) + _, err := q.db.Exec(ctx, softDeleteEvent, id) return err } diff --git a/go/core/internal/database/gen/feedback.sql.go b/go/core/internal/database/gen/feedback.sql.go index 548aff05b..ed3df4ae3 100644 --- a/go/core/internal/database/gen/feedback.sql.go +++ b/go/core/internal/database/gen/feedback.sql.go @@ -7,7 +7,6 @@ package dbgen import ( "context" - "database/sql" "github.com/kagent-dev/kagent/go/api/database" ) @@ -20,14 +19,14 @@ RETURNING id, created_at, updated_at, deleted_at, user_id, message_id, is_positi type InsertFeedbackParams struct { UserID string - MessageID sql.NullInt64 - IsPositive sql.NullBool + MessageID *int64 + IsPositive *bool FeedbackText string IssueType *database.FeedbackIssueType } func (q *Queries) InsertFeedback(ctx context.Context, arg InsertFeedbackParams) (Feedback, error) { - row := q.db.QueryRowContext(ctx, insertFeedback, + row := q.db.QueryRow(ctx, insertFeedback, arg.UserID, arg.MessageID, arg.IsPositive, @@ -56,7 +55,7 @@ ORDER BY created_at ASC ` func (q *Queries) ListFeedback(ctx context.Context, userID string) ([]Feedback, error) { - rows, err := q.db.QueryContext(ctx, listFeedback, userID) + rows, err := q.db.Query(ctx, listFeedback, userID) if err != nil { return nil, err } @@ -79,9 +78,6 @@ func (q *Queries) ListFeedback(ctx context.Context, userID string) ([]Feedback, } items = append(items, i) } - if err := rows.Close(); err != nil { - return nil, err - } if err := rows.Err(); err != nil { return nil, err } diff --git a/go/core/internal/database/gen/langgraph.sql.go b/go/core/internal/database/gen/langgraph.sql.go index d37b73b94..505db2a56 100644 --- a/go/core/internal/database/gen/langgraph.sql.go +++ b/go/core/internal/database/gen/langgraph.sql.go @@ -7,7 +7,6 @@ package dbgen import ( "context" - "database/sql" ) const getCheckpoint = `-- name: GetCheckpoint :one @@ -25,7 +24,7 @@ type GetCheckpointParams struct { } func (q *Queries) GetCheckpoint(ctx context.Context, arg GetCheckpointParams) (LgCheckpoint, error) { - row := q.db.QueryRowContext(ctx, getCheckpoint, + row := q.db.QueryRow(ctx, getCheckpoint, arg.UserID, arg.ThreadID, arg.CheckpointNs, @@ -64,7 +63,7 @@ type ListCheckpointWritesParams struct { } func (q *Queries) ListCheckpointWrites(ctx context.Context, arg ListCheckpointWritesParams) ([]LgCheckpointWrite, error) { - rows, err := q.db.QueryContext(ctx, listCheckpointWrites, + rows, err := q.db.Query(ctx, listCheckpointWrites, arg.UserID, arg.ThreadID, arg.CheckpointNs, @@ -95,9 +94,6 @@ func (q *Queries) ListCheckpointWrites(ctx context.Context, arg ListCheckpointWr } items = append(items, i) } - if err := rows.Close(); err != nil { - return nil, err - } if err := rows.Err(); err != nil { return nil, err } @@ -118,7 +114,7 @@ type ListCheckpointsParams struct { } func (q *Queries) ListCheckpoints(ctx context.Context, arg ListCheckpointsParams) ([]LgCheckpoint, error) { - rows, err := q.db.QueryContext(ctx, listCheckpoints, arg.UserID, arg.ThreadID, arg.CheckpointNs) + rows, err := q.db.Query(ctx, listCheckpoints, arg.UserID, arg.ThreadID, arg.CheckpointNs) if err != nil { return nil, err } @@ -144,9 +140,6 @@ func (q *Queries) ListCheckpoints(ctx context.Context, arg ListCheckpointsParams } items = append(items, i) } - if err := rows.Close(); err != nil { - return nil, err - } if err := rows.Err(); err != nil { return nil, err } @@ -169,7 +162,7 @@ type ListCheckpointsLimitParams struct { } func (q *Queries) ListCheckpointsLimit(ctx context.Context, arg ListCheckpointsLimitParams) ([]LgCheckpoint, error) { - rows, err := q.db.QueryContext(ctx, listCheckpointsLimit, + rows, err := q.db.Query(ctx, listCheckpointsLimit, arg.UserID, arg.ThreadID, arg.CheckpointNs, @@ -200,9 +193,6 @@ func (q *Queries) ListCheckpointsLimit(ctx context.Context, arg ListCheckpointsL } items = append(items, i) } - if err := rows.Close(); err != nil { - return nil, err - } if err := rows.Err(); err != nil { return nil, err } @@ -220,7 +210,7 @@ type SoftDeleteCheckpointWritesParams struct { } func (q *Queries) SoftDeleteCheckpointWrites(ctx context.Context, arg SoftDeleteCheckpointWritesParams) error { - _, err := q.db.ExecContext(ctx, softDeleteCheckpointWrites, arg.UserID, arg.ThreadID) + _, err := q.db.Exec(ctx, softDeleteCheckpointWrites, arg.UserID, arg.ThreadID) return err } @@ -235,7 +225,7 @@ type SoftDeleteCheckpointsParams struct { } func (q *Queries) SoftDeleteCheckpoints(ctx context.Context, arg SoftDeleteCheckpointsParams) error { - _, err := q.db.ExecContext(ctx, softDeleteCheckpoints, arg.UserID, arg.ThreadID) + _, err := q.db.Exec(ctx, softDeleteCheckpoints, arg.UserID, arg.ThreadID) return err } @@ -259,15 +249,15 @@ type UpsertCheckpointParams struct { ThreadID string CheckpointNs string CheckpointID string - ParentCheckpointID sql.NullString + ParentCheckpointID *string Metadata string Checkpoint string CheckpointType string - Version sql.NullInt32 + Version *int32 } func (q *Queries) UpsertCheckpoint(ctx context.Context, arg UpsertCheckpointParams) error { - _, err := q.db.ExecContext(ctx, upsertCheckpoint, + _, err := q.db.Exec(ctx, upsertCheckpoint, arg.UserID, arg.ThreadID, arg.CheckpointNs, @@ -307,7 +297,7 @@ type UpsertCheckpointWriteParams struct { } func (q *Queries) UpsertCheckpointWrite(ctx context.Context, arg UpsertCheckpointWriteParams) error { - _, err := q.db.ExecContext(ctx, upsertCheckpointWrite, + _, err := q.db.Exec(ctx, upsertCheckpointWrite, arg.UserID, arg.ThreadID, arg.CheckpointNs, diff --git a/go/core/internal/database/gen/memory.sql.go b/go/core/internal/database/gen/memory.sql.go index b5ee97f28..2ada25c70 100644 --- a/go/core/internal/database/gen/memory.sql.go +++ b/go/core/internal/database/gen/memory.sql.go @@ -7,10 +7,8 @@ package dbgen import ( "context" - "database/sql" "time" - "github.com/lib/pq" pgvector_go "github.com/pgvector/pgvector-go" ) @@ -19,12 +17,12 @@ DELETE FROM memory WHERE agent_name = $1 AND user_id = $2 ` type DeleteAgentMemoryParams struct { - AgentName sql.NullString - UserID sql.NullString + AgentName *string + UserID *string } func (q *Queries) DeleteAgentMemory(ctx context.Context, arg DeleteAgentMemoryParams) error { - _, err := q.db.ExecContext(ctx, deleteAgentMemory, arg.AgentName, arg.UserID) + _, err := q.db.Exec(ctx, deleteAgentMemory, arg.AgentName, arg.UserID) return err } @@ -34,7 +32,7 @@ WHERE expires_at < NOW() AND access_count < 10 ` func (q *Queries) DeleteExpiredMemories(ctx context.Context) error { - _, err := q.db.ExecContext(ctx, deleteExpiredMemories) + _, err := q.db.Exec(ctx, deleteExpiredMemories) return err } @@ -45,7 +43,7 @@ WHERE expires_at < NOW() AND access_count >= 10 ` func (q *Queries) ExtendMemoryTTL(ctx context.Context) error { - _, err := q.db.ExecContext(ctx, extendMemoryTTL) + _, err := q.db.Exec(ctx, extendMemoryTTL) return err } @@ -55,7 +53,7 @@ WHERE id = ANY($1::text[]) ` func (q *Queries) IncrementMemoryAccessCount(ctx context.Context, dollar_1 []string) error { - _, err := q.db.ExecContext(ctx, incrementMemoryAccessCount, pq.Array(dollar_1)) + _, err := q.db.Exec(ctx, incrementMemoryAccessCount, dollar_1) return err } @@ -66,17 +64,17 @@ RETURNING id ` type InsertMemoryParams struct { - AgentName sql.NullString - UserID sql.NullString - Content sql.NullString + AgentName *string + UserID *string + Content *string Embedding pgvector_go.Vector - Metadata sql.NullString - ExpiresAt sql.NullTime - AccessCount sql.NullInt32 + Metadata *string + ExpiresAt *time.Time + AccessCount *int32 } func (q *Queries) InsertMemory(ctx context.Context, arg InsertMemoryParams) (string, error) { - row := q.db.QueryRowContext(ctx, insertMemory, + row := q.db.QueryRow(ctx, insertMemory, arg.AgentName, arg.UserID, arg.Content, @@ -97,13 +95,13 @@ ORDER BY access_count DESC ` type ListAgentMemoriesParams struct { - AgentName sql.NullString - AgentName_2 sql.NullString - UserID sql.NullString + AgentName *string + AgentName_2 *string + UserID *string } func (q *Queries) ListAgentMemories(ctx context.Context, arg ListAgentMemoriesParams) ([]Memory, error) { - rows, err := q.db.QueryContext(ctx, listAgentMemories, arg.AgentName, arg.AgentName_2, arg.UserID) + rows, err := q.db.Query(ctx, listAgentMemories, arg.AgentName, arg.AgentName_2, arg.UserID) if err != nil { return nil, err } @@ -126,9 +124,6 @@ func (q *Queries) ListAgentMemories(ctx context.Context, arg ListAgentMemoriesPa } items = append(items, i) } - if err := rows.Close(); err != nil { - return nil, err - } if err := rows.Err(); err != nil { return nil, err } @@ -145,26 +140,26 @@ LIMIT $4 type SearchAgentMemoryParams struct { Embedding pgvector_go.Vector - AgentName sql.NullString - UserID sql.NullString + AgentName *string + UserID *string Limit int32 } type SearchAgentMemoryRow struct { ID string - AgentName sql.NullString - UserID sql.NullString - Content sql.NullString + AgentName *string + UserID *string + Content *string Embedding pgvector_go.Vector - Metadata sql.NullString + Metadata *string CreatedAt time.Time - ExpiresAt sql.NullTime - AccessCount sql.NullInt32 + ExpiresAt *time.Time + AccessCount *int32 Score interface{} } func (q *Queries) SearchAgentMemory(ctx context.Context, arg SearchAgentMemoryParams) ([]SearchAgentMemoryRow, error) { - rows, err := q.db.QueryContext(ctx, searchAgentMemory, + rows, err := q.db.Query(ctx, searchAgentMemory, arg.Embedding, arg.AgentName, arg.UserID, @@ -193,9 +188,6 @@ func (q *Queries) SearchAgentMemory(ctx context.Context, arg SearchAgentMemoryPa } items = append(items, i) } - if err := rows.Close(); err != nil { - return nil, err - } if err := rows.Err(); err != nil { return nil, err } diff --git a/go/core/internal/database/gen/models.go b/go/core/internal/database/gen/models.go index 8c8b6d9f8..4b26661da 100644 --- a/go/core/internal/database/gen/models.go +++ b/go/core/internal/database/gen/models.go @@ -5,7 +5,6 @@ package dbgen import ( - "database/sql" "time" "github.com/kagent-dev/kagent/go/api/adk" @@ -17,7 +16,7 @@ type Agent struct { ID string CreatedAt time.Time UpdatedAt time.Time - DeletedAt sql.NullTime + DeletedAt *time.Time Type string Config *adk.AgentConfig } @@ -27,7 +26,7 @@ type CrewaiAgentMemory struct { ThreadID string CreatedAt time.Time UpdatedAt time.Time - DeletedAt sql.NullTime + DeletedAt *time.Time MemoryData string } @@ -37,28 +36,28 @@ type CrewaiFlowState struct { MethodName string CreatedAt time.Time UpdatedAt time.Time - DeletedAt sql.NullTime + DeletedAt *time.Time StateData string } type Event struct { ID string UserID string - SessionID sql.NullString + SessionID *string CreatedAt time.Time UpdatedAt time.Time - DeletedAt sql.NullTime + DeletedAt *time.Time Data string } type Feedback struct { ID int64 - CreatedAt sql.NullTime - UpdatedAt sql.NullTime - DeletedAt sql.NullTime + CreatedAt *time.Time + UpdatedAt *time.Time + DeletedAt *time.Time UserID string - MessageID sql.NullInt64 - IsPositive sql.NullBool + MessageID *int64 + IsPositive *bool FeedbackText string IssueType *database.FeedbackIssueType } @@ -68,14 +67,14 @@ type LgCheckpoint struct { ThreadID string CheckpointNs string CheckpointID string - ParentCheckpointID sql.NullString + ParentCheckpointID *string CreatedAt time.Time UpdatedAt time.Time - DeletedAt sql.NullTime + DeletedAt *time.Time Metadata string Checkpoint string CheckpointType string - Version sql.NullInt32 + Version *int32 } type LgCheckpointWrite struct { @@ -90,19 +89,19 @@ type LgCheckpointWrite struct { TaskID string CreatedAt time.Time UpdatedAt time.Time - DeletedAt sql.NullTime + DeletedAt *time.Time } type Memory struct { ID string - AgentName sql.NullString - UserID sql.NullString - Content sql.NullString + AgentName *string + UserID *string + Content *string Embedding pgvector_go.Vector - Metadata sql.NullString + Metadata *string CreatedAt time.Time - ExpiresAt sql.NullTime - AccessCount sql.NullInt32 + ExpiresAt *time.Time + AccessCount *int32 } type PushNotification struct { @@ -110,28 +109,28 @@ type PushNotification struct { TaskID string CreatedAt time.Time UpdatedAt time.Time - DeletedAt sql.NullTime + DeletedAt *time.Time Data string } type Session struct { ID string UserID string - Name sql.NullString + Name *string CreatedAt time.Time UpdatedAt time.Time - DeletedAt sql.NullTime - AgentID sql.NullString - Source sql.NullString + DeletedAt *time.Time + AgentID *string + Source *string } type Task struct { ID string CreatedAt time.Time UpdatedAt time.Time - DeletedAt sql.NullTime + DeletedAt *time.Time Data string - SessionID sql.NullString + SessionID *string } type Tool struct { @@ -140,8 +139,8 @@ type Tool struct { GroupKind string CreatedAt time.Time UpdatedAt time.Time - DeletedAt sql.NullTime - Description sql.NullString + DeletedAt *time.Time + Description *string } type Toolserver struct { @@ -149,7 +148,7 @@ type Toolserver struct { GroupKind string CreatedAt time.Time UpdatedAt time.Time - DeletedAt sql.NullTime - Description sql.NullString - LastConnected sql.NullTime + DeletedAt *time.Time + Description *string + LastConnected *time.Time } diff --git a/go/core/internal/database/gen/push_notifications.sql.go b/go/core/internal/database/gen/push_notifications.sql.go index cdedf3555..73a7a0069 100644 --- a/go/core/internal/database/gen/push_notifications.sql.go +++ b/go/core/internal/database/gen/push_notifications.sql.go @@ -21,7 +21,7 @@ type GetPushNotificationParams struct { } func (q *Queries) GetPushNotification(ctx context.Context, arg GetPushNotificationParams) (PushNotification, error) { - row := q.db.QueryRowContext(ctx, getPushNotification, arg.TaskID, arg.ID) + row := q.db.QueryRow(ctx, getPushNotification, arg.TaskID, arg.ID) var i PushNotification err := row.Scan( &i.ID, @@ -41,7 +41,7 @@ ORDER BY created_at ASC ` func (q *Queries) ListPushNotifications(ctx context.Context, taskID string) ([]PushNotification, error) { - rows, err := q.db.QueryContext(ctx, listPushNotifications, taskID) + rows, err := q.db.Query(ctx, listPushNotifications, taskID) if err != nil { return nil, err } @@ -61,9 +61,6 @@ func (q *Queries) ListPushNotifications(ctx context.Context, taskID string) ([]P } items = append(items, i) } - if err := rows.Close(); err != nil { - return nil, err - } if err := rows.Err(); err != nil { return nil, err } @@ -76,7 +73,7 @@ WHERE task_id = $1 AND deleted_at IS NULL ` func (q *Queries) SoftDeletePushNotification(ctx context.Context, taskID string) error { - _, err := q.db.ExecContext(ctx, softDeletePushNotification, taskID) + _, err := q.db.Exec(ctx, softDeletePushNotification, taskID) return err } @@ -95,6 +92,6 @@ type UpsertPushNotificationParams struct { } func (q *Queries) UpsertPushNotification(ctx context.Context, arg UpsertPushNotificationParams) error { - _, err := q.db.ExecContext(ctx, upsertPushNotification, arg.ID, arg.TaskID, arg.Data) + _, err := q.db.Exec(ctx, upsertPushNotification, arg.ID, arg.TaskID, arg.Data) return err } diff --git a/go/core/internal/database/gen/querier.go b/go/core/internal/database/gen/querier.go index 152a08988..281d22edf 100644 --- a/go/core/internal/database/gen/querier.go +++ b/go/core/internal/database/gen/querier.go @@ -6,7 +6,6 @@ package dbgen import ( "context" - "database/sql" ) type Querier interface { @@ -32,7 +31,7 @@ type Querier interface { ListCheckpointWrites(ctx context.Context, arg ListCheckpointWritesParams) ([]LgCheckpointWrite, error) ListCheckpoints(ctx context.Context, arg ListCheckpointsParams) ([]LgCheckpoint, error) ListCheckpointsLimit(ctx context.Context, arg ListCheckpointsLimitParams) ([]LgCheckpoint, error) - ListEventsByContextID(ctx context.Context, sessionID sql.NullString) ([]Event, error) + ListEventsByContextID(ctx context.Context, sessionID *string) ([]Event, error) ListEventsByContextIDLimit(ctx context.Context, arg ListEventsByContextIDLimitParams) ([]Event, error) ListEventsForSessionAsc(ctx context.Context, arg ListEventsForSessionAscParams) ([]Event, error) ListEventsForSessionAscLimit(ctx context.Context, arg ListEventsForSessionAscLimitParams) ([]Event, error) @@ -42,7 +41,7 @@ type Querier interface { ListPushNotifications(ctx context.Context, taskID string) ([]PushNotification, error) ListSessions(ctx context.Context, userID string) ([]Session, error) ListSessionsForAgent(ctx context.Context, arg ListSessionsForAgentParams) ([]Session, error) - ListTasksForSession(ctx context.Context, sessionID sql.NullString) ([]Task, error) + ListTasksForSession(ctx context.Context, sessionID *string) ([]Task, error) ListToolServers(ctx context.Context) ([]Toolserver, error) ListTools(ctx context.Context) ([]Tool, error) ListToolsForServer(ctx context.Context, arg ListToolsForServerParams) ([]Tool, error) diff --git a/go/core/internal/database/gen/sessions.sql.go b/go/core/internal/database/gen/sessions.sql.go index b7f58b70a..8c432bfba 100644 --- a/go/core/internal/database/gen/sessions.sql.go +++ b/go/core/internal/database/gen/sessions.sql.go @@ -7,7 +7,6 @@ package dbgen import ( "context" - "database/sql" ) const getSession = `-- name: GetSession :one @@ -22,7 +21,7 @@ type GetSessionParams struct { } func (q *Queries) GetSession(ctx context.Context, arg GetSessionParams) (Session, error) { - row := q.db.QueryRowContext(ctx, getSession, arg.ID, arg.UserID) + row := q.db.QueryRow(ctx, getSession, arg.ID, arg.UserID) var i Session err := row.Scan( &i.ID, @@ -44,7 +43,7 @@ ORDER BY created_at ASC ` func (q *Queries) ListSessions(ctx context.Context, userID string) ([]Session, error) { - rows, err := q.db.QueryContext(ctx, listSessions, userID) + rows, err := q.db.Query(ctx, listSessions, userID) if err != nil { return nil, err } @@ -66,9 +65,6 @@ func (q *Queries) ListSessions(ctx context.Context, userID string) ([]Session, e } items = append(items, i) } - if err := rows.Close(); err != nil { - return nil, err - } if err := rows.Err(); err != nil { return nil, err } @@ -83,12 +79,12 @@ ORDER BY created_at ASC ` type ListSessionsForAgentParams struct { - AgentID sql.NullString + AgentID *string UserID string } func (q *Queries) ListSessionsForAgent(ctx context.Context, arg ListSessionsForAgentParams) ([]Session, error) { - rows, err := q.db.QueryContext(ctx, listSessionsForAgent, arg.AgentID, arg.UserID) + rows, err := q.db.Query(ctx, listSessionsForAgent, arg.AgentID, arg.UserID) if err != nil { return nil, err } @@ -110,9 +106,6 @@ func (q *Queries) ListSessionsForAgent(ctx context.Context, arg ListSessionsForA } items = append(items, i) } - if err := rows.Close(); err != nil { - return nil, err - } if err := rows.Err(); err != nil { return nil, err } @@ -130,7 +123,7 @@ type SoftDeleteSessionParams struct { } func (q *Queries) SoftDeleteSession(ctx context.Context, arg SoftDeleteSessionParams) error { - _, err := q.db.ExecContext(ctx, softDeleteSession, arg.ID, arg.UserID) + _, err := q.db.Exec(ctx, softDeleteSession, arg.ID, arg.UserID) return err } @@ -147,13 +140,13 @@ ON CONFLICT (id, user_id) DO UPDATE SET type UpsertSessionParams struct { ID string UserID string - Name sql.NullString - AgentID sql.NullString - Source sql.NullString + Name *string + AgentID *string + Source *string } func (q *Queries) UpsertSession(ctx context.Context, arg UpsertSessionParams) error { - _, err := q.db.ExecContext(ctx, upsertSession, + _, err := q.db.Exec(ctx, upsertSession, arg.ID, arg.UserID, arg.Name, diff --git a/go/core/internal/database/gen/tasks.sql.go b/go/core/internal/database/gen/tasks.sql.go index c3e59638d..f5e8f8d2d 100644 --- a/go/core/internal/database/gen/tasks.sql.go +++ b/go/core/internal/database/gen/tasks.sql.go @@ -7,7 +7,6 @@ package dbgen import ( "context" - "database/sql" ) const getTask = `-- name: GetTask :one @@ -17,7 +16,7 @@ LIMIT 1 ` func (q *Queries) GetTask(ctx context.Context, id string) (Task, error) { - row := q.db.QueryRowContext(ctx, getTask, id) + row := q.db.QueryRow(ctx, getTask, id) var i Task err := row.Scan( &i.ID, @@ -36,8 +35,8 @@ WHERE session_id = $1 AND deleted_at IS NULL ORDER BY created_at ASC ` -func (q *Queries) ListTasksForSession(ctx context.Context, sessionID sql.NullString) ([]Task, error) { - rows, err := q.db.QueryContext(ctx, listTasksForSession, sessionID) +func (q *Queries) ListTasksForSession(ctx context.Context, sessionID *string) ([]Task, error) { + rows, err := q.db.Query(ctx, listTasksForSession, sessionID) if err != nil { return nil, err } @@ -57,9 +56,6 @@ func (q *Queries) ListTasksForSession(ctx context.Context, sessionID sql.NullStr } items = append(items, i) } - if err := rows.Close(); err != nil { - return nil, err - } if err := rows.Err(); err != nil { return nil, err } @@ -71,7 +67,7 @@ UPDATE task SET deleted_at = NOW() WHERE id = $1 AND deleted_at IS NULL ` func (q *Queries) SoftDeleteTask(ctx context.Context, id string) error { - _, err := q.db.ExecContext(ctx, softDeleteTask, id) + _, err := q.db.Exec(ctx, softDeleteTask, id) return err } @@ -82,7 +78,7 @@ SELECT EXISTS ( ` func (q *Queries) TaskExists(ctx context.Context, id string) (bool, error) { - row := q.db.QueryRowContext(ctx, taskExists, id) + row := q.db.QueryRow(ctx, taskExists, id) var exists bool err := row.Scan(&exists) return exists, err @@ -100,10 +96,10 @@ ON CONFLICT (id) DO UPDATE SET type UpsertTaskParams struct { ID string Data string - SessionID sql.NullString + SessionID *string } func (q *Queries) UpsertTask(ctx context.Context, arg UpsertTaskParams) error { - _, err := q.db.ExecContext(ctx, upsertTask, arg.ID, arg.Data, arg.SessionID) + _, err := q.db.Exec(ctx, upsertTask, arg.ID, arg.Data, arg.SessionID) return err } diff --git a/go/core/internal/database/gen/tools.sql.go b/go/core/internal/database/gen/tools.sql.go index 67be7412e..68bf4a69b 100644 --- a/go/core/internal/database/gen/tools.sql.go +++ b/go/core/internal/database/gen/tools.sql.go @@ -7,7 +7,7 @@ package dbgen import ( "context" - "database/sql" + "time" ) const getTool = `-- name: GetTool :one @@ -17,7 +17,7 @@ LIMIT 1 ` func (q *Queries) GetTool(ctx context.Context, id string) (Tool, error) { - row := q.db.QueryRowContext(ctx, getTool, id) + row := q.db.QueryRow(ctx, getTool, id) var i Tool err := row.Scan( &i.ID, @@ -38,7 +38,7 @@ LIMIT 1 ` func (q *Queries) GetToolServer(ctx context.Context, name string) (Toolserver, error) { - row := q.db.QueryRowContext(ctx, getToolServer, name) + row := q.db.QueryRow(ctx, getToolServer, name) var i Toolserver err := row.Scan( &i.Name, @@ -59,7 +59,7 @@ ORDER BY created_at ASC ` func (q *Queries) ListToolServers(ctx context.Context) ([]Toolserver, error) { - rows, err := q.db.QueryContext(ctx, listToolServers) + rows, err := q.db.Query(ctx, listToolServers) if err != nil { return nil, err } @@ -80,9 +80,6 @@ func (q *Queries) ListToolServers(ctx context.Context) ([]Toolserver, error) { } items = append(items, i) } - if err := rows.Close(); err != nil { - return nil, err - } if err := rows.Err(); err != nil { return nil, err } @@ -96,7 +93,7 @@ ORDER BY created_at ASC ` func (q *Queries) ListTools(ctx context.Context) ([]Tool, error) { - rows, err := q.db.QueryContext(ctx, listTools) + rows, err := q.db.Query(ctx, listTools) if err != nil { return nil, err } @@ -117,9 +114,6 @@ func (q *Queries) ListTools(ctx context.Context) ([]Tool, error) { } items = append(items, i) } - if err := rows.Close(); err != nil { - return nil, err - } if err := rows.Err(); err != nil { return nil, err } @@ -138,7 +132,7 @@ type ListToolsForServerParams struct { } func (q *Queries) ListToolsForServer(ctx context.Context, arg ListToolsForServerParams) ([]Tool, error) { - rows, err := q.db.QueryContext(ctx, listToolsForServer, arg.ServerName, arg.GroupKind) + rows, err := q.db.Query(ctx, listToolsForServer, arg.ServerName, arg.GroupKind) if err != nil { return nil, err } @@ -159,9 +153,6 @@ func (q *Queries) ListToolsForServer(ctx context.Context, arg ListToolsForServer } items = append(items, i) } - if err := rows.Close(); err != nil { - return nil, err - } if err := rows.Err(); err != nil { return nil, err } @@ -179,7 +170,7 @@ type SoftDeleteToolServerParams struct { } func (q *Queries) SoftDeleteToolServer(ctx context.Context, arg SoftDeleteToolServerParams) error { - _, err := q.db.ExecContext(ctx, softDeleteToolServer, arg.Name, arg.GroupKind) + _, err := q.db.Exec(ctx, softDeleteToolServer, arg.Name, arg.GroupKind) return err } @@ -194,7 +185,7 @@ type SoftDeleteToolsForServerParams struct { } func (q *Queries) SoftDeleteToolsForServer(ctx context.Context, arg SoftDeleteToolsForServerParams) error { - _, err := q.db.ExecContext(ctx, softDeleteToolsForServer, arg.ServerName, arg.GroupKind) + _, err := q.db.Exec(ctx, softDeleteToolsForServer, arg.ServerName, arg.GroupKind) return err } @@ -211,11 +202,11 @@ type UpsertToolParams struct { ID string ServerName string GroupKind string - Description sql.NullString + Description *string } func (q *Queries) UpsertTool(ctx context.Context, arg UpsertToolParams) error { - _, err := q.db.ExecContext(ctx, upsertTool, + _, err := q.db.Exec(ctx, upsertTool, arg.ID, arg.ServerName, arg.GroupKind, @@ -238,12 +229,12 @@ RETURNING name, group_kind, created_at, updated_at, deleted_at, description, las type UpsertToolServerParams struct { Name string GroupKind string - Description sql.NullString - LastConnected sql.NullTime + Description *string + LastConnected *time.Time } func (q *Queries) UpsertToolServer(ctx context.Context, arg UpsertToolServerParams) (Toolserver, error) { - row := q.db.QueryRowContext(ctx, upsertToolServer, + row := q.db.QueryRow(ctx, upsertToolServer, arg.Name, arg.GroupKind, arg.Description, diff --git a/go/core/internal/database/sqlc.yaml b/go/core/internal/database/sqlc.yaml index bef4f26da..e9f1fe3cb 100644 --- a/go/core/internal/database/sqlc.yaml +++ b/go/core/internal/database/sqlc.yaml @@ -7,9 +7,21 @@ sql: go: package: "dbgen" out: "gen" + sql_package: "pgx/v5" emit_interface: true emit_pointers_for_null_types: true overrides: + # pgx/v5 native mode maps all timestamptz → pgtype.Timestamptz by default. + # Override to keep idiomatic time.Time / *time.Time throughout. + - db_type: "timestamptz" + nullable: false + go_type: + type: "time.Time" + - db_type: "timestamptz" + nullable: true + go_type: + type: "time.Time" + pointer: true # Use domain types for columns that would otherwise be plain strings/bytes. - column: "agent.config" go_type: diff --git a/go/core/internal/database/testhelpers_test.go b/go/core/internal/database/testhelpers_test.go index cb7d6c56b..cd2c9ed1f 100644 --- a/go/core/internal/database/testhelpers_test.go +++ b/go/core/internal/database/testhelpers_test.go @@ -2,17 +2,17 @@ package database import ( "context" - "database/sql" "flag" "fmt" "os" "testing" + "github.com/jackc/pgx/v5/pgxpool" "github.com/kagent-dev/kagent/go/core/internal/dbtest" ) var ( - sharedDB *sql.DB + sharedDB *pgxpool.Pool sharedConnStr string ) @@ -34,7 +34,7 @@ func TestMain(m *testing.M) { os.Exit(1) } - db, err := Connect(context.Background(), &PostgresConfig{URL: connStr}) + db, err := Connect(context.Background(), &PostgresConfig{URL: connStr, VectorEnabled: true}) if err != nil { fmt.Fprintf(os.Stderr, "failed to connect to test database: %v\n", err) os.Exit(1) diff --git a/go/core/internal/database/upgrade_test.go b/go/core/internal/database/upgrade_test.go index 54d0e8ba4..208b4e29b 100644 --- a/go/core/internal/database/upgrade_test.go +++ b/go/core/internal/database/upgrade_test.go @@ -212,7 +212,7 @@ func TestUpgradeFromGORM(t *testing.T) { connStr := dbtest.StartT(ctx, t) // ── Step 1: apply the GORM-era schema ──────────────────────────────────── - rawDB, err := sql.Open("postgres", connStr) + rawDB, err := sql.Open("pgx", connStr) require.NoError(t, err) t.Cleanup(func() { rawDB.Close() }) diff --git a/go/core/internal/dbtest/dbtest.go b/go/core/internal/dbtest/dbtest.go index b6ea42f4f..71a448ed6 100644 --- a/go/core/internal/dbtest/dbtest.go +++ b/go/core/internal/dbtest/dbtest.go @@ -10,10 +10,10 @@ import ( "time" "github.com/golang-migrate/migrate/v4" - migratepg "github.com/golang-migrate/migrate/v4/database/postgres" + migratepgx "github.com/golang-migrate/migrate/v4/database/pgx/v5" "github.com/golang-migrate/migrate/v4/source/iofs" + _ "github.com/jackc/pgx/v5/stdlib" "github.com/kagent-dev/kagent/go/core/pkg/migrations" - _ "github.com/lib/pq" testcontainers "github.com/testcontainers/testcontainers-go" tcpostgres "github.com/testcontainers/testcontainers-go/modules/postgres" "github.com/testcontainers/testcontainers-go/wait" @@ -102,7 +102,7 @@ func MigrateDown(connStr string, vectorEnabled bool) error { } func runMigrationDir(connStr, dir, migrationsTable string) error { - db, err := sql.Open("postgres", connStr) + db, err := sql.Open("pgx", connStr) if err != nil { return fmt.Errorf("open db for %s: %w", dir, err) } @@ -112,7 +112,7 @@ func runMigrationDir(connStr, dir, migrationsTable string) error { return fmt.Errorf("load migration files from %s: %w", dir, err) } - driver, err := migratepg.WithInstance(db, &migratepg.Config{ + driver, err := migratepgx.WithInstance(db, &migratepgx.Config{ MigrationsTable: migrationsTable, }) if err != nil { @@ -132,7 +132,7 @@ func runMigrationDir(connStr, dir, migrationsTable string) error { } func downMigrationDir(connStr, dir, migrationsTable string) error { - db, err := sql.Open("postgres", connStr) + db, err := sql.Open("pgx", connStr) if err != nil { return fmt.Errorf("open db for %s: %w", dir, err) } @@ -142,7 +142,7 @@ func downMigrationDir(connStr, dir, migrationsTable string) error { return fmt.Errorf("load migration files from %s: %w", dir, err) } - driver, err := migratepg.WithInstance(db, &migratepg.Config{ + driver, err := migratepgx.WithInstance(db, &migratepgx.Config{ MigrationsTable: migrationsTable, }) if err != nil { diff --git a/go/core/internal/httpserver/middleware_error.go b/go/core/internal/httpserver/middleware_error.go index 8b1ba74ec..9860169c9 100644 --- a/go/core/internal/httpserver/middleware_error.go +++ b/go/core/internal/httpserver/middleware_error.go @@ -1,11 +1,11 @@ package httpserver import ( - "database/sql" "encoding/json" "errors" "net/http" + "github.com/jackc/pgx/v5" apierrors "github.com/kagent-dev/kagent/go/core/internal/httpserver/errors" "github.com/kagent-dev/kagent/go/core/internal/httpserver/handlers" ctrllog "sigs.k8s.io/controller-runtime/pkg/log" @@ -56,7 +56,7 @@ func (w *errorResponseWriter) RespondWithError(err error) { } } - if !errors.Is(err, sql.ErrNoRows) { + if !errors.Is(err, pgx.ErrNoRows) { log.Error(err, message) } else { log.Info(message) diff --git a/go/go.mod b/go/go.mod index b002eb934..f45b37e8e 100644 --- a/go/go.mod +++ b/go/go.mod @@ -18,15 +18,15 @@ require ( github.com/fatih/color v1.18.0 github.com/go-logr/logr v1.4.3 github.com/go-logr/zapr v1.3.0 + github.com/golang-migrate/migrate/v4 v4.19.1 // api dependencies github.com/google/uuid v1.6.0 - github.com/golang-migrate/migrate/v4 v4.19.1 github.com/gorilla/mux v1.8.1 github.com/hashicorp/go-multierror v1.1.1 - github.com/lib/pq v1.11.2 github.com/jedib0t/go-pretty/v6 v6.7.8 github.com/kagent-dev/kmcp v0.2.7 github.com/kagent-dev/mockllm v0.0.5 + github.com/lib/pq v1.11.2 // indirect github.com/modelcontextprotocol/go-sdk v1.4.1 github.com/muesli/reflow v0.3.0 github.com/openai/openai-go/v3 v3.26.0 @@ -59,6 +59,7 @@ require ( ) require ( + github.com/jackc/pgx/v5 v5.9.1 github.com/testcontainers/testcontainers-go v0.41.0 github.com/testcontainers/testcontainers-go/modules/postgres v0.41.0 ) @@ -140,6 +141,10 @@ require ( github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 // indirect github.com/hashicorp/errwrap v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/jackc/pgerrcode v0.0.0-20220416144525-469b46aa5efa // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/compress v1.18.2 // indirect diff --git a/go/go.sum b/go/go.sum index 606bde40f..52b661212 100644 --- a/go/go.sum +++ b/go/go.sum @@ -237,12 +237,14 @@ github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+l github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/jackc/pgerrcode v0.0.0-20220416144525-469b46aa5efa h1:s+4MhCQ6YrzisK6hFJUX53drDT4UsSW3DEhKn0ifuHw= +github.com/jackc/pgerrcode v0.0.0-20220416144525-469b46aa5efa/go.mod h1:a/s9Lp5W7n/DD0VrVoyJ00FbP2ytTPDVOivvn2bMlds= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx/v5 v5.7.2 h1:mLoDLV6sonKlvjIEsV56SkWNCnuNv531l94GaIzO+XI= -github.com/jackc/pgx/v5 v5.7.2/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ= +github.com/jackc/pgx/v5 v5.9.1 h1:uwrxJXBnx76nyISkhr33kQLlUqjv7et7b9FjCen/tdc= +github.com/jackc/pgx/v5 v5.9.1/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jedib0t/go-pretty/v6 v6.7.8 h1:BVYrDy5DPBA3Qn9ICT+PokP9cvCv1KaHv2i+Hc8sr5o= From 3421e26dcca2287d7bf403dfe8ab9a0893791b42 Mon Sep 17 00:00:00 2001 From: Jeremy Alvis Date: Mon, 30 Mar 2026 13:31:51 -0700 Subject: [PATCH 04/16] Simplify migrations by requiring kagent 0.8.0+ before upgrade Signed-off-by: Jeremy Alvis --- go/core/internal/database/upgrade_test.go | 12 ++++--- .../pkg/migrations/core/000001_initial.up.sql | 32 ++++++++++--------- .../core/000002_add_session_source.down.sql | 2 -- .../core/000002_add_session_source.up.sql | 6 ---- 4 files changed, 24 insertions(+), 28 deletions(-) delete mode 100644 go/core/pkg/migrations/core/000002_add_session_source.down.sql delete mode 100644 go/core/pkg/migrations/core/000002_add_session_source.up.sql diff --git a/go/core/internal/database/upgrade_test.go b/go/core/internal/database/upgrade_test.go index 208b4e29b..75e92ccd3 100644 --- a/go/core/internal/database/upgrade_test.go +++ b/go/core/internal/database/upgrade_test.go @@ -1,10 +1,12 @@ package database // TestUpgradeFromGORM validates that the golang-migrate migrations run cleanly -// against a database that was previously managed by GORM AutoMigrate, and that -// pre-existing data is accessible via the new sqlc client afterwards. +// against a database that was previously managed by GORM AutoMigrate (kagent +// v0.8.0), and that pre-existing data is accessible via the new sqlc client +// afterwards. v0.8.0 is the minimum required version before upgrading to this +// release. // -// It simulates an existing deployment by: +// It simulates an existing v0.8.0 deployment by: // 1. Creating the schema that GORM AutoMigrate would have produced (no migration // tracking tables, no gen_random_uuid() default on memory.id). // 2. Seeding representative rows, including soft-deleted CrewAI rows that GORM's @@ -25,8 +27,8 @@ import ( "github.com/stretchr/testify/require" ) -// gormSchema reproduces the DDL that GORM AutoMigrate emitted for the kagent -// models. Key differences from the current migrations: +// gormSchema reproduces the DDL that GORM AutoMigrate emitted for kagent v0.8.0. +// Key differences from the current migrations: // - No schema_migrations / vector_schema_migrations tracking tables. // - memory.id has no DEFAULT (GORM relied on the BeforeCreate hook). // - Indexes may have different names (GORM derives them from the struct name). diff --git a/go/core/pkg/migrations/core/000001_initial.up.sql b/go/core/pkg/migrations/core/000001_initial.up.sql index 193f076a4..b154ceedc 100644 --- a/go/core/pkg/migrations/core/000001_initial.up.sql +++ b/go/core/pkg/migrations/core/000001_initial.up.sql @@ -1,5 +1,5 @@ --- Baseline migration: matches the schema produced by GORM AutoMigrate in the --- prior kagent release. Each subsequent migration applies additive changes. +-- Baseline migration: matches the schema produced by GORM AutoMigrate as of +-- kagent v0.8.0. Upgrading to v0.8.0 before this version is required. CREATE TABLE IF NOT EXISTS agent ( id TEXT PRIMARY KEY, @@ -7,7 +7,7 @@ CREATE TABLE IF NOT EXISTS agent ( updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), deleted_at TIMESTAMPTZ, type TEXT NOT NULL, - config JSONB + config JSON ); CREATE INDEX IF NOT EXISTS idx_agent_deleted_at ON agent(deleted_at); @@ -19,11 +19,13 @@ CREATE TABLE IF NOT EXISTS session ( updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), deleted_at TIMESTAMPTZ, agent_id TEXT, + source TEXT, PRIMARY KEY (id, user_id) ); -CREATE INDEX IF NOT EXISTS idx_session_name ON session(name); -CREATE INDEX IF NOT EXISTS idx_session_agent_id ON session(agent_id); +CREATE INDEX IF NOT EXISTS idx_session_name ON session(name); +CREATE INDEX IF NOT EXISTS idx_session_agent_id ON session(agent_id); CREATE INDEX IF NOT EXISTS idx_session_deleted_at ON session(deleted_at); +CREATE INDEX IF NOT EXISTS idx_session_source ON session(source); CREATE TABLE IF NOT EXISTS event ( id TEXT NOT NULL, @@ -116,18 +118,18 @@ CREATE TABLE IF NOT EXISTS lg_checkpoint ( ); CREATE INDEX IF NOT EXISTS idx_lg_checkpoint_parent_checkpoint_id ON lg_checkpoint(parent_checkpoint_id); CREATE INDEX IF NOT EXISTS idx_lgcp_list ON lg_checkpoint(created_at); -CREATE INDEX IF NOT EXISTS idx_lg_checkpoint_deleted_at ON lg_checkpoint(deleted_at); +CREATE INDEX IF NOT EXISTS idx_lg_checkpoint_deleted_at ON lg_checkpoint(deleted_at); CREATE TABLE IF NOT EXISTS lg_checkpoint_write ( - user_id TEXT NOT NULL, - thread_id TEXT NOT NULL, - checkpoint_ns TEXT NOT NULL DEFAULT '', - checkpoint_id TEXT NOT NULL, - write_idx INTEGER NOT NULL, - value TEXT NOT NULL, - value_type TEXT NOT NULL, - channel TEXT NOT NULL, - task_id TEXT NOT NULL, + user_id TEXT NOT NULL, + thread_id TEXT NOT NULL, + checkpoint_ns TEXT NOT NULL DEFAULT '', + checkpoint_id TEXT NOT NULL, + write_idx INTEGER NOT NULL, + value TEXT NOT NULL, + value_type TEXT NOT NULL, + channel TEXT NOT NULL, + task_id TEXT NOT NULL, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), deleted_at TIMESTAMPTZ, diff --git a/go/core/pkg/migrations/core/000002_add_session_source.down.sql b/go/core/pkg/migrations/core/000002_add_session_source.down.sql deleted file mode 100644 index 0ef080534..000000000 --- a/go/core/pkg/migrations/core/000002_add_session_source.down.sql +++ /dev/null @@ -1,2 +0,0 @@ -DROP INDEX IF EXISTS idx_session_source; -ALTER TABLE session DROP COLUMN IF EXISTS source; diff --git a/go/core/pkg/migrations/core/000002_add_session_source.up.sql b/go/core/pkg/migrations/core/000002_add_session_source.up.sql deleted file mode 100644 index ca940b2f3..000000000 --- a/go/core/pkg/migrations/core/000002_add_session_source.up.sql +++ /dev/null @@ -1,6 +0,0 @@ --- Add session.source column, introduced after the initial GORM-managed schema. --- ALTER TABLE ... ADD COLUMN IF NOT EXISTS is idempotent: a no-op on fresh installs --- where migration 000001 created the table without this column, and adds the column --- on existing GORM-managed deployments upgrading to golang-migrate. -ALTER TABLE session ADD COLUMN IF NOT EXISTS source TEXT; -CREATE INDEX IF NOT EXISTS idx_session_source ON session(source); From 9951ef0d3aa35288db6e9b8d74aab31a8b5d0bb0 Mon Sep 17 00:00:00 2001 From: Jeremy Alvis Date: Mon, 30 Mar 2026 14:40:44 -0700 Subject: [PATCH 05/16] Better matching of what GORM produces for initial migrations Signed-off-by: Jeremy Alvis --- go/core/internal/database/upgrade_test.go | 77 ++++++++++--------- .../pkg/migrations/core/000001_initial.up.sql | 53 +++++++------ .../vector/000001_vector_support.up.sql | 4 +- 3 files changed, 73 insertions(+), 61 deletions(-) diff --git a/go/core/internal/database/upgrade_test.go b/go/core/internal/database/upgrade_test.go index 75e92ccd3..8369c3c29 100644 --- a/go/core/internal/database/upgrade_test.go +++ b/go/core/internal/database/upgrade_test.go @@ -31,14 +31,16 @@ import ( // Key differences from the current migrations: // - No schema_migrations / vector_schema_migrations tracking tables. // - memory.id has no DEFAULT (GORM relied on the BeforeCreate hook). -// - Indexes may have different names (GORM derives them from the struct name). +// - created_at/updated_at are nullable: GORM sets them in Go code, not via a +// DB default or NOT NULL constraint. +// - version, write_idx, access_count are BIGINT: GORM maps Go `int` to bigint. const gormSchema = ` CREATE EXTENSION IF NOT EXISTS vector; CREATE TABLE IF NOT EXISTS agent ( id TEXT PRIMARY KEY, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + created_at TIMESTAMPTZ, + updated_at TIMESTAMPTZ, deleted_at TIMESTAMPTZ, type TEXT NOT NULL, config JSON @@ -49,8 +51,8 @@ CREATE TABLE IF NOT EXISTS session ( id TEXT NOT NULL, user_id TEXT NOT NULL, name TEXT, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + created_at TIMESTAMPTZ, + updated_at TIMESTAMPTZ, deleted_at TIMESTAMPTZ, agent_id TEXT, source TEXT, @@ -65,8 +67,8 @@ CREATE TABLE IF NOT EXISTS event ( id TEXT NOT NULL, user_id TEXT NOT NULL, session_id TEXT, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + created_at TIMESTAMPTZ, + updated_at TIMESTAMPTZ, deleted_at TIMESTAMPTZ, data TEXT NOT NULL, PRIMARY KEY (id, user_id) @@ -76,8 +78,8 @@ CREATE INDEX IF NOT EXISTS idx_event_deleted_at ON event(deleted_at); CREATE TABLE IF NOT EXISTS task ( id TEXT PRIMARY KEY, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + created_at TIMESTAMPTZ, + updated_at TIMESTAMPTZ, deleted_at TIMESTAMPTZ, data TEXT NOT NULL, session_id TEXT @@ -88,8 +90,8 @@ CREATE INDEX IF NOT EXISTS idx_task_deleted_at ON task(deleted_at); CREATE TABLE IF NOT EXISTS push_notification ( id TEXT PRIMARY KEY, task_id TEXT NOT NULL, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + created_at TIMESTAMPTZ, + updated_at TIMESTAMPTZ, deleted_at TIMESTAMPTZ, data TEXT NOT NULL ); @@ -109,13 +111,14 @@ CREATE TABLE IF NOT EXISTS feedback ( ); CREATE INDEX IF NOT EXISTS idx_feedback_deleted_at ON feedback(deleted_at); CREATE INDEX IF NOT EXISTS idx_feedback_user_id ON feedback(user_id); +CREATE INDEX IF NOT EXISTS idx_feedback_message_id ON feedback(message_id); CREATE TABLE IF NOT EXISTS tool ( id TEXT NOT NULL, server_name TEXT NOT NULL, group_kind TEXT NOT NULL, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + created_at TIMESTAMPTZ, + updated_at TIMESTAMPTZ, deleted_at TIMESTAMPTZ, description TEXT, PRIMARY KEY (id, server_name, group_kind) @@ -125,8 +128,8 @@ CREATE INDEX IF NOT EXISTS idx_tool_deleted_at ON tool(deleted_at); CREATE TABLE IF NOT EXISTS toolserver ( name TEXT NOT NULL, group_kind TEXT NOT NULL, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + created_at TIMESTAMPTZ, + updated_at TIMESTAMPTZ, deleted_at TIMESTAMPTZ, description TEXT, last_connected TIMESTAMPTZ, @@ -140,13 +143,13 @@ CREATE TABLE IF NOT EXISTS lg_checkpoint ( checkpoint_ns TEXT NOT NULL DEFAULT '', checkpoint_id TEXT NOT NULL, parent_checkpoint_id TEXT, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + created_at TIMESTAMPTZ, + updated_at TIMESTAMPTZ, deleted_at TIMESTAMPTZ, metadata TEXT NOT NULL, checkpoint TEXT NOT NULL, checkpoint_type TEXT NOT NULL, - version INTEGER DEFAULT 1, + version BIGINT DEFAULT 1, PRIMARY KEY (user_id, thread_id, checkpoint_ns, checkpoint_id) ); CREATE INDEX IF NOT EXISTS idx_lg_checkpoint_parent_checkpoint_id ON lg_checkpoint(parent_checkpoint_id); @@ -154,17 +157,17 @@ CREATE INDEX IF NOT EXISTS idx_lgcp_list ON lg_checkpoi CREATE INDEX IF NOT EXISTS idx_lg_checkpoint_deleted_at ON lg_checkpoint(deleted_at); CREATE TABLE IF NOT EXISTS lg_checkpoint_write ( - user_id TEXT NOT NULL, - thread_id TEXT NOT NULL, - checkpoint_ns TEXT NOT NULL DEFAULT '', - checkpoint_id TEXT NOT NULL, - write_idx INTEGER NOT NULL, - value TEXT NOT NULL, - value_type TEXT NOT NULL, - channel TEXT NOT NULL, - task_id TEXT NOT NULL, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + user_id TEXT NOT NULL, + thread_id TEXT NOT NULL, + checkpoint_ns TEXT NOT NULL DEFAULT '', + checkpoint_id TEXT NOT NULL, + write_idx BIGINT NOT NULL, + value TEXT NOT NULL, + value_type TEXT NOT NULL, + channel TEXT NOT NULL, + task_id TEXT NOT NULL, + created_at TIMESTAMPTZ, + updated_at TIMESTAMPTZ, deleted_at TIMESTAMPTZ, PRIMARY KEY (user_id, thread_id, checkpoint_ns, checkpoint_id, write_idx) ); @@ -173,23 +176,27 @@ CREATE INDEX IF NOT EXISTS idx_lg_checkpoint_write_deleted_at ON lg_checkpoint_w CREATE TABLE IF NOT EXISTS crewai_agent_memory ( user_id TEXT NOT NULL, thread_id TEXT NOT NULL, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + created_at TIMESTAMPTZ, + updated_at TIMESTAMPTZ, deleted_at TIMESTAMPTZ, memory_data TEXT NOT NULL, PRIMARY KEY (user_id, thread_id) ); +CREATE INDEX IF NOT EXISTS idx_crewai_memory_list ON crewai_agent_memory(created_at); +CREATE INDEX IF NOT EXISTS idx_crewai_agent_memory_deleted_at ON crewai_agent_memory(deleted_at); CREATE TABLE IF NOT EXISTS crewai_flow_state ( user_id TEXT NOT NULL, thread_id TEXT NOT NULL, method_name TEXT NOT NULL, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + created_at TIMESTAMPTZ, + updated_at TIMESTAMPTZ, deleted_at TIMESTAMPTZ, state_data TEXT NOT NULL, PRIMARY KEY (user_id, thread_id, method_name) ); +CREATE INDEX IF NOT EXISTS idx_crewai_flow_state_list ON crewai_flow_state(created_at); +CREATE INDEX IF NOT EXISTS idx_crewai_flow_state_deleted_at ON crewai_flow_state(deleted_at); CREATE TABLE IF NOT EXISTS memory ( id TEXT PRIMARY KEY, @@ -198,9 +205,9 @@ CREATE TABLE IF NOT EXISTS memory ( content TEXT, embedding vector(768), metadata TEXT, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + created_at TIMESTAMPTZ, expires_at TIMESTAMPTZ, - access_count INTEGER DEFAULT 0 + access_count BIGINT DEFAULT 0 ); CREATE INDEX IF NOT EXISTS idx_memory_embedding_hnsw ON memory USING hnsw (embedding vector_cosine_ops); ` diff --git a/go/core/pkg/migrations/core/000001_initial.up.sql b/go/core/pkg/migrations/core/000001_initial.up.sql index b154ceedc..3b17c3f1c 100644 --- a/go/core/pkg/migrations/core/000001_initial.up.sql +++ b/go/core/pkg/migrations/core/000001_initial.up.sql @@ -1,10 +1,15 @@ -- Baseline migration: matches the schema produced by GORM AutoMigrate as of -- kagent v0.8.0. Upgrading to v0.8.0 before this version is required. +-- +-- Notes on column definitions vs. what you might expect: +-- - created_at/updated_at are nullable: GORM sets these in Go code, not via a +-- DB default or NOT NULL constraint. +-- - version, write_idx, access_count are BIGINT: GORM maps Go `int` to bigint. CREATE TABLE IF NOT EXISTS agent ( id TEXT PRIMARY KEY, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + created_at TIMESTAMPTZ, + updated_at TIMESTAMPTZ, deleted_at TIMESTAMPTZ, type TEXT NOT NULL, config JSON @@ -15,8 +20,8 @@ CREATE TABLE IF NOT EXISTS session ( id TEXT NOT NULL, user_id TEXT NOT NULL, name TEXT, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + created_at TIMESTAMPTZ, + updated_at TIMESTAMPTZ, deleted_at TIMESTAMPTZ, agent_id TEXT, source TEXT, @@ -31,8 +36,8 @@ CREATE TABLE IF NOT EXISTS event ( id TEXT NOT NULL, user_id TEXT NOT NULL, session_id TEXT, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + created_at TIMESTAMPTZ, + updated_at TIMESTAMPTZ, deleted_at TIMESTAMPTZ, data TEXT NOT NULL, PRIMARY KEY (id, user_id) @@ -42,8 +47,8 @@ CREATE INDEX IF NOT EXISTS idx_event_deleted_at ON event(deleted_at); CREATE TABLE IF NOT EXISTS task ( id TEXT PRIMARY KEY, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + created_at TIMESTAMPTZ, + updated_at TIMESTAMPTZ, deleted_at TIMESTAMPTZ, data TEXT NOT NULL, session_id TEXT @@ -54,8 +59,8 @@ CREATE INDEX IF NOT EXISTS idx_task_deleted_at ON task(deleted_at); CREATE TABLE IF NOT EXISTS push_notification ( id TEXT PRIMARY KEY, task_id TEXT NOT NULL, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + created_at TIMESTAMPTZ, + updated_at TIMESTAMPTZ, deleted_at TIMESTAMPTZ, data TEXT NOT NULL ); @@ -81,8 +86,8 @@ CREATE TABLE IF NOT EXISTS tool ( id TEXT NOT NULL, server_name TEXT NOT NULL, group_kind TEXT NOT NULL, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + created_at TIMESTAMPTZ, + updated_at TIMESTAMPTZ, deleted_at TIMESTAMPTZ, description TEXT, PRIMARY KEY (id, server_name, group_kind) @@ -92,8 +97,8 @@ CREATE INDEX IF NOT EXISTS idx_tool_deleted_at ON tool(deleted_at); CREATE TABLE IF NOT EXISTS toolserver ( name TEXT NOT NULL, group_kind TEXT NOT NULL, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + created_at TIMESTAMPTZ, + updated_at TIMESTAMPTZ, deleted_at TIMESTAMPTZ, description TEXT, last_connected TIMESTAMPTZ, @@ -107,13 +112,13 @@ CREATE TABLE IF NOT EXISTS lg_checkpoint ( checkpoint_ns TEXT NOT NULL DEFAULT '', checkpoint_id TEXT NOT NULL, parent_checkpoint_id TEXT, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + created_at TIMESTAMPTZ, + updated_at TIMESTAMPTZ, deleted_at TIMESTAMPTZ, metadata TEXT NOT NULL, checkpoint TEXT NOT NULL, checkpoint_type TEXT NOT NULL, - version INTEGER DEFAULT 1, + version BIGINT DEFAULT 1, PRIMARY KEY (user_id, thread_id, checkpoint_ns, checkpoint_id) ); CREATE INDEX IF NOT EXISTS idx_lg_checkpoint_parent_checkpoint_id ON lg_checkpoint(parent_checkpoint_id); @@ -125,13 +130,13 @@ CREATE TABLE IF NOT EXISTS lg_checkpoint_write ( thread_id TEXT NOT NULL, checkpoint_ns TEXT NOT NULL DEFAULT '', checkpoint_id TEXT NOT NULL, - write_idx INTEGER NOT NULL, + write_idx BIGINT NOT NULL, value TEXT NOT NULL, value_type TEXT NOT NULL, channel TEXT NOT NULL, task_id TEXT NOT NULL, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + created_at TIMESTAMPTZ, + updated_at TIMESTAMPTZ, deleted_at TIMESTAMPTZ, PRIMARY KEY (user_id, thread_id, checkpoint_ns, checkpoint_id, write_idx) ); @@ -140,8 +145,8 @@ CREATE INDEX IF NOT EXISTS idx_lg_checkpoint_write_deleted_at ON lg_checkpoint_w CREATE TABLE IF NOT EXISTS crewai_agent_memory ( user_id TEXT NOT NULL, thread_id TEXT NOT NULL, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + created_at TIMESTAMPTZ, + updated_at TIMESTAMPTZ, deleted_at TIMESTAMPTZ, memory_data TEXT NOT NULL, PRIMARY KEY (user_id, thread_id) @@ -153,8 +158,8 @@ CREATE TABLE IF NOT EXISTS crewai_flow_state ( user_id TEXT NOT NULL, thread_id TEXT NOT NULL, method_name TEXT NOT NULL, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + created_at TIMESTAMPTZ, + updated_at TIMESTAMPTZ, deleted_at TIMESTAMPTZ, state_data TEXT NOT NULL, PRIMARY KEY (user_id, thread_id, method_name) diff --git a/go/core/pkg/migrations/vector/000001_vector_support.up.sql b/go/core/pkg/migrations/vector/000001_vector_support.up.sql index 97cc13d98..3b1e66234 100644 --- a/go/core/pkg/migrations/vector/000001_vector_support.up.sql +++ b/go/core/pkg/migrations/vector/000001_vector_support.up.sql @@ -9,9 +9,9 @@ CREATE TABLE IF NOT EXISTS memory ( content TEXT, embedding vector(768), metadata TEXT, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + created_at TIMESTAMPTZ, expires_at TIMESTAMPTZ, - access_count INTEGER DEFAULT 0 + access_count BIGINT DEFAULT 0 ); CREATE INDEX IF NOT EXISTS idx_memory_agent_user ON memory(agent_name, user_id); CREATE INDEX IF NOT EXISTS idx_memory_expires_at ON memory(expires_at); From 81c041d8a8183c1b9f7458c7401ed99836afb745 Mon Sep 17 00:00:00 2001 From: Jeremy Alvis Date: Mon, 30 Mar 2026 15:09:12 -0700 Subject: [PATCH 06/16] Remove github workflow check for migration changes on 0.7.x release branch Signed-off-by: Jeremy Alvis --- .github/workflows/migration-immutability.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/migration-immutability.yaml b/.github/workflows/migration-immutability.yaml index d22063906..72c9cc28f 100644 --- a/.github/workflows/migration-immutability.yaml +++ b/.github/workflows/migration-immutability.yaml @@ -2,7 +2,7 @@ name: Migration Immutability on: pull_request: - branches: [main, release/v0.7.x] + branches: [main] paths: - "go/core/pkg/migrations/**" From e9aa20df5e214248231179e954e28a5a51737184 Mon Sep 17 00:00:00 2001 From: Jeremy Alvis Date: Tue, 31 Mar 2026 13:29:10 -0700 Subject: [PATCH 07/16] Add new db migrations skill info and add force command to migrator Signed-off-by: Jeremy Alvis --- .claude/skills/kagent-dev/SKILL.md | 1 + .../references/database-migrations.md | 123 ++++++++++++++++++ go/core/cmd/migrate/main.go | 38 +++++- 3 files changed, 161 insertions(+), 1 deletion(-) create mode 100644 .claude/skills/kagent-dev/references/database-migrations.md diff --git a/.claude/skills/kagent-dev/SKILL.md b/.claude/skills/kagent-dev/SKILL.md index 04ece8be7..9901b646e 100644 --- a/.claude/skills/kagent-dev/SKILL.md +++ b/.claude/skills/kagent-dev/SKILL.md @@ -349,3 +349,4 @@ Don't use Go template syntax (`{{ }}`) in doc comments — Helm will try to pars - `references/translator-guide.md` - Translator patterns, `deployments.go` and `adk_api_translator.go` - `references/e2e-debugging.md` - Comprehensive E2E debugging, local reproduction - `references/ci-failures.md` - CI failure patterns and fixes +- `references/database-migrations.md` - Migration authoring rules, multi-instance safety, GORM baseline, expand/contract pattern diff --git a/.claude/skills/kagent-dev/references/database-migrations.md b/.claude/skills/kagent-dev/references/database-migrations.md new file mode 100644 index 000000000..f2f5d09a7 --- /dev/null +++ b/.claude/skills/kagent-dev/references/database-migrations.md @@ -0,0 +1,123 @@ +# Database Migrations Guide + +kagent uses [golang-migrate](https://github.com/golang-migrate/migrate) with embedded SQL files. Migrations run as a Kubernetes **init container** (`kagent-migrate`) before the controller starts. + +## Structure + +``` +go/core/pkg/migrations/ +├── migrations.go # Embeds the FS (go:embed) +├── core/ # Core schema (tracked in schema_migrations table) +│ ├── 000001_initial.up.sql / .down.sql +│ ├── 000002_add_session_source.up.sql / .down.sql +│ └── ... +└── vector/ # pgvector schema (tracked in vector_schema_migrations table) + ├── 000001_vector_support.up.sql / .down.sql + └── ... +``` + +The `kagent-migrate` binary (in `go/core/cmd/migrate/`) runs `up` by default. It manages two independent tracks — `core` and `vector` — and rolls back both if either fails. + +## Writing Migrations + +### Version compatibility policy + +kagent supports **n-1 minor version** compatibility. Users must not skip minor versions when upgrading. This gives us a defined window for schema cleanup: + +- **Version N**: stop using the old column/table in application code; the schema still contains it (backward compatible with N-1) +- **Version N+1**: drop the old column/table (or N+2 for additional safety if rollback risk is high) + +Never migrate data and remove the old structure in the same migration — if the migration fails mid-way, rollback is much harder. Always separate the two steps across versions. + +### Backward-compatible schema changes (expand/contract) + +During a rolling deploy, old pods (running the previous code version) will be reading and writing a schema that has already been upgraded by the new pod's init container. **Every migration must be backward-compatible with the n-1 minor version's code.** Locking serializes concurrent migration runs but does nothing to protect old pods still running against the new schema. + +| Change | Old code behavior | Safe? | +|--------|------------------|-------| +| Add nullable column | SELECT ignores it; INSERT omits it (goes NULL) | ✅ | +| Add column with `DEFAULT x` | INSERT omits it; DB fills default | ✅ | +| Add NOT NULL column **without** default | Old INSERT missing the column → error | ❌ | +| Add index | Invisible to application code | ✅ | +| Add foreign key | Old INSERT may fail constraint | ❌ | +| Drop/rename column old code references | Old SELECT/INSERT errors | ❌ | +| Change compatible type (e.g. `int` → `bigint`) | Usually fine | ⚠️ | + +**Expand/contract pattern for destructive changes:** +1. **Version N (Expand)**: add the new column/table (nullable or with default); old code still works +2. **Version N (Deploy)**: ship new code that reads from the new structure, writes to both +3. **Version N+1 (Contract)**: drop the old column/table in a follow-on migration + +Never drop a column or rename a column in the same release as the code change that stops using it. + +### Naming + +Files must follow `NNNNNN_description.up.sql` / `NNNNNN_description.down.sql` with zero-padded 6-digit sequence numbers. + +### Down migrations + +Every `.up.sql` must have a corresponding `.down.sql` that exactly reverses it. Down migrations are used by the `kagent-migrate down --steps N --track core` command for rollbacks, and by automatic rollback on migration failure. They must be **idempotent** — the two-track rollback logic (roll back core if vector fails) may call them more than once in failure scenarios. + +## Multi-Instance Safety + +### How the advisory lock works + +golang-migrate acquires a PostgreSQL **session-level** advisory lock (`pg_advisory_lock`) before running. + +### Init container concurrency + +If multiple pods start simultaneously (e.g., rolling deploy with replicas > 1): +1. One init container acquires the advisory lock and runs migrations. +2. Others block on `pg_advisory_lock`. +3. When the winner finishes and its connection closes, the next waiter acquires the lock, calls `Up()`, gets `ErrNoChange`, and exits immediately. + +This is safe. The only risk is if the winning init container crashes mid-migration (see Dirty State below). + +### Dirty state recovery + +If `kagent-migrate` crashes mid-migration (OOMKill, pod eviction), golang-migrate records the version as `dirty = true` in the tracking table. The next run (after the advisory lock releases) will detect dirty state and call `rollbackToVersion`, which: +1. Calls `mg.Force(version - 1)` to clear the dirty flag. +2. Runs the down migration to restore the previous clean state. +3. Re-runs the failed up migration. + +**Requirement**: down migrations must be idempotent and correctly reverse their up migration. A missing or broken down migration requires manual recovery — see the `force` subcommand below. + + +### Rollout strategy + +For additive, backward-compatible migrations a rolling update is safe: + +1. New pod starts → `kagent-migrate up` runs (advisory lock serializes concurrent runs) +2. New pod passes readiness probe → old pod terminates +3. Backward-compatible schema means old pods continue operating during the window + +For a migration that is **not** backward-compatible, restructure it using expand/contract. + +## Running Migrations Locally + +```bash +# Apply all pending migrations +POSTGRES_DATABASE_URL="postgres://..." kagent-migrate up + +# Check current version on each track +POSTGRES_DATABASE_URL="..." kagent-migrate version + +# Roll back 1 step on core track +POSTGRES_DATABASE_URL="..." kagent-migrate down --steps 1 --track core + +# With vector support +KAGENT_DATABASE_VECTOR_ENABLED=true POSTGRES_DATABASE_URL="..." kagent-migrate up +``` + +### Manual dirty-state recovery + +If a migration was partially applied (dirty state), use `force` to reset to the last clean version before running `down`: + +```bash +# Force the tracking table to a specific version (clears dirty flag) +POSTGRES_DATABASE_URL="..." kagent-migrate force --track core +# Then re-run up, or roll back: +POSTGRES_DATABASE_URL="..." kagent-migrate down --steps 1 --track core +``` + +In a Kubernetes deployment, the init container runs automatically on every pod start. diff --git a/go/core/cmd/migrate/main.go b/go/core/cmd/migrate/main.go index 3d4717aea..ffd26ce71 100644 --- a/go/core/cmd/migrate/main.go +++ b/go/core/cmd/migrate/main.go @@ -11,6 +11,7 @@ // up Apply all pending migrations (default when no command is given) // down Roll back N migrations on a single track // version Print the current applied version and dirty flag for each track +// force Force the tracking table to a specific version (clears dirty flag) // // Required environment variable: // @@ -67,8 +68,10 @@ func main() { runDownCommand(url, migrations.FS, vectorEnabled, args) case "version": runVersionCommand(url, migrations.FS, vectorEnabled) + case "force": + runForceCommand(url, migrations.FS, vectorEnabled, args) default: - log.Fatalf("kagent-migrate: unknown command %q (valid: up, down, version)", cmd) + log.Fatalf("kagent-migrate: unknown command %q (valid: up, down, version, force)", cmd) } } @@ -226,6 +229,39 @@ func runVersionCommand(url string, migrationsFS fs.FS, vectorEnabled bool) { } } +func runForceCommand(url string, migrationsFS fs.FS, vectorEnabled bool, args []string) { + forceFlags := flag.NewFlagSet("force", flag.ExitOnError) + version := forceFlags.Int("version", -1, "version to force (required; use -1 to clear all migration history)") + track := forceFlags.String("track", "core", "migration track: core or vector") + if err := forceFlags.Parse(args); err != nil { + log.Fatalf("kagent-migrate: force: %v", err) + } + + var dir, table string + switch *track { + case "core": + dir, table = "core", "schema_migrations" + case "vector": + if !vectorEnabled { + log.Fatalf("kagent-migrate: force: track %q requested but KAGENT_DATABASE_VECTOR_ENABLED is not true", *track) + } + dir, table = "vector", "vector_schema_migrations" + default: + log.Fatalf("kagent-migrate: force: unknown track %q (valid: core, vector)", *track) + } + + mg, err := newMigrate(url, migrationsFS, dir, table) + if err != nil { + log.Fatalf("kagent-migrate: force: %v", err) + } + defer closeMigrate(dir, mg) + + if err := mg.Force(*version); err != nil { + log.Fatalf("kagent-migrate: force %s to version %d: %v", *track, *version, err) + } + log.Printf("kagent-migrate: forced %s track to version %d (dirty flag cleared)", *track, *version) +} + func resolveURL() (string, error) { if file := os.Getenv("POSTGRES_DATABASE_URL_FILE"); file != "" { content, err := os.ReadFile(file) From e4bd217af6189bd919ce5c1f469c1b870bfae481 Mon Sep 17 00:00:00 2001 From: Jeremy Alvis Date: Tue, 31 Mar 2026 14:41:07 -0700 Subject: [PATCH 08/16] Remove temporary upgrade test Signed-off-by: Jeremy Alvis --- go/core/internal/database/upgrade_test.go | 367 ---------------------- 1 file changed, 367 deletions(-) delete mode 100644 go/core/internal/database/upgrade_test.go diff --git a/go/core/internal/database/upgrade_test.go b/go/core/internal/database/upgrade_test.go deleted file mode 100644 index 8369c3c29..000000000 --- a/go/core/internal/database/upgrade_test.go +++ /dev/null @@ -1,367 +0,0 @@ -package database - -// TestUpgradeFromGORM validates that the golang-migrate migrations run cleanly -// against a database that was previously managed by GORM AutoMigrate (kagent -// v0.8.0), and that pre-existing data is accessible via the new sqlc client -// afterwards. v0.8.0 is the minimum required version before upgrading to this -// release. -// -// It simulates an existing v0.8.0 deployment by: -// 1. Creating the schema that GORM AutoMigrate would have produced (no migration -// tracking tables, no gen_random_uuid() default on memory.id). -// 2. Seeding representative rows, including soft-deleted CrewAI rows that GORM's -// Delete() hook would have left behind. -// 3. Running the new golang-migrate migrations. -// 4. Verifying that all pre-existing data is readable and that new writes work. - -import ( - "context" - "database/sql" - "testing" - "time" - - dbpkg "github.com/kagent-dev/kagent/go/api/database" - "github.com/kagent-dev/kagent/go/core/internal/dbtest" - "github.com/pgvector/pgvector-go" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// gormSchema reproduces the DDL that GORM AutoMigrate emitted for kagent v0.8.0. -// Key differences from the current migrations: -// - No schema_migrations / vector_schema_migrations tracking tables. -// - memory.id has no DEFAULT (GORM relied on the BeforeCreate hook). -// - created_at/updated_at are nullable: GORM sets them in Go code, not via a -// DB default or NOT NULL constraint. -// - version, write_idx, access_count are BIGINT: GORM maps Go `int` to bigint. -const gormSchema = ` -CREATE EXTENSION IF NOT EXISTS vector; - -CREATE TABLE IF NOT EXISTS agent ( - id TEXT PRIMARY KEY, - created_at TIMESTAMPTZ, - updated_at TIMESTAMPTZ, - deleted_at TIMESTAMPTZ, - type TEXT NOT NULL, - config JSON -); -CREATE INDEX IF NOT EXISTS idx_agent_deleted_at ON agent(deleted_at); - -CREATE TABLE IF NOT EXISTS session ( - id TEXT NOT NULL, - user_id TEXT NOT NULL, - name TEXT, - created_at TIMESTAMPTZ, - updated_at TIMESTAMPTZ, - deleted_at TIMESTAMPTZ, - agent_id TEXT, - source TEXT, - PRIMARY KEY (id, user_id) -); -CREATE INDEX IF NOT EXISTS idx_session_name ON session(name); -CREATE INDEX IF NOT EXISTS idx_session_agent_id ON session(agent_id); -CREATE INDEX IF NOT EXISTS idx_session_deleted_at ON session(deleted_at); -CREATE INDEX IF NOT EXISTS idx_session_source ON session(source); - -CREATE TABLE IF NOT EXISTS event ( - id TEXT NOT NULL, - user_id TEXT NOT NULL, - session_id TEXT, - created_at TIMESTAMPTZ, - updated_at TIMESTAMPTZ, - deleted_at TIMESTAMPTZ, - data TEXT NOT NULL, - PRIMARY KEY (id, user_id) -); -CREATE INDEX IF NOT EXISTS idx_event_session_id ON event(session_id); -CREATE INDEX IF NOT EXISTS idx_event_deleted_at ON event(deleted_at); - -CREATE TABLE IF NOT EXISTS task ( - id TEXT PRIMARY KEY, - created_at TIMESTAMPTZ, - updated_at TIMESTAMPTZ, - deleted_at TIMESTAMPTZ, - data TEXT NOT NULL, - session_id TEXT -); -CREATE INDEX IF NOT EXISTS idx_task_session_id ON task(session_id); -CREATE INDEX IF NOT EXISTS idx_task_deleted_at ON task(deleted_at); - -CREATE TABLE IF NOT EXISTS push_notification ( - id TEXT PRIMARY KEY, - task_id TEXT NOT NULL, - created_at TIMESTAMPTZ, - updated_at TIMESTAMPTZ, - deleted_at TIMESTAMPTZ, - data TEXT NOT NULL -); -CREATE INDEX IF NOT EXISTS idx_push_notification_task_id ON push_notification(task_id); -CREATE INDEX IF NOT EXISTS idx_push_notification_deleted_at ON push_notification(deleted_at); - -CREATE TABLE IF NOT EXISTS feedback ( - id BIGSERIAL PRIMARY KEY, - created_at TIMESTAMPTZ, - updated_at TIMESTAMPTZ, - deleted_at TIMESTAMPTZ, - user_id TEXT NOT NULL, - message_id BIGINT, - is_positive BOOLEAN DEFAULT false, - feedback_text TEXT NOT NULL, - issue_type TEXT -); -CREATE INDEX IF NOT EXISTS idx_feedback_deleted_at ON feedback(deleted_at); -CREATE INDEX IF NOT EXISTS idx_feedback_user_id ON feedback(user_id); -CREATE INDEX IF NOT EXISTS idx_feedback_message_id ON feedback(message_id); - -CREATE TABLE IF NOT EXISTS tool ( - id TEXT NOT NULL, - server_name TEXT NOT NULL, - group_kind TEXT NOT NULL, - created_at TIMESTAMPTZ, - updated_at TIMESTAMPTZ, - deleted_at TIMESTAMPTZ, - description TEXT, - PRIMARY KEY (id, server_name, group_kind) -); -CREATE INDEX IF NOT EXISTS idx_tool_deleted_at ON tool(deleted_at); - -CREATE TABLE IF NOT EXISTS toolserver ( - name TEXT NOT NULL, - group_kind TEXT NOT NULL, - created_at TIMESTAMPTZ, - updated_at TIMESTAMPTZ, - deleted_at TIMESTAMPTZ, - description TEXT, - last_connected TIMESTAMPTZ, - PRIMARY KEY (name, group_kind) -); -CREATE INDEX IF NOT EXISTS idx_toolserver_deleted_at ON toolserver(deleted_at); - -CREATE TABLE IF NOT EXISTS lg_checkpoint ( - user_id TEXT NOT NULL, - thread_id TEXT NOT NULL, - checkpoint_ns TEXT NOT NULL DEFAULT '', - checkpoint_id TEXT NOT NULL, - parent_checkpoint_id TEXT, - created_at TIMESTAMPTZ, - updated_at TIMESTAMPTZ, - deleted_at TIMESTAMPTZ, - metadata TEXT NOT NULL, - checkpoint TEXT NOT NULL, - checkpoint_type TEXT NOT NULL, - version BIGINT DEFAULT 1, - PRIMARY KEY (user_id, thread_id, checkpoint_ns, checkpoint_id) -); -CREATE INDEX IF NOT EXISTS idx_lg_checkpoint_parent_checkpoint_id ON lg_checkpoint(parent_checkpoint_id); -CREATE INDEX IF NOT EXISTS idx_lgcp_list ON lg_checkpoint(created_at); -CREATE INDEX IF NOT EXISTS idx_lg_checkpoint_deleted_at ON lg_checkpoint(deleted_at); - -CREATE TABLE IF NOT EXISTS lg_checkpoint_write ( - user_id TEXT NOT NULL, - thread_id TEXT NOT NULL, - checkpoint_ns TEXT NOT NULL DEFAULT '', - checkpoint_id TEXT NOT NULL, - write_idx BIGINT NOT NULL, - value TEXT NOT NULL, - value_type TEXT NOT NULL, - channel TEXT NOT NULL, - task_id TEXT NOT NULL, - created_at TIMESTAMPTZ, - updated_at TIMESTAMPTZ, - deleted_at TIMESTAMPTZ, - PRIMARY KEY (user_id, thread_id, checkpoint_ns, checkpoint_id, write_idx) -); -CREATE INDEX IF NOT EXISTS idx_lg_checkpoint_write_deleted_at ON lg_checkpoint_write(deleted_at); - -CREATE TABLE IF NOT EXISTS crewai_agent_memory ( - user_id TEXT NOT NULL, - thread_id TEXT NOT NULL, - created_at TIMESTAMPTZ, - updated_at TIMESTAMPTZ, - deleted_at TIMESTAMPTZ, - memory_data TEXT NOT NULL, - PRIMARY KEY (user_id, thread_id) -); -CREATE INDEX IF NOT EXISTS idx_crewai_memory_list ON crewai_agent_memory(created_at); -CREATE INDEX IF NOT EXISTS idx_crewai_agent_memory_deleted_at ON crewai_agent_memory(deleted_at); - -CREATE TABLE IF NOT EXISTS crewai_flow_state ( - user_id TEXT NOT NULL, - thread_id TEXT NOT NULL, - method_name TEXT NOT NULL, - created_at TIMESTAMPTZ, - updated_at TIMESTAMPTZ, - deleted_at TIMESTAMPTZ, - state_data TEXT NOT NULL, - PRIMARY KEY (user_id, thread_id, method_name) -); -CREATE INDEX IF NOT EXISTS idx_crewai_flow_state_list ON crewai_flow_state(created_at); -CREATE INDEX IF NOT EXISTS idx_crewai_flow_state_deleted_at ON crewai_flow_state(deleted_at); - -CREATE TABLE IF NOT EXISTS memory ( - id TEXT PRIMARY KEY, - agent_name TEXT, - user_id TEXT, - content TEXT, - embedding vector(768), - metadata TEXT, - created_at TIMESTAMPTZ, - expires_at TIMESTAMPTZ, - access_count BIGINT DEFAULT 0 -); -CREATE INDEX IF NOT EXISTS idx_memory_embedding_hnsw ON memory USING hnsw (embedding vector_cosine_ops); -` - -func TestUpgradeFromGORM(t *testing.T) { - if testing.Short() { - t.Skip("skipping upgrade test in short mode") - } - - ctx := context.Background() - connStr := dbtest.StartT(ctx, t) - - // ── Step 1: apply the GORM-era schema ──────────────────────────────────── - rawDB, err := sql.Open("pgx", connStr) - require.NoError(t, err) - t.Cleanup(func() { rawDB.Close() }) - - _, err = rawDB.ExecContext(ctx, gormSchema) - require.NoError(t, err, "GORM schema setup failed") - - // ── Step 2: seed pre-migration data ────────────────────────────────────── - now := time.Now().UTC().Truncate(time.Millisecond) - softDeleted := now.Add(-24 * time.Hour) - - seeds := []struct { - name string - query string - args []any - }{ - { - "agent", - `INSERT INTO agent (id, type, created_at, updated_at) VALUES ($1, $2, $3, $4)`, - []any{"agent-1", "autogen", now, now}, - }, - { - "session", - `INSERT INTO session (id, user_id, name, agent_id, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6)`, - []any{"session-1", "user-1", "test session", "agent-1", now, now}, - }, - { - "event", - `INSERT INTO event (id, user_id, session_id, data, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6)`, - []any{"event-1", "user-1", "session-1", `{"role":"user"}`, now, now}, - }, - { - "task", - `INSERT INTO task (id, session_id, data, created_at, updated_at) VALUES ($1, $2, $3, $4, $5)`, - []any{"task-1", "session-1", `{"id":"task-1"}`, now, now}, - }, - { - "toolserver", - `INSERT INTO toolserver (name, group_kind, description, created_at, updated_at) VALUES ($1, $2, $3, $4, $5)`, - []any{"server-1", "MCPServer.kagent.dev", "test server", now, now}, - }, - { - "tool", - `INSERT INTO tool (id, server_name, group_kind, description, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6)`, - []any{"tool-1", "server-1", "MCPServer.kagent.dev", "a tool", now, now}, - }, - // Soft-deleted CrewAI memory row — simulates GORM's Delete() behaviour. - // After migration the upsert must revive it (deleted_at = NULL). - { - "crewai_agent_memory (soft-deleted)", - `INSERT INTO crewai_agent_memory (user_id, thread_id, memory_data, created_at, updated_at, deleted_at) VALUES ($1, $2, $3, $4, $5, $6)`, - []any{"user-1", "thread-1", `{"task_description":"old task"}`, now, now, softDeleted}, - }, - // Soft-deleted CrewAI flow state row — same scenario. - { - "crewai_flow_state (soft-deleted)", - `INSERT INTO crewai_flow_state (user_id, thread_id, method_name, state_data, created_at, updated_at, deleted_at) VALUES ($1, $2, $3, $4, $5, $6, $7)`, - []any{"user-1", "thread-1", "kickoff", `{"status":"done"}`, now, now, softDeleted}, - }, - // Memory row with a manually supplied ID (old GORM BeforeCreate behaviour). - { - "memory", - `INSERT INTO memory (id, agent_name, user_id, content, embedding, metadata, created_at) VALUES ($1, $2, $3, $4, $5, $6, $7)`, - []any{"mem-1", "agent-1", "user-1", "hello world", pgvector.NewVector(make([]float32, 768)), "{}", now}, - }, - } - - for _, s := range seeds { - _, err := rawDB.ExecContext(ctx, s.query, s.args...) - require.NoError(t, err, "seeding %s failed", s.name) - } - - // ── Step 3: run the new migrations ─────────────────────────────────────── - dbtest.MigrateT(t, connStr, true) - - // ── Step 4: connect via the new client ─────────────────────────────────── - db, err := Connect(ctx, &PostgresConfig{URL: connStr}) - require.NoError(t, err) - client := NewClient(db) - - // ── Step 5: verify pre-existing data is readable ───────────────────────── - - agent, err := client.GetAgent(ctx, "agent-1") - require.NoError(t, err) - assert.Equal(t, "agent-1", agent.ID) - assert.Equal(t, "autogen", agent.Type) - - session, err := client.GetSession(ctx, "session-1", "user-1") - require.NoError(t, err) - assert.Equal(t, "session-1", session.ID) - assert.Equal(t, "agent-1", *session.AgentID) - - events, err := client.ListEventsForSession(ctx, "session-1", "user-1", dbpkg.QueryOptions{}) - require.NoError(t, err) - require.Len(t, events, 1) - assert.Equal(t, "event-1", events[0].ID) - - tasks, err := client.ListTasksForSession(ctx, "session-1") - require.NoError(t, err) - require.Len(t, tasks, 1) - - toolServer, err := client.GetToolServer(ctx, "server-1") - require.NoError(t, err) - assert.Equal(t, "server-1", toolServer.Name) - - tools, err := client.ListToolsForServer(ctx, "server-1", "MCPServer.kagent.dev") - require.NoError(t, err) - require.Len(t, tools, 1) - assert.Equal(t, "tool-1", tools[0].ID) - - // ── Step 6: verify soft-deleted CrewAI rows are revived by upsert ──────── - // Before upsert both rows are invisible (deleted_at IS NOT NULL). - results, err := client.SearchCrewAIMemoryByTask(ctx, "user-1", "thread-1", "old task", 10) - require.NoError(t, err) - assert.Empty(t, results, "soft-deleted memory should be invisible before upsert") - - err = client.StoreCrewAIMemory(ctx, &dbpkg.CrewAIAgentMemory{ - UserID: "user-1", - ThreadID: "thread-1", - MemoryData: `{"task_description":"old task"}`, - }) - require.NoError(t, err) - - results, err = client.SearchCrewAIMemoryByTask(ctx, "user-1", "thread-1", "old task", 10) - require.NoError(t, err) - assert.Len(t, results, 1, "upsert should revive soft-deleted memory row") - - // ── Step 7: verify new writes work (gen_random_uuid() default) ─────────── - embedding := pgvector.NewVector(make([]float32, 768)) - mem := &dbpkg.Memory{ - AgentName: "agent-1", - UserID: "user-1", - Content: "new memory content", - Embedding: embedding, - Metadata: "{}", - } - err = client.StoreAgentMemory(ctx, mem) - require.NoError(t, err) - assert.NotEmpty(t, mem.ID, "StoreAgentMemory should populate ID via gen_random_uuid()") - - memories, err := client.ListAgentMemories(ctx, "agent-1", "user-1") - require.NoError(t, err) - assert.Len(t, memories, 2, "should see the seeded memory row and the new one") -} From 7034b1c4d570af83eefb4d7ea9b90138809908f2 Mon Sep 17 00:00:00 2001 From: Jeremy Alvis Date: Wed, 1 Apr 2026 09:56:54 -0700 Subject: [PATCH 09/16] Switch to using in-app migration with extension Signed-off-by: Jeremy Alvis --- .github/workflows/tag.yaml | 1 - Makefile | 10 +- go/adk/pkg/a2a/converter_test.go | 1 + go/adk/pkg/a2a/hitl_test.go | 3 + go/adk/pkg/config/config_loader_test.go | 2 + go/core/cmd/controller/main.go | 2 +- go/core/cmd/migrate/main.go | 347 ------------------ go/core/internal/database/connect.go | 10 + go/core/internal/dbtest/dbtest.go | 93 +---- go/core/pkg/app/app.go | 37 +- go/core/pkg/migrations/runner.go | 218 +++++++++++ .../migrations/runner_test.go} | 56 ++- .../templates/controller-deployment.yaml | 30 -- helm/kagent/values.yaml | 6 - 14 files changed, 319 insertions(+), 497 deletions(-) delete mode 100644 go/core/cmd/migrate/main.go create mode 100644 go/core/pkg/migrations/runner.go rename go/core/{cmd/migrate/main_test.go => pkg/migrations/runner_test.go} (84%) diff --git a/.github/workflows/tag.yaml b/.github/workflows/tag.yaml index 443f618f4..3eccd3787 100644 --- a/.github/workflows/tag.yaml +++ b/.github/workflows/tag.yaml @@ -18,7 +18,6 @@ jobs: matrix: image: - controller - - migrate - ui - app - golang-adk diff --git a/Makefile b/Makefile index d6066b583..747fac545 100644 --- a/Makefile +++ b/Makefile @@ -40,7 +40,6 @@ APP_IMAGE_NAME ?= app KAGENT_ADK_IMAGE_NAME ?= kagent-adk GOLANG_ADK_IMAGE_NAME ?= golang-adk SKILLS_INIT_IMAGE_NAME ?= skills-init -MIGRATE_IMAGE_NAME ?= migrate CONTROLLER_IMAGE_TAG ?= $(VERSION) UI_IMAGE_TAG ?= $(VERSION) @@ -48,15 +47,12 @@ APP_IMAGE_TAG ?= $(VERSION) KAGENT_ADK_IMAGE_TAG ?= $(VERSION) GOLANG_ADK_IMAGE_TAG ?= $(VERSION) SKILLS_INIT_IMAGE_TAG ?= $(VERSION) -MIGRATE_IMAGE_TAG ?= $(VERSION) - CONTROLLER_IMG ?= $(DOCKER_REGISTRY)/$(DOCKER_REPO)/$(CONTROLLER_IMAGE_NAME):$(CONTROLLER_IMAGE_TAG) UI_IMG ?= $(DOCKER_REGISTRY)/$(DOCKER_REPO)/$(UI_IMAGE_NAME):$(UI_IMAGE_TAG) APP_IMG ?= $(DOCKER_REGISTRY)/$(DOCKER_REPO)/$(APP_IMAGE_NAME):$(APP_IMAGE_TAG) KAGENT_ADK_IMG ?= $(DOCKER_REGISTRY)/$(DOCKER_REPO)/$(KAGENT_ADK_IMAGE_NAME):$(KAGENT_ADK_IMAGE_TAG) GOLANG_ADK_IMG ?= $(DOCKER_REGISTRY)/$(DOCKER_REPO)/$(GOLANG_ADK_IMAGE_NAME):$(GOLANG_ADK_IMAGE_TAG) SKILLS_INIT_IMG ?= $(DOCKER_REGISTRY)/$(DOCKER_REPO)/$(SKILLS_INIT_IMAGE_NAME):$(SKILLS_INIT_IMAGE_TAG) -MIGRATE_IMG ?= $(DOCKER_REGISTRY)/$(DOCKER_REPO)/$(MIGRATE_IMAGE_NAME):$(MIGRATE_IMAGE_TAG) #take from go/go.mod AWK ?= $(shell command -v gawk || command -v awk) @@ -222,7 +218,7 @@ prune-docker-images: docker images --filter dangling=true -q | xargs -r docker rmi || : .PHONY: build -build: buildx-create build-controller build-migrate build-ui build-app build-golang-adk build-skills-init +build: buildx-create build-controller build-ui build-app build-golang-adk build-skills-init @echo "Build completed successfully." @echo "Controller Image: $(CONTROLLER_IMG)" @echo "UI Image: $(UI_IMG)" @@ -270,10 +266,6 @@ controller-manifests: build-controller: buildx-create controller-manifests $(DOCKER_BUILDER) build $(DOCKER_BUILD_ARGS) $(TOOLS_IMAGE_BUILD_ARGS) --build-arg BUILD_PACKAGE=core/cmd/controller/main.go -t $(CONTROLLER_IMG) -f go/Dockerfile ./go -.PHONY: build-migrate -build-migrate: buildx-create - $(DOCKER_BUILDER) build $(DOCKER_BUILD_ARGS) $(TOOLS_IMAGE_BUILD_ARGS) --build-arg BUILD_PACKAGE=core/cmd/migrate/main.go -t $(MIGRATE_IMG) -f go/Dockerfile ./go - .PHONY: build-ui build-ui: buildx-create $(DOCKER_BUILDER) build $(DOCKER_BUILD_ARGS) $(TOOLS_IMAGE_BUILD_ARGS) -t $(UI_IMG) -f ui/Dockerfile ./ui diff --git a/go/adk/pkg/a2a/converter_test.go b/go/adk/pkg/a2a/converter_test.go index 18aea2606..43df60098 100644 --- a/go/adk/pkg/a2a/converter_test.go +++ b/go/adk/pkg/a2a/converter_test.go @@ -125,6 +125,7 @@ func TestMessageToGenAIContent_TextPart(t *testing.T) { } if content == nil { t.Fatal("expected non-nil content") + return } if len(content.Parts) != 1 { t.Fatalf("expected 1 part, got %d", len(content.Parts)) diff --git a/go/adk/pkg/a2a/hitl_test.go b/go/adk/pkg/a2a/hitl_test.go index 2a0ee4fbf..1a7416bf1 100644 --- a/go/adk/pkg/a2a/hitl_test.go +++ b/go/adk/pkg/a2a/hitl_test.go @@ -481,6 +481,7 @@ func TestBuildResumeHITLMessage(t *testing.T) { resume := BuildResumeHITLMessage(storedTask, incoming) if resume == nil { t.Fatal("BuildResumeHITLMessage() returned nil") + return } if len(resume.Parts) != 1 { t.Fatalf("resume parts len = %d, want 1", len(resume.Parts)) @@ -488,6 +489,7 @@ func TestBuildResumeHITLMessage(t *testing.T) { dp := asDataPart(resume.Parts[0]) if dp == nil { t.Fatal("resume part is not a DataPart") + return } if dp.Data[PartKeyName] != "adk_request_confirmation" { t.Fatalf("resume FunctionResponse name = %#v", dp.Data[PartKeyName]) @@ -516,6 +518,7 @@ func TestProcessHitlDecision(t *testing.T) { dp := asDataPart(parts[0]) if dp == nil { t.Fatal("part is not DataPart") + return } if dp.Data[PartKeyName] != "adk_request_confirmation" { t.Errorf("name = %v", dp.Data[PartKeyName]) diff --git a/go/adk/pkg/config/config_loader_test.go b/go/adk/pkg/config/config_loader_test.go index c081f5b7f..101f13f50 100644 --- a/go/adk/pkg/config/config_loader_test.go +++ b/go/adk/pkg/config/config_loader_test.go @@ -45,6 +45,7 @@ func TestLoadAgentConfig(t *testing.T) { if config == nil { t.Fatal("LoadAgentConfig() returned nil config") + return } // Check that model was loaded @@ -92,6 +93,7 @@ func TestLoadAgentCard(t *testing.T) { if card == nil { t.Fatal("LoadAgentCard() returned nil card") + return } if card.Name != "test-agent" { diff --git a/go/core/cmd/controller/main.go b/go/core/cmd/controller/main.go index 7fea43458..561f54665 100644 --- a/go/core/cmd/controller/main.go +++ b/go/core/cmd/controller/main.go @@ -35,5 +35,5 @@ func main() { Authorizer: authorizer, AgentPlugins: nil, }, nil - }) + }, nil) } diff --git a/go/core/cmd/migrate/main.go b/go/core/cmd/migrate/main.go deleted file mode 100644 index ffd26ce71..000000000 --- a/go/core/cmd/migrate/main.go +++ /dev/null @@ -1,347 +0,0 @@ -// kagent-migrate runs Postgres schema migrations and exits. -// It is intended to run as a Kubernetes init container before the kagent -// controller starts, ensuring the schema is up to date before the app connects. -// -// Usage: -// -// kagent-migrate [command] -// -// Commands: -// -// up Apply all pending migrations (default when no command is given) -// down Roll back N migrations on a single track -// version Print the current applied version and dirty flag for each track -// force Force the tracking table to a specific version (clears dirty flag) -// -// Required environment variable: -// -// POSTGRES_DATABASE_URL — Postgres connection URL -// -// Optional environment variables: -// -// POSTGRES_DATABASE_URL_FILE — path to a file containing the URL (takes precedence) -// KAGENT_DATABASE_VECTOR_ENABLED — set to "true" to also run vector migrations -// -// Enterprise extension: replace this binary with enterprise-migrate, which imports -// go/core/pkg/migrations.FS directly via the OSS Go module dependency and adds its -// own migration passes alongside it at compile time. -package main - -import ( - "database/sql" - "errors" - "flag" - "fmt" - "io/fs" - "log" - "os" - "strings" - - "github.com/golang-migrate/migrate/v4" - migratepgx "github.com/golang-migrate/migrate/v4/database/pgx/v5" - "github.com/golang-migrate/migrate/v4/source/iofs" - _ "github.com/jackc/pgx/v5/stdlib" - "github.com/kagent-dev/kagent/go/core/pkg/migrations" -) - -func main() { - flag.Parse() - - url, err := resolveURL() - if err != nil { - log.Fatalf("kagent-migrate: %v", err) - } - - vectorEnabled := strings.EqualFold(os.Getenv("KAGENT_DATABASE_VECTOR_ENABLED"), "true") - - cmd := "up" - args := flag.Args() - if len(args) > 0 { - cmd = args[0] - args = args[1:] - } - - switch cmd { - case "up": - runUpCommand(url, migrations.FS, vectorEnabled) - case "down": - runDownCommand(url, migrations.FS, vectorEnabled, args) - case "version": - runVersionCommand(url, migrations.FS, vectorEnabled) - case "force": - runForceCommand(url, migrations.FS, vectorEnabled, args) - default: - log.Fatalf("kagent-migrate: unknown command %q (valid: up, down, version, force)", cmd) - } -} - -func runUpCommand(url string, migrationsFS fs.FS, vectorEnabled bool) { - corePrev, err := applyDir(url, migrationsFS, "core", "schema_migrations") - if err != nil { - log.Fatalf("kagent-migrate: core migrations: %v", err) - } - log.Println("kagent-migrate: core migrations applied") - - if vectorEnabled { - if _, err := applyDir(url, migrationsFS, "vector", "vector_schema_migrations"); err != nil { - // Vector failed (and already rolled itself back). Roll back core too - // since both tracks are treated as one unit. - log.Printf("kagent-migrate: rolling back core to version %d", corePrev) - rollbackDir(url, migrationsFS, "core", "schema_migrations", corePrev) - log.Fatalf("kagent-migrate: vector migrations: %v", err) - } - log.Println("kagent-migrate: vector migrations applied") - } - - log.Println("kagent-migrate: done") -} - -// applyDir runs Up for dir and rolls back on failure. It returns the pre-run -// version so the caller can roll back this track if a later track fails. -func applyDir(url string, migrationsFS fs.FS, dir, migrationsTable string) (prevVersion uint, err error) { - mg, err := newMigrate(url, migrationsFS, dir, migrationsTable) - if err != nil { - return 0, err - } - defer closeMigrate(dir, mg) - - prevVersion, _, err = mg.Version() - if err != nil && !errors.Is(err, migrate.ErrNilVersion) { - return 0, fmt.Errorf("get pre-migration version for %s: %w", dir, err) - } - // prevVersion == 0 when ErrNilVersion (no migrations applied yet). - - if upErr := mg.Up(); upErr != nil { - if errors.Is(upErr, migrate.ErrNoChange) { - return prevVersion, nil - } - log.Printf("kagent-migrate: migration failed for %s, attempting rollback to version %d", dir, prevVersion) - if rbErr := rollbackToVersion(mg, dir, prevVersion); rbErr != nil { - log.Printf("kagent-migrate: rollback failed for %s: %v", dir, rbErr) - } else { - log.Printf("kagent-migrate: rolled back %s to version %d", dir, prevVersion) - } - return prevVersion, fmt.Errorf("run migrations for %s: %w", dir, upErr) - } - return prevVersion, nil -} - -// rollbackDir opens a fresh migrate instance and rolls dir back to targetVersion. -// Used to roll back a previously-succeeded track when a later track fails. -func rollbackDir(url string, migrationsFS fs.FS, dir, migrationsTable string, targetVersion uint) { - mg, err := newMigrate(url, migrationsFS, dir, migrationsTable) - if err != nil { - log.Printf("kagent-migrate: rollback of %s failed (open): %v", dir, err) - return - } - defer closeMigrate(dir, mg) - if err := rollbackToVersion(mg, dir, targetVersion); err != nil { - log.Printf("kagent-migrate: rollback of %s failed: %v", dir, err) - } else { - log.Printf("kagent-migrate: rolled back %s to version %d", dir, targetVersion) - } -} - -// rollbackToVersion rolls the migration state back to targetVersion. -// It handles the dirty-state cleanup golang-migrate requires after a failed -// Up run before down steps can be applied. -func rollbackToVersion(mg *migrate.Migrate, dir string, targetVersion uint) error { - currentVersion, dirty, err := mg.Version() - if err != nil { - if errors.Is(err, migrate.ErrNilVersion) { - return nil // nothing was applied; nothing to roll back - } - return fmt.Errorf("get version after failure for %s: %w", dir, err) - } - - if dirty { - // The failed migration is recorded as dirty at currentVersion. - // Force to the last clean version so Steps can run. - cleanVersion := int(currentVersion) - 1 - forceTarget := cleanVersion - if forceTarget < 1 { - forceTarget = -1 // negative tells golang-migrate to remove the version record entirely - } - if err := mg.Force(forceTarget); err != nil { - return fmt.Errorf("clear dirty state for %s: %w", dir, err) - } - if forceTarget < 0 { - return nil // first migration failed and was cleared; nothing left to roll back - } - currentVersion = uint(cleanVersion) - } - - steps := int(currentVersion) - int(targetVersion) - if steps <= 0 { - return nil - } - if err := mg.Steps(-steps); err != nil && !errors.Is(err, migrate.ErrNoChange) { - return fmt.Errorf("roll back %d step(s) for %s: %w", steps, dir, err) - } - return nil -} - -func runDownCommand(url string, migrationsFS fs.FS, vectorEnabled bool, args []string) { - downFlags := flag.NewFlagSet("down", flag.ExitOnError) - steps := downFlags.Int("steps", 0, "number of down migrations to run (required, must be > 0)") - track := downFlags.String("track", "core", "migration track to roll back: core or vector") - if err := downFlags.Parse(args); err != nil { - log.Fatalf("kagent-migrate: down: %v", err) - } - - if *steps <= 0 { - log.Fatalf("kagent-migrate: down: --steps must be a positive integer") - } - - var dir, table string - switch *track { - case "core": - dir, table = "core", "schema_migrations" - case "vector": - if !vectorEnabled { - log.Fatalf("kagent-migrate: down: track %q requested but KAGENT_DATABASE_VECTOR_ENABLED is not true", *track) - } - dir, table = "vector", "vector_schema_migrations" - default: - log.Fatalf("kagent-migrate: down: unknown track %q (valid: core, vector)", *track) - } - - if err := downDir(url, migrationsFS, dir, table, *steps); err != nil { - log.Fatalf("kagent-migrate: down %s (%d steps): %v", *track, *steps, err) - } - log.Printf("kagent-migrate: rolled back %d migration(s) on %s track", *steps, *track) -} - -func runVersionCommand(url string, migrationsFS fs.FS, vectorEnabled bool) { - tracks := []struct{ dir, table string }{ - {"core", "schema_migrations"}, - } - if vectorEnabled { - tracks = append(tracks, struct{ dir, table string }{"vector", "vector_schema_migrations"}) - } - - for _, t := range tracks { - version, dirty, err := versionDir(url, migrationsFS, t.dir, t.table) - if err != nil { - log.Fatalf("kagent-migrate: version %s: %v", t.dir, err) - } - log.Printf("kagent-migrate: track=%-6s table=%-30s version=%d dirty=%v", t.dir, t.table, version, dirty) - } -} - -func runForceCommand(url string, migrationsFS fs.FS, vectorEnabled bool, args []string) { - forceFlags := flag.NewFlagSet("force", flag.ExitOnError) - version := forceFlags.Int("version", -1, "version to force (required; use -1 to clear all migration history)") - track := forceFlags.String("track", "core", "migration track: core or vector") - if err := forceFlags.Parse(args); err != nil { - log.Fatalf("kagent-migrate: force: %v", err) - } - - var dir, table string - switch *track { - case "core": - dir, table = "core", "schema_migrations" - case "vector": - if !vectorEnabled { - log.Fatalf("kagent-migrate: force: track %q requested but KAGENT_DATABASE_VECTOR_ENABLED is not true", *track) - } - dir, table = "vector", "vector_schema_migrations" - default: - log.Fatalf("kagent-migrate: force: unknown track %q (valid: core, vector)", *track) - } - - mg, err := newMigrate(url, migrationsFS, dir, table) - if err != nil { - log.Fatalf("kagent-migrate: force: %v", err) - } - defer closeMigrate(dir, mg) - - if err := mg.Force(*version); err != nil { - log.Fatalf("kagent-migrate: force %s to version %d: %v", *track, *version, err) - } - log.Printf("kagent-migrate: forced %s track to version %d (dirty flag cleared)", *track, *version) -} - -func resolveURL() (string, error) { - if file := os.Getenv("POSTGRES_DATABASE_URL_FILE"); file != "" { - content, err := os.ReadFile(file) - if err != nil { - return "", fmt.Errorf("reading URL file %s: %w", file, err) - } - url := strings.TrimSpace(string(content)) - if url == "" { - return "", fmt.Errorf("URL file %s is empty", file) - } - return url, nil - } - url := os.Getenv("POSTGRES_DATABASE_URL") - if url == "" { - return "", fmt.Errorf("POSTGRES_DATABASE_URL must be set") - } - return url, nil -} - -// newMigrate opens a database connection and constructs a migrate.Migrate for the given dir/table. -// The caller is responsible for calling closeMigrate on the returned instance. -func newMigrate(url string, migrationsFS fs.FS, dir, migrationsTable string) (*migrate.Migrate, error) { - db, err := sql.Open("pgx", url) - if err != nil { - return nil, fmt.Errorf("open database for %s: %w", dir, err) - } - - src, err := iofs.New(migrationsFS, dir) - if err != nil { - return nil, fmt.Errorf("load migration files from %s: %w", dir, err) - } - - driver, err := migratepgx.WithInstance(db, &migratepgx.Config{ - MigrationsTable: migrationsTable, - }) - if err != nil { - return nil, fmt.Errorf("create migration driver for %s: %w", dir, err) - } - - mg, err := migrate.NewWithInstance("iofs", src, "postgres", driver) - if err != nil { - return nil, fmt.Errorf("create migrator for %s: %w", dir, err) - } - return mg, nil -} - -func downDir(url string, migrationsFS fs.FS, dir, migrationsTable string, steps int) error { - mg, err := newMigrate(url, migrationsFS, dir, migrationsTable) - if err != nil { - return err - } - defer closeMigrate(dir, mg) - - if err := mg.Steps(-steps); err != nil && !errors.Is(err, migrate.ErrNoChange) { - return fmt.Errorf("roll back %d migration(s) for %s: %w", steps, dir, err) - } - return nil -} - -func versionDir(url string, migrationsFS fs.FS, dir, migrationsTable string) (version uint, dirty bool, err error) { - mg, err := newMigrate(url, migrationsFS, dir, migrationsTable) - if err != nil { - return 0, false, err - } - defer closeMigrate(dir, mg) - - version, dirty, err = mg.Version() - if err != nil && !errors.Is(err, migrate.ErrNilVersion) { - return 0, false, fmt.Errorf("get version for %s: %w", dir, err) - } - return version, dirty, nil -} - -// closeMigrate closes mg, logging source and database close errors separately. -func closeMigrate(dir string, mg *migrate.Migrate) { - srcErr, dbErr := mg.Close() - if srcErr != nil { - log.Printf("warning: closing migration source for %s: %v", dir, srcErr) - } - if dbErr != nil { - log.Printf("warning: closing migration database for %s: %v", dir, dbErr) - } -} diff --git a/go/core/internal/database/connect.go b/go/core/internal/database/connect.go index 8fd3e2ee0..29fa1a9df 100644 --- a/go/core/internal/database/connect.go +++ b/go/core/internal/database/connect.go @@ -84,6 +84,16 @@ func retryDBConnection(ctx context.Context, url string, vectorEnabled bool) (*pg } } +// ResolveURL returns url, unless urlFile is non-empty in which case the URL is +// read from that file. Used by callers (e.g. the migration runner) that need +// the resolved connection string before a pool is created. +func ResolveURL(url, urlFile string) (string, error) { + if urlFile != "" { + return resolveURLFile(urlFile) + } + return url, nil +} + // resolveURLFile reads a database connection URL from a file and returns the // trimmed contents. Returns an error if the file cannot be read or is empty. func resolveURLFile(path string) (string, error) { diff --git a/go/core/internal/dbtest/dbtest.go b/go/core/internal/dbtest/dbtest.go index 71a448ed6..b5a38409c 100644 --- a/go/core/internal/dbtest/dbtest.go +++ b/go/core/internal/dbtest/dbtest.go @@ -3,16 +3,10 @@ package dbtest import ( "context" - "database/sql" - "errors" "fmt" "testing" "time" - "github.com/golang-migrate/migrate/v4" - migratepgx "github.com/golang-migrate/migrate/v4/database/pgx/v5" - "github.com/golang-migrate/migrate/v4/source/iofs" - _ "github.com/jackc/pgx/v5/stdlib" "github.com/kagent-dev/kagent/go/core/pkg/migrations" testcontainers "github.com/testcontainers/testcontainers-go" tcpostgres "github.com/testcontainers/testcontainers-go/modules/postgres" @@ -70,15 +64,7 @@ func StartT(ctx context.Context, t *testing.T) string { // If vectorEnabled is true the vector pass is also applied. // Use MigrateT in tests that have a *testing.T; use Migrate in TestMain where no T is available. func Migrate(connStr string, vectorEnabled bool) error { - if err := runMigrationDir(connStr, "core", "schema_migrations"); err != nil { - return fmt.Errorf("core migrations: %w", err) - } - if vectorEnabled { - if err := runMigrationDir(connStr, "vector", "vector_schema_migrations"); err != nil { - return fmt.Errorf("vector migrations: %w", err) - } - } - return nil + return migrations.RunUp(context.Background(), connStr, migrations.FS, vectorEnabled) } // MigrateT runs the embedded OSS migrations against connStr and calls t.Fatal on error. @@ -90,84 +76,13 @@ func MigrateT(t *testing.T, connStr string, vectorEnabled bool) { } } -// MigrateDown runs the embedded OSS down-migrations against connStr and returns any error. +// MigrateDown rolls back all OSS migrations against connStr and returns any error. // If vectorEnabled is true the vector pass is also rolled back first. func MigrateDown(connStr string, vectorEnabled bool) error { if vectorEnabled { - if err := downMigrationDir(connStr, "vector", "vector_schema_migrations"); err != nil { + if err := migrations.RunDownAll(connStr, migrations.FS, "vector", "vector_schema_migrations"); err != nil { return fmt.Errorf("vector down migrations: %w", err) } } - return downMigrationDir(connStr, "core", "schema_migrations") -} - -func runMigrationDir(connStr, dir, migrationsTable string) error { - db, err := sql.Open("pgx", connStr) - if err != nil { - return fmt.Errorf("open db for %s: %w", dir, err) - } - - src, err := iofs.New(migrations.FS, dir) - if err != nil { - return fmt.Errorf("load migration files from %s: %w", dir, err) - } - - driver, err := migratepgx.WithInstance(db, &migratepgx.Config{ - MigrationsTable: migrationsTable, - }) - if err != nil { - return fmt.Errorf("create migration driver for %s: %w", dir, err) - } - - mg, err := migrate.NewWithInstance("iofs", src, "postgres", driver) - if err != nil { - return fmt.Errorf("create migrator for %s: %w", dir, err) - } - defer closeMigrate(dir, mg) - - if err := mg.Up(); err != nil && !errors.Is(err, migrate.ErrNoChange) { - return fmt.Errorf("run migrations for %s: %w", dir, err) - } - return nil -} - -func downMigrationDir(connStr, dir, migrationsTable string) error { - db, err := sql.Open("pgx", connStr) - if err != nil { - return fmt.Errorf("open db for %s: %w", dir, err) - } - - src, err := iofs.New(migrations.FS, dir) - if err != nil { - return fmt.Errorf("load migration files from %s: %w", dir, err) - } - - driver, err := migratepgx.WithInstance(db, &migratepgx.Config{ - MigrationsTable: migrationsTable, - }) - if err != nil { - return fmt.Errorf("create migration driver for %s: %w", dir, err) - } - - mg, err := migrate.NewWithInstance("iofs", src, "postgres", driver) - if err != nil { - return fmt.Errorf("create migrator for %s: %w", dir, err) - } - defer closeMigrate(dir, mg) - - if err := mg.Down(); err != nil && !errors.Is(err, migrate.ErrNoChange) { - return fmt.Errorf("down migrations for %s: %w", dir, err) - } - return nil -} - -// closeMigrate closes mg, logging source and database close errors separately. -func closeMigrate(dir string, mg *migrate.Migrate) { - srcErr, dbErr := mg.Close() - if srcErr != nil { - fmt.Printf("warning: closing migration source for %s: %v\n", dir, srcErr) - } - if dbErr != nil { - fmt.Printf("warning: closing migration database for %s: %v\n", dir, dbErr) - } + return migrations.RunDownAll(connStr, migrations.FS, "core", "schema_migrations") } diff --git a/go/core/pkg/app/app.go b/go/core/pkg/app/app.go index 73994f88d..6ae0b8ec8 100644 --- a/go/core/pkg/app/app.go +++ b/go/core/pkg/app/app.go @@ -55,6 +55,7 @@ import ( dbpkg "github.com/kagent-dev/kagent/go/api/database" "github.com/kagent-dev/kagent/go/core/pkg/auth" + "github.com/kagent-dev/kagent/go/core/pkg/migrations" "github.com/kagent-dev/kagent/go/core/pkg/translator" "k8s.io/apimachinery/pkg/runtime" utilruntime "k8s.io/apimachinery/pkg/util/runtime" @@ -262,7 +263,15 @@ type ExtensionConfig struct { type GetExtensionConfig func(bootstrap BootstrapConfig) (*ExtensionConfig, error) -func Start(getExtensionConfig GetExtensionConfig) { +// MigrationRunner applies database migrations given the resolved connection URL. +// vectorEnabled mirrors the --database-vector-enabled flag. +// Returning a non-nil error causes the app to exit. +// +// Pass nil to Start to use the default migration runner (migrations.RunUp with migrations.FS). +// An optional migration runner can be provided to take over the migration process. +type MigrationRunner func(ctx context.Context, url string, vectorEnabled bool) error + +func Start(getExtensionConfig GetExtensionConfig, migrationRunner MigrationRunner) { var tlsOpts []func(*tls.Config) var cfg Config @@ -410,10 +419,32 @@ func Start(getExtensionConfig GetExtensionConfig) { os.Exit(1) } + // Resolve the database URL once so both the migration runner and the pool + // connection use exactly the same value. + dbURL, err := database.ResolveURL(cfg.Database.Url, cfg.Database.UrlFile) + if err != nil { + setupLog.Error(err, "unable to resolve database URL") + os.Exit(1) + } + + // Use the built-in migration runner when none is provided. + if migrationRunner == nil { + migrationRunner = func(ctx context.Context, url string, vectorEnabled bool) error { + return migrations.RunUp(ctx, url, migrations.FS, vectorEnabled) + } + } + + // Run migrations before connecting; schema must exist before queries. + setupLog.Info("running database migrations") + if err := migrationRunner(ctx, dbURL, cfg.Database.VectorEnabled); err != nil { + setupLog.Error(err, "database migration failed") + os.Exit(1) + } + setupLog.Info("database migrations complete") + // Connect to database db, err := database.Connect(ctx, &database.PostgresConfig{ - URL: cfg.Database.Url, - URLFile: cfg.Database.UrlFile, + URL: dbURL, VectorEnabled: cfg.Database.VectorEnabled, }) if err != nil { diff --git a/go/core/pkg/migrations/runner.go b/go/core/pkg/migrations/runner.go new file mode 100644 index 000000000..ff19cd918 --- /dev/null +++ b/go/core/pkg/migrations/runner.go @@ -0,0 +1,218 @@ +package migrations + +import ( + "context" + "database/sql" + "errors" + "fmt" + "io/fs" + "log" + + "github.com/golang-migrate/migrate/v4" + migratepgx "github.com/golang-migrate/migrate/v4/database/pgx/v5" + "github.com/golang-migrate/migrate/v4/source/iofs" + _ "github.com/jackc/pgx/v5/stdlib" +) + +// RunUp applies all pending migrations for the given FS. +// vectorEnabled controls whether the vector track is also applied. +// Returns an error if any track fails (and attempts rollback of previously applied tracks). +func RunUp(_ context.Context, url string, migrationsFS fs.FS, vectorEnabled bool) error { + corePrev, err := applyDir(url, migrationsFS, "core", "schema_migrations") + if err != nil { + return fmt.Errorf("core migrations: %w", err) + } + + if vectorEnabled { + if _, err := applyDir(url, migrationsFS, "vector", "vector_schema_migrations"); err != nil { + log.Printf("migrations: rolling back core to version %d", corePrev) + rollbackDir(url, migrationsFS, "core", "schema_migrations", corePrev) + return fmt.Errorf("vector migrations: %w", err) + } + } + + return nil +} + +// RunDown rolls back steps migrations on a single track. +func RunDown(url string, migrationsFS fs.FS, dir, migrationsTable string, steps int) error { + mg, err := newMigrate(url, migrationsFS, dir, migrationsTable) + if err != nil { + return err + } + defer closeMigrate(dir, mg) + + if err := mg.Steps(-steps); err != nil && !errors.Is(err, migrate.ErrNoChange) { + return fmt.Errorf("roll back %d migration(s) for %s: %w", steps, dir, err) + } + return nil +} + +// RunDownAll rolls back all applied migrations on a single track. +func RunDownAll(url string, migrationsFS fs.FS, dir, migrationsTable string) error { + mg, err := newMigrate(url, migrationsFS, dir, migrationsTable) + if err != nil { + return err + } + defer closeMigrate(dir, mg) + + if err := mg.Down(); err != nil && !errors.Is(err, migrate.ErrNoChange) { + return fmt.Errorf("down migrations for %s: %w", dir, err) + } + return nil +} + +// RunVersion returns the current applied version and dirty flag for a single track. +func RunVersion(url string, migrationsFS fs.FS, dir, migrationsTable string) (version uint, dirty bool, err error) { + mg, err := newMigrate(url, migrationsFS, dir, migrationsTable) + if err != nil { + return 0, false, err + } + defer closeMigrate(dir, mg) + + version, dirty, err = mg.Version() + if err != nil && !errors.Is(err, migrate.ErrNilVersion) { + return 0, false, fmt.Errorf("get version for %s: %w", dir, err) + } + return version, dirty, nil +} + +// RunForce forces the tracking table for a single track to version (clears the dirty flag). +// Pass version=-1 to remove the version record entirely. +func RunForce(url string, migrationsFS fs.FS, dir, migrationsTable string, version int) error { + mg, err := newMigrate(url, migrationsFS, dir, migrationsTable) + if err != nil { + return err + } + defer closeMigrate(dir, mg) + + if err := mg.Force(version); err != nil { + return fmt.Errorf("force %s to version %d: %w", dir, version, err) + } + return nil +} + +// applyDir runs Up for dir and rolls back on failure. It returns the pre-run +// version so the caller can roll back this track if a later track fails. +func applyDir(url string, migrationsFS fs.FS, dir, migrationsTable string) (prevVersion uint, err error) { + mg, err := newMigrate(url, migrationsFS, dir, migrationsTable) + if err != nil { + return 0, err + } + defer closeMigrate(dir, mg) + + prevVersion, _, err = mg.Version() + if err != nil && !errors.Is(err, migrate.ErrNilVersion) { + return 0, fmt.Errorf("get pre-migration version for %s: %w", dir, err) + } + // prevVersion == 0 when ErrNilVersion (no migrations applied yet). + + if upErr := mg.Up(); upErr != nil { + if errors.Is(upErr, migrate.ErrNoChange) { + return prevVersion, nil + } + log.Printf("migrations: migration failed for %s, attempting rollback to version %d", dir, prevVersion) + if rbErr := rollbackToVersion(mg, dir, prevVersion); rbErr != nil { + log.Printf("migrations: rollback failed for %s: %v", dir, rbErr) + } else { + log.Printf("migrations: rolled back %s to version %d", dir, prevVersion) + } + return prevVersion, fmt.Errorf("run migrations for %s: %w", dir, upErr) + } + return prevVersion, nil +} + +// rollbackDir opens a fresh migrate instance and rolls dir back to targetVersion. +// Used to roll back a previously-succeeded track when a later track fails. +func rollbackDir(url string, migrationsFS fs.FS, dir, migrationsTable string, targetVersion uint) { + mg, err := newMigrate(url, migrationsFS, dir, migrationsTable) + if err != nil { + log.Printf("migrations: rollback of %s failed (open): %v", dir, err) + return + } + defer closeMigrate(dir, mg) + if err := rollbackToVersion(mg, dir, targetVersion); err != nil { + log.Printf("migrations: rollback of %s failed: %v", dir, err) + } else { + log.Printf("migrations: rolled back %s to version %d", dir, targetVersion) + } +} + +// rollbackToVersion rolls the migration state back to targetVersion. +// It handles the dirty-state cleanup golang-migrate requires after a failed +// Up run before down steps can be applied. +func rollbackToVersion(mg *migrate.Migrate, dir string, targetVersion uint) error { + currentVersion, dirty, err := mg.Version() + if err != nil { + if errors.Is(err, migrate.ErrNilVersion) { + return nil // nothing was applied; nothing to roll back + } + return fmt.Errorf("get version after failure for %s: %w", dir, err) + } + + if dirty { + // The failed migration is recorded as dirty at currentVersion. + // Force to the last clean version so Steps can run. + cleanVersion := int(currentVersion) - 1 + forceTarget := cleanVersion + if forceTarget < 1 { + forceTarget = -1 // negative tells golang-migrate to remove the version record entirely + } + if err := mg.Force(forceTarget); err != nil { + return fmt.Errorf("clear dirty state for %s: %w", dir, err) + } + if forceTarget < 0 { + return nil // first migration failed and was cleared; nothing left to roll back + } + currentVersion = uint(cleanVersion) + } + + steps := int(currentVersion) - int(targetVersion) + if steps <= 0 { + return nil + } + if err := mg.Steps(-steps); err != nil && !errors.Is(err, migrate.ErrNoChange) { + return fmt.Errorf("roll back %d step(s) for %s: %w", steps, dir, err) + } + return nil +} + +// newMigrate opens a dedicated database connection and constructs a migrate.Migrate +// for the given dir/table. The caller must call closeMigrate when done. +// Uses sql.Open (pgx stdlib shim) — a single dedicated connection — not a pool, +// because the advisory lock is session-level and must not be shared. +func newMigrate(url string, migrationsFS fs.FS, dir, migrationsTable string) (*migrate.Migrate, error) { + db, err := sql.Open("pgx", url) + if err != nil { + return nil, fmt.Errorf("open database for %s: %w", dir, err) + } + + src, err := iofs.New(migrationsFS, dir) + if err != nil { + return nil, fmt.Errorf("load migration files from %s: %w", dir, err) + } + + driver, err := migratepgx.WithInstance(db, &migratepgx.Config{ + MigrationsTable: migrationsTable, + }) + if err != nil { + return nil, fmt.Errorf("create migration driver for %s: %w", dir, err) + } + + mg, err := migrate.NewWithInstance("iofs", src, "postgres", driver) + if err != nil { + return nil, fmt.Errorf("create migrator for %s: %w", dir, err) + } + return mg, nil +} + +// closeMigrate closes mg, logging source and database close errors separately. +func closeMigrate(dir string, mg *migrate.Migrate) { + srcErr, dbErr := mg.Close() + if srcErr != nil { + log.Printf("warning: closing migration source for %s: %v", dir, srcErr) + } + if dbErr != nil { + log.Printf("warning: closing migration database for %s: %v", dir, dbErr) + } +} diff --git a/go/core/cmd/migrate/main_test.go b/go/core/pkg/migrations/runner_test.go similarity index 84% rename from go/core/cmd/migrate/main_test.go rename to go/core/pkg/migrations/runner_test.go index 87055d65f..d5f5262e3 100644 --- a/go/core/cmd/migrate/main_test.go +++ b/go/core/pkg/migrations/runner_test.go @@ -1,4 +1,4 @@ -package main +package migrations import ( "context" @@ -7,9 +7,12 @@ import ( "maps" "testing" "testing/fstest" + "time" _ "github.com/jackc/pgx/v5/stdlib" - "github.com/kagent-dev/kagent/go/core/internal/dbtest" + testcontainers "github.com/testcontainers/testcontainers-go" + tcpostgres "github.com/testcontainers/testcontainers-go/modules/postgres" + "github.com/testcontainers/testcontainers-go/wait" ) // --- migration fixtures --- @@ -75,10 +78,41 @@ func trackVersion(t *testing.T, connStr, table string) uint { return v } +// startTestDB spins up a pgvector Postgres container and returns its connection +// string, registering cleanup with t. It does not run any migrations. +func startTestDB(t *testing.T) string { + t.Helper() + ctx := context.Background() + pgContainer, err := tcpostgres.Run(ctx, + "pgvector/pgvector:pg18-trixie", + tcpostgres.WithDatabase("kagent_test"), + tcpostgres.WithUsername("postgres"), + tcpostgres.WithPassword("kagent"), + testcontainers.WithWaitStrategy( + wait.ForLog("database system is ready to accept connections"). + WithOccurrence(2). + WithStartupTimeout(60*time.Second), + ), + ) + if err != nil { + t.Fatalf("startTestDB: start container: %v", err) + } + t.Cleanup(func() { + if err := pgContainer.Terminate(ctx); err != nil { + t.Logf("warning: failed to terminate postgres container: %v", err) + } + }) + connStr, err := pgContainer.ConnectionString(ctx, "sslmode=disable") + if err != nil { + t.Fatalf("startTestDB: connection string: %v", err) + } + return connStr +} + // --- applyDir tests --- func TestApplyDir_HappyPath(t *testing.T) { - connStr := dbtest.StartT(context.Background(), t) + connStr := startTestDB(t) prev, err := applyDir(connStr, goodCoreFS, "core", "schema_migrations") if err != nil { @@ -93,7 +127,7 @@ func TestApplyDir_HappyPath(t *testing.T) { } func TestApplyDir_NoOpWhenAlreadyAtLatest(t *testing.T) { - connStr := dbtest.StartT(context.Background(), t) + connStr := startTestDB(t) if _, err := applyDir(connStr, goodCoreFS, "core", "schema_migrations"); err != nil { t.Fatalf("first apply: %v", err) @@ -111,7 +145,7 @@ func TestApplyDir_NoOpWhenAlreadyAtLatest(t *testing.T) { } func TestApplyDir_RollsBackWhenFirstMigrationFails(t *testing.T) { - connStr := dbtest.StartT(context.Background(), t) + connStr := startTestDB(t) if _, err := applyDir(connStr, failOnFirstCoreFS, "core", "schema_migrations"); err == nil { t.Fatal("expected error, got nil") @@ -122,7 +156,7 @@ func TestApplyDir_RollsBackWhenFirstMigrationFails(t *testing.T) { } func TestApplyDir_RollsBackWhenLaterMigrationFails(t *testing.T) { - connStr := dbtest.StartT(context.Background(), t) + connStr := startTestDB(t) if _, err := applyDir(connStr, failOnSecondCoreFS, "core", "schema_migrations"); err == nil { t.Fatal("expected error, got nil") @@ -134,7 +168,7 @@ func TestApplyDir_RollsBackWhenLaterMigrationFails(t *testing.T) { } func TestApplyDir_RollsBackToExistingVersion(t *testing.T) { - connStr := dbtest.StartT(context.Background(), t) + connStr := startTestDB(t) // Establish a baseline at version 1. if _, err := applyDir(connStr, oneCoreFS, "core", "schema_migrations"); err != nil { @@ -153,7 +187,7 @@ func TestApplyDir_RollsBackToExistingVersion(t *testing.T) { // --- rollbackDir tests --- func TestRollbackDir_RollsBackToTarget(t *testing.T) { - connStr := dbtest.StartT(context.Background(), t) + connStr := startTestDB(t) if _, err := applyDir(connStr, goodCoreFS, "core", "schema_migrations"); err != nil { t.Fatalf("setup: %v", err) @@ -167,7 +201,7 @@ func TestRollbackDir_RollsBackToTarget(t *testing.T) { } func TestRollbackDir_PartialRollback(t *testing.T) { - connStr := dbtest.StartT(context.Background(), t) + connStr := startTestDB(t) if _, err := applyDir(connStr, goodCoreFS, "core", "schema_migrations"); err != nil { t.Fatalf("setup: %v", err) @@ -187,7 +221,7 @@ func TestRollbackDir_PartialRollback(t *testing.T) { // core has no new migrations (ErrNoChange) and vector fails. Core should not // be downgraded by the cross-track rollback. func TestCrossTrackRollback_CoreUnchangedWhenVectorFails(t *testing.T) { - connStr := dbtest.StartT(context.Background(), t) + connStr := startTestDB(t) combined := mergeFS(goodCoreFS, failVectorFS) @@ -218,7 +252,7 @@ func TestCrossTrackRollback_CoreUnchangedWhenVectorFails(t *testing.T) { } func TestCrossTrackRollback_CoreRolledBackWhenVectorFails(t *testing.T) { - connStr := dbtest.StartT(context.Background(), t) + connStr := startTestDB(t) combined := mergeFS(goodCoreFS, failVectorFS) diff --git a/helm/kagent/templates/controller-deployment.yaml b/helm/kagent/templates/controller-deployment.yaml index eace760ea..3d264913a 100644 --- a/helm/kagent/templates/controller-deployment.yaml +++ b/helm/kagent/templates/controller-deployment.yaml @@ -44,36 +44,6 @@ spec: tolerations: {{- toYaml . | nindent 8 }} {{- end }} - initContainers: - - name: migrate - image: "{{ .Values.controller.migrate.image.registry | default .Values.registry }}/{{ .Values.controller.migrate.image.repository }}:{{ coalesce .Values.tag .Values.controller.migrate.image.tag .Chart.Version }}" - imagePullPolicy: {{ .Values.controller.migrate.image.pullPolicy | default .Values.imagePullPolicy }} - env: - {{- if .Values.database.postgres.urlFile }} - - name: POSTGRES_DATABASE_URL_FILE - value: {{ .Values.database.postgres.urlFile | quote }} - {{- else if .Values.database.postgres.url }} - - name: POSTGRES_DATABASE_URL - value: {{ .Values.database.postgres.url | quote }} - {{- else if .Values.database.postgres.bundled.enabled }} - - name: POSTGRES_PASSWORD - valueFrom: - secretKeyRef: - name: {{ include "kagent.passwordSecretName" . }} - key: POSTGRES_PASSWORD - - name: POSTGRES_DATABASE_URL - value: {{ printf "postgres://kagent:$(POSTGRES_PASSWORD)@%s.%s.svc.cluster.local:5432/kagent?sslmode=disable" (include "kagent.postgresqlServiceName" .) (include "kagent.namespace" .) | quote }} - {{- else }} - {{ fail "No database connection configured. Set database.postgres.url, database.postgres.urlFile, or enable database.postgres.bundled." }} - {{- end }} - - name: KAGENT_DATABASE_VECTOR_ENABLED - value: {{ .Values.database.postgres.vectorEnabled | default false | quote }} - {{- if gt (len .Values.controller.volumeMounts) 0 }} - volumeMounts: - {{- with .Values.controller.volumeMounts }} - {{- toYaml . | nindent 12 }} - {{- end }} - {{- end }} containers: - name: controller image: "{{ .Values.controller.image.registry | default .Values.registry }}/{{ .Values.controller.image.repository }}:{{ coalesce .Values.tag .Values.controller.image.tag .Chart.Version }}" diff --git a/helm/kagent/values.yaml b/helm/kagent/values.yaml index 053a93e49..be972108a 100644 --- a/helm/kagent/values.yaml +++ b/helm/kagent/values.yaml @@ -167,12 +167,6 @@ controller: repository: kagent-dev/kagent/controller tag: "" # Will default to global, then Chart version pullPolicy: "" - migrate: - image: - registry: "" - repository: kagent-dev/kagent/migrate - tag: "" # Will default to global, then Chart version - pullPolicy: "" resources: requests: cpu: 100m From bb1281d5c33a2bc27ae76cab4e48821b6236797c Mon Sep 17 00:00:00 2001 From: Jeremy Alvis Date: Wed, 1 Apr 2026 15:13:42 -0700 Subject: [PATCH 10/16] Updates based on review and sqlc generate Signed-off-by: Jeremy Alvis --- .claude/skills/kagent-dev/SKILL.md | 9 +- .../kagent-dev/references/ci-failures.md | 2 + .../references/database-migrations.md | 82 +++++++++---------- .github/workflows/sqlc-generate-check.yaml | 38 +++++++++ docs/architecture/README.md | 4 +- go/README.md | 2 +- go/api/database/models.go | 6 +- go/core/internal/database/client_postgres.go | 62 ++++++++------ go/core/internal/database/gen/feedback.sql.go | 21 +---- .../internal/database/gen/langgraph.sql.go | 4 +- go/core/internal/database/gen/memory.sql.go | 7 +- go/core/internal/database/gen/models.go | 52 ++++++------ go/core/internal/database/gen/querier.go | 3 +- .../internal/database/queries/feedback.sql | 5 +- go/core/internal/database/queries/memory.sql | 1 + go/core/internal/dbtest/dbtest.go | 2 +- .../httpserver/handlers/checkpoints.go | 4 +- go/core/pkg/app/app.go | 11 ++- go/core/pkg/migrations/runner.go | 3 +- 19 files changed, 180 insertions(+), 138 deletions(-) create mode 100644 .github/workflows/sqlc-generate-check.yaml diff --git a/.claude/skills/kagent-dev/SKILL.md b/.claude/skills/kagent-dev/SKILL.md index 9901b646e..cdf2b5f07 100644 --- a/.claude/skills/kagent-dev/SKILL.md +++ b/.claude/skills/kagent-dev/SKILL.md @@ -18,6 +18,9 @@ make helm-install # Builds images and deploys to Kind make controller-manifests # generate + copy CRDs to helm (recommended) make -C go generate # DeepCopy methods only +# sqlc (after editing go/core/internal/database/queries/*.sql) +cd go/core/internal/database && sqlc generate # regenerate gen/ — commit both + # Build & test make -C go test # Unit tests (includes golden file checks) make -C go e2e # E2E tests (needs KAGENT_URL) @@ -43,7 +46,7 @@ kagent/ │ ├── api/ # Shared types module │ │ ├── v1alpha2/ # Current CRD types (agent_types.go, etc.) │ │ ├── adk/ # ADK config types (types.go) — flows to Python runtime -│ │ ├── database/ # GORM models +│ │ ├── database/ # database models │ │ ├── httpapi/ # HTTP API types │ │ └── config/crd/bases/ # Generated CRD YAML │ ├── core/ # Infrastructure module @@ -231,7 +234,7 @@ curl -v $KAGENT_URL/healthz # Controller reach **Reproducing locally (without cluster):** Follow `go/core/test/e2e/README.md` — extract agent config, start mock LLM server, run agent with `kagent-adk test`. Much faster iteration than full cluster. -**CI-specific:** E2E runs in matrix (`sqlite` + `postgres`). If only one database variant fails, it's likely database-related. If both fail, it's infrastructure. Most common CI-only failure: mock LLM unreachability because `KAGENT_LOCAL_HOST` detection fails on Linux. +**CI-specific:** Most common CI-only failure: mock LLM unreachability because `KAGENT_LOCAL_HOST` detection fails on Linux. See `references/e2e-debugging.md` for comprehensive debugging techniques. @@ -349,4 +352,4 @@ Don't use Go template syntax (`{{ }}`) in doc comments — Helm will try to pars - `references/translator-guide.md` - Translator patterns, `deployments.go` and `adk_api_translator.go` - `references/e2e-debugging.md` - Comprehensive E2E debugging, local reproduction - `references/ci-failures.md` - CI failure patterns and fixes -- `references/database-migrations.md` - Migration authoring rules, multi-instance safety, GORM baseline, expand/contract pattern +- `references/database-migrations.md` - Migration authoring rules, sqlc workflow, multi-instance safety, expand/contract pattern diff --git a/.claude/skills/kagent-dev/references/ci-failures.md b/.claude/skills/kagent-dev/references/ci-failures.md index 2cebb996d..367c7e2dc 100644 --- a/.claude/skills/kagent-dev/references/ci-failures.md +++ b/.claude/skills/kagent-dev/references/ci-failures.md @@ -7,6 +7,7 @@ Common GitHub Actions CI failures and how to fix them. | Failure | Likely Cause | Quick Fix | |---------|--------------|-----------| | manifests-check | CRD manifests out of date | `make -C go generate && cp go/api/config/crd/bases/*.yaml helm/kagent-crds/templates/` | +| sqlc-generate-check | `gen/` out of sync with queries | `cd go/core/internal/database && sqlc generate`, commit `gen/` | | go-lint depguard | Forbidden package used | Replace with allowed alternative (e.g., `slices.Sort` not `sort.Strings`) | | test-e2e timeout | Agent not starting or KAGENT_URL wrong | Check pod status, verify KAGENT_URL setup in CI | | golden files mismatch | Translator output changed | `UPDATE_GOLDEN=true make -C go test` and commit | @@ -520,6 +521,7 @@ make init-git-hooks Before submitting PR: - [ ] Ran `make -C go generate` after CRD changes +- [ ] Ran `cd go/core/internal/database && sqlc generate` after query changes, committed `gen/` - [ ] Ran `make lint` and fixed issues - [ ] Ran `make -C go test` and all pass - [ ] Regenerated golden files if translator changed diff --git a/.claude/skills/kagent-dev/references/database-migrations.md b/.claude/skills/kagent-dev/references/database-migrations.md index f2f5d09a7..f38df3839 100644 --- a/.claude/skills/kagent-dev/references/database-migrations.md +++ b/.claude/skills/kagent-dev/references/database-migrations.md @@ -1,12 +1,13 @@ # Database Migrations Guide -kagent uses [golang-migrate](https://github.com/golang-migrate/migrate) with embedded SQL files. Migrations run as a Kubernetes **init container** (`kagent-migrate`) before the controller starts. +kagent uses [golang-migrate](https://github.com/golang-migrate/migrate) with embedded SQL files and [sqlc](https://sqlc.dev/) for type-safe query generation. Migrations run **in-app at startup** — the controller applies them before accepting traffic. ## Structure ``` go/core/pkg/migrations/ -├── migrations.go # Embeds the FS (go:embed) +├── migrations.go # Embeds the FS (go:embed); exports FS for downstream consumers +├── runner.go # RunUp / RunDown / RunDownAll / RunVersion / RunForce ├── core/ # Core schema (tracked in schema_migrations table) │ ├── 000001_initial.up.sql / .down.sql │ ├── 000002_add_session_source.up.sql / .down.sql @@ -14,9 +15,38 @@ go/core/pkg/migrations/ └── vector/ # pgvector schema (tracked in vector_schema_migrations table) ├── 000001_vector_support.up.sql / .down.sql └── ... + +go/core/internal/database/ +├── queries/ # Hand-written SQL queries (source of truth) +│ ├── sessions.sql +│ ├── memory.sql +│ └── ... +├── gen/ # sqlc-generated Go code — DO NOT edit manually +│ ├── db.go +│ ├── models.go +│ └── *.sql.go +└── sqlc.yaml # sqlc configuration ``` -The `kagent-migrate` binary (in `go/core/cmd/migrate/`) runs `up` by default. It manages two independent tracks — `core` and `vector` — and rolls back both if either fails. +Migrations manage two independent tracks — `core` and `vector` — and roll back both if either fails. The `--database-vector-enabled` flag (default `true`) controls whether the vector track runs. + +## sqlc Workflow + +When you add or change a SQL query: + +1. Edit (or add) a `.sql` file under `go/core/internal/database/queries/` +2. Regenerate: + ```bash + cd go/core/internal/database && sqlc generate + ``` +3. Commit both the query file and the updated `gen/` files together. + +A CI check (`.github/workflows/sqlc-generate-check.yaml`) fails the PR if `gen/` is out of sync with the queries. Never edit `gen/` by hand. + +**sqlc annotations used:** +- `:one` — returns a single row +- `:many` — returns a slice +- `:exec` — returns only error (use for INSERT/UPDATE/DELETE that don't need the result) ## Writing Migrations @@ -56,68 +86,38 @@ Files must follow `NNNNNN_description.up.sql` / `NNNNNN_description.down.sql` wi ### Down migrations -Every `.up.sql` must have a corresponding `.down.sql` that exactly reverses it. Down migrations are used by the `kagent-migrate down --steps N --track core` command for rollbacks, and by automatic rollback on migration failure. They must be **idempotent** — the two-track rollback logic (roll back core if vector fails) may call them more than once in failure scenarios. +Every `.up.sql` must have a corresponding `.down.sql` that exactly reverses it. Down migrations are used for rollbacks and by automatic rollback on migration failure. They must be **idempotent** — the two-track rollback logic (roll back core if vector fails) may call them more than once in failure scenarios. ## Multi-Instance Safety ### How the advisory lock works -golang-migrate acquires a PostgreSQL **session-level** advisory lock (`pg_advisory_lock`) before running. +The migration runner acquires a PostgreSQL **session-level** advisory lock (`pg_advisory_lock`) before running. -### Init container concurrency +### Rolling deploy concurrency If multiple pods start simultaneously (e.g., rolling deploy with replicas > 1): -1. One init container acquires the advisory lock and runs migrations. +1. One controller acquires the advisory lock and runs migrations. 2. Others block on `pg_advisory_lock`. 3. When the winner finishes and its connection closes, the next waiter acquires the lock, calls `Up()`, gets `ErrNoChange`, and exits immediately. -This is safe. The only risk is if the winning init container crashes mid-migration (see Dirty State below). +This is safe. The only risk is if the winning controller crashes mid-migration (see Dirty State below). ### Dirty state recovery -If `kagent-migrate` crashes mid-migration (OOMKill, pod eviction), golang-migrate records the version as `dirty = true` in the tracking table. The next run (after the advisory lock releases) will detect dirty state and call `rollbackToVersion`, which: +If the controller crashes mid-migration, the migration runner records the version as `dirty = true` in the tracking table. The next startup detects dirty state and calls `rollbackToVersion`, which: 1. Calls `mg.Force(version - 1)` to clear the dirty flag. 2. Runs the down migration to restore the previous clean state. 3. Re-runs the failed up migration. -**Requirement**: down migrations must be idempotent and correctly reverse their up migration. A missing or broken down migration requires manual recovery — see the `force` subcommand below. - +**Requirement**: down migrations must be idempotent and correctly reverse their up migration. A missing or broken down migration requires manual recovery using `RunForce`. ### Rollout strategy For additive, backward-compatible migrations a rolling update is safe: -1. New pod starts → `kagent-migrate up` runs (advisory lock serializes concurrent runs) +1. New pod starts → migration runner applies pending migrations (advisory lock serializes concurrent runs) 2. New pod passes readiness probe → old pod terminates 3. Backward-compatible schema means old pods continue operating during the window For a migration that is **not** backward-compatible, restructure it using expand/contract. - -## Running Migrations Locally - -```bash -# Apply all pending migrations -POSTGRES_DATABASE_URL="postgres://..." kagent-migrate up - -# Check current version on each track -POSTGRES_DATABASE_URL="..." kagent-migrate version - -# Roll back 1 step on core track -POSTGRES_DATABASE_URL="..." kagent-migrate down --steps 1 --track core - -# With vector support -KAGENT_DATABASE_VECTOR_ENABLED=true POSTGRES_DATABASE_URL="..." kagent-migrate up -``` - -### Manual dirty-state recovery - -If a migration was partially applied (dirty state), use `force` to reset to the last clean version before running `down`: - -```bash -# Force the tracking table to a specific version (clears dirty flag) -POSTGRES_DATABASE_URL="..." kagent-migrate force --track core -# Then re-run up, or roll back: -POSTGRES_DATABASE_URL="..." kagent-migrate down --steps 1 --track core -``` - -In a Kubernetes deployment, the init container runs automatically on every pod start. diff --git a/.github/workflows/sqlc-generate-check.yaml b/.github/workflows/sqlc-generate-check.yaml new file mode 100644 index 000000000..8629d465e --- /dev/null +++ b/.github/workflows/sqlc-generate-check.yaml @@ -0,0 +1,38 @@ +name: sqlc Generate Check + +on: + pull_request: + branches: [main] + paths: + - "go/core/internal/database/queries/**" + - "go/core/internal/database/sqlc.yaml" + - "go/core/pkg/migrations/**" + +jobs: + check: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-go@v6 + with: + go-version: "1.26" + cache: true + cache-dependency-path: go/go.sum + + - name: Install sqlc + run: go install github.com/sqlc-dev/sqlc/cmd/sqlc@v1.30.0 + + - name: Run sqlc generate + working-directory: go/core/internal/database + run: sqlc generate + + - name: Fail if generated files differ + run: | + if ! git diff --quiet go/core/internal/database/gen/; then + echo "ERROR: sqlc generate produced changes. Run sqlc generate locally and commit the result." + echo "" + git diff go/core/internal/database/gen/ + exit 1 + fi + echo "OK: generated files are up to date." diff --git a/docs/architecture/README.md b/docs/architecture/README.md index a311e11c1..f2e87fb88 100644 --- a/docs/architecture/README.md +++ b/docs/architecture/README.md @@ -149,7 +149,7 @@ The controller uses SQLite (default) or PostgreSQL for persistent state that sup **Why a separate DB?** The Kubernetes API is not designed for high-frequency read patterns like listing conversations or searching tools. The DB provides fast lookups for the HTTP API and UI, while the CRDs remain the source of truth for agent configuration. **Key files:** -- `go/api/database/models.go` — GORM models +- `go/api/database/models.go` — database models - `go/core/internal/database/client.go` — Database client implementation - `go/core/internal/database/service.go` — Business logic with atomic upserts @@ -398,7 +398,7 @@ go/ ├── go.work ├── api/ # github.com/kagent-dev/kagent/go/api │ ├── v1alpha2/ # CRD type definitions -│ ├── database/ # GORM database models +│ ├── database/ # database models │ ├── httpapi/ # HTTP API request/response types │ ├── client/ # REST client SDK for the HTTP API │ └── config/crd/ # Generated CRD manifests diff --git a/go/README.md b/go/README.md index 35f07a263..b679cf39c 100644 --- a/go/README.md +++ b/go/README.md @@ -31,7 +31,7 @@ go/ │ ├── v1alpha1/ # Legacy CRD types │ ├── v1alpha2/ # Current CRD types │ ├── adk/ # ADK config & model types -│ ├── database/ # GORM model structs & Client interface +│ ├── database/ # database model structs & Client interface │ ├── httpapi/ # HTTP API request/response types │ ├── client/ # REST HTTP client SDK │ ├── utils/ # Shared utility functions diff --git a/go/api/database/models.go b/go/api/database/models.go index 5520b0039..23648accd 100644 --- a/go/api/database/models.go +++ b/go/api/database/models.go @@ -166,7 +166,7 @@ type LangGraphCheckpoint struct { Metadata string `json:"metadata"` Checkpoint string `json:"checkpoint"` CheckpointType string `json:"checkpoint_type"` - Version int32 `json:"version"` + Version int64 `json:"version"` } type LangGraphCheckpointWrite struct { @@ -174,7 +174,7 @@ type LangGraphCheckpointWrite struct { ThreadID string `json:"thread_id"` CheckpointNS string `json:"checkpoint_ns"` CheckpointID string `json:"checkpoint_id"` - WriteIdx int32 `json:"write_idx"` + WriteIdx int64 `json:"write_idx"` Value string `json:"value"` ValueType string `json:"value_type"` Channel string `json:"channel"` @@ -212,7 +212,7 @@ type Memory struct { Metadata string `json:"metadata"` CreatedAt time.Time `json:"created_at"` ExpiresAt *time.Time `json:"expires_at,omitempty"` - AccessCount int32 `json:"access_count"` + AccessCount int64 `json:"access_count"` } // AgentMemorySearchResult is the result of a vector similarity search over Memory. diff --git a/go/core/internal/database/client_postgres.go b/go/core/internal/database/client_postgres.go index a80a7f002..c5aaaa10d 100644 --- a/go/core/internal/database/client_postgres.go +++ b/go/core/internal/database/client_postgres.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "strings" + "time" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" @@ -271,7 +272,7 @@ func (c *postgresClient) DeletePushNotification(ctx context.Context, taskID stri func (c *postgresClient) StoreFeedback(ctx context.Context, feedback *dbpkg.Feedback) error { isPositive := feedback.IsPositive - _, err := c.q.InsertFeedback(ctx, dbgen.InsertFeedbackParams{ + err := c.q.InsertFeedback(ctx, dbgen.InsertFeedbackParams{ UserID: feedback.UserID, MessageID: feedback.MessageID, IsPositive: &isPositive, @@ -611,9 +612,9 @@ func (c *postgresClient) SearchAgentMemory(ctx context.Context, agentName, userI Content: derefStr(r.Content), Embedding: r.Embedding, Metadata: derefStr(r.Metadata), - CreatedAt: r.CreatedAt, + CreatedAt: derefTime(r.CreatedAt), ExpiresAt: r.ExpiresAt, - AccessCount: derefInt32(r.AccessCount), + AccessCount: derefInt64(r.AccessCount), }, Score: score, } @@ -685,8 +686,8 @@ func (c *postgresClient) PruneExpiredMemories(ctx context.Context) error { func toAgent(r dbgen.Agent) *dbpkg.Agent { return &dbpkg.Agent{ ID: r.ID, - CreatedAt: r.CreatedAt, - UpdatedAt: r.UpdatedAt, + CreatedAt: derefTime(r.CreatedAt), + UpdatedAt: derefTime(r.UpdatedAt), DeletedAt: r.DeletedAt, Type: r.Type, Config: r.Config, @@ -698,8 +699,8 @@ func toSession(r dbgen.Session) *dbpkg.Session { ID: r.ID, UserID: r.UserID, Name: r.Name, - CreatedAt: r.CreatedAt, - UpdatedAt: r.UpdatedAt, + CreatedAt: derefTime(r.CreatedAt), + UpdatedAt: derefTime(r.UpdatedAt), DeletedAt: r.DeletedAt, AgentID: r.AgentID, } @@ -715,8 +716,8 @@ func toEvent(r dbgen.Event) *dbpkg.Event { ID: r.ID, UserID: r.UserID, SessionID: derefStr(r.SessionID), - CreatedAt: r.CreatedAt, - UpdatedAt: r.UpdatedAt, + CreatedAt: derefTime(r.CreatedAt), + UpdatedAt: derefTime(r.UpdatedAt), DeletedAt: r.DeletedAt, Data: r.Data, } @@ -725,8 +726,8 @@ func toEvent(r dbgen.Event) *dbpkg.Event { func toTask(r dbgen.Task) *dbpkg.Task { return &dbpkg.Task{ ID: r.ID, - CreatedAt: r.CreatedAt, - UpdatedAt: r.UpdatedAt, + CreatedAt: derefTime(r.CreatedAt), + UpdatedAt: derefTime(r.UpdatedAt), DeletedAt: r.DeletedAt, Data: r.Data, SessionID: derefStr(r.SessionID), @@ -752,8 +753,8 @@ func toTool(r dbgen.Tool) *dbpkg.Tool { ID: r.ID, ServerName: r.ServerName, GroupKind: r.GroupKind, - CreatedAt: r.CreatedAt, - UpdatedAt: r.UpdatedAt, + CreatedAt: derefTime(r.CreatedAt), + UpdatedAt: derefTime(r.UpdatedAt), DeletedAt: r.DeletedAt, Description: derefStr(r.Description), } @@ -763,8 +764,8 @@ func toToolServer(r dbgen.Toolserver) *dbpkg.ToolServer { return &dbpkg.ToolServer{ Name: r.Name, GroupKind: r.GroupKind, - CreatedAt: r.CreatedAt, - UpdatedAt: r.UpdatedAt, + CreatedAt: derefTime(r.CreatedAt), + UpdatedAt: derefTime(r.UpdatedAt), DeletedAt: r.DeletedAt, Description: derefStr(r.Description), LastConnected: r.LastConnected, @@ -778,13 +779,13 @@ func toCheckpoint(r dbgen.LgCheckpoint) *dbpkg.LangGraphCheckpoint { CheckpointNS: r.CheckpointNs, CheckpointID: r.CheckpointID, ParentCheckpointID: r.ParentCheckpointID, - CreatedAt: r.CreatedAt, - UpdatedAt: r.UpdatedAt, + CreatedAt: derefTime(r.CreatedAt), + UpdatedAt: derefTime(r.UpdatedAt), DeletedAt: r.DeletedAt, Metadata: r.Metadata, Checkpoint: r.Checkpoint, CheckpointType: r.CheckpointType, - Version: derefInt32(r.Version), + Version: derefInt64(r.Version), } } @@ -799,8 +800,8 @@ func toCheckpointWrite(r dbgen.LgCheckpointWrite) *dbpkg.LangGraphCheckpointWrit ValueType: r.ValueType, Channel: r.Channel, TaskID: r.TaskID, - CreatedAt: r.CreatedAt, - UpdatedAt: r.UpdatedAt, + CreatedAt: derefTime(r.CreatedAt), + UpdatedAt: derefTime(r.UpdatedAt), DeletedAt: r.DeletedAt, } } @@ -809,8 +810,8 @@ func toCrewAIMemory(r dbgen.CrewaiAgentMemory) *dbpkg.CrewAIAgentMemory { return &dbpkg.CrewAIAgentMemory{ UserID: r.UserID, ThreadID: r.ThreadID, - CreatedAt: r.CreatedAt, - UpdatedAt: r.UpdatedAt, + CreatedAt: derefTime(r.CreatedAt), + UpdatedAt: derefTime(r.UpdatedAt), DeletedAt: r.DeletedAt, MemoryData: r.MemoryData, } @@ -821,8 +822,8 @@ func toCrewAIFlowState(r dbgen.CrewaiFlowState) *dbpkg.CrewAIFlowState { UserID: r.UserID, ThreadID: r.ThreadID, MethodName: r.MethodName, - CreatedAt: r.CreatedAt, - UpdatedAt: r.UpdatedAt, + CreatedAt: derefTime(r.CreatedAt), + UpdatedAt: derefTime(r.UpdatedAt), DeletedAt: r.DeletedAt, StateData: r.StateData, } @@ -836,9 +837,9 @@ func toMemory(r dbgen.Memory) *dbpkg.Memory { Content: derefStr(r.Content), Embedding: r.Embedding, Metadata: derefStr(r.Metadata), - CreatedAt: r.CreatedAt, + CreatedAt: derefTime(r.CreatedAt), ExpiresAt: r.ExpiresAt, - AccessCount: derefInt32(r.AccessCount), + AccessCount: derefInt64(r.AccessCount), } } @@ -858,13 +859,20 @@ func derefStr(s *string) string { return "" } -func derefInt32(n *int32) int32 { +func derefInt64(n *int64) int64 { if n != nil { return *n } return 0 } +func derefTime(t *time.Time) time.Time { + if t != nil { + return *t + } + return time.Time{} +} + func derefBool(b *bool) bool { if b != nil { return *b diff --git a/go/core/internal/database/gen/feedback.sql.go b/go/core/internal/database/gen/feedback.sql.go index ed3df4ae3..34f4c293e 100644 --- a/go/core/internal/database/gen/feedback.sql.go +++ b/go/core/internal/database/gen/feedback.sql.go @@ -11,10 +11,9 @@ import ( "github.com/kagent-dev/kagent/go/api/database" ) -const insertFeedback = `-- name: InsertFeedback :one +const insertFeedback = `-- name: InsertFeedback :exec INSERT INTO feedback (user_id, message_id, is_positive, feedback_text, issue_type, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, NOW(), NOW()) -RETURNING id, created_at, updated_at, deleted_at, user_id, message_id, is_positive, feedback_text, issue_type ` type InsertFeedbackParams struct { @@ -25,27 +24,15 @@ type InsertFeedbackParams struct { IssueType *database.FeedbackIssueType } -func (q *Queries) InsertFeedback(ctx context.Context, arg InsertFeedbackParams) (Feedback, error) { - row := q.db.QueryRow(ctx, insertFeedback, +func (q *Queries) InsertFeedback(ctx context.Context, arg InsertFeedbackParams) error { + _, err := q.db.Exec(ctx, insertFeedback, arg.UserID, arg.MessageID, arg.IsPositive, arg.FeedbackText, arg.IssueType, ) - var i Feedback - err := row.Scan( - &i.ID, - &i.CreatedAt, - &i.UpdatedAt, - &i.DeletedAt, - &i.UserID, - &i.MessageID, - &i.IsPositive, - &i.FeedbackText, - &i.IssueType, - ) - return i, err + return err } const listFeedback = `-- name: ListFeedback :many diff --git a/go/core/internal/database/gen/langgraph.sql.go b/go/core/internal/database/gen/langgraph.sql.go index 505db2a56..4453595f8 100644 --- a/go/core/internal/database/gen/langgraph.sql.go +++ b/go/core/internal/database/gen/langgraph.sql.go @@ -253,7 +253,7 @@ type UpsertCheckpointParams struct { Metadata string Checkpoint string CheckpointType string - Version *int32 + Version *int64 } func (q *Queries) UpsertCheckpoint(ctx context.Context, arg UpsertCheckpointParams) error { @@ -289,7 +289,7 @@ type UpsertCheckpointWriteParams struct { ThreadID string CheckpointNs string CheckpointID string - WriteIdx int32 + WriteIdx int64 Value string ValueType string Channel string diff --git a/go/core/internal/database/gen/memory.sql.go b/go/core/internal/database/gen/memory.sql.go index 2ada25c70..8e967cc34 100644 --- a/go/core/internal/database/gen/memory.sql.go +++ b/go/core/internal/database/gen/memory.sql.go @@ -70,7 +70,7 @@ type InsertMemoryParams struct { Embedding pgvector_go.Vector Metadata *string ExpiresAt *time.Time - AccessCount *int32 + AccessCount *int64 } func (q *Queries) InsertMemory(ctx context.Context, arg InsertMemoryParams) (string, error) { @@ -152,12 +152,13 @@ type SearchAgentMemoryRow struct { Content *string Embedding pgvector_go.Vector Metadata *string - CreatedAt time.Time + CreatedAt *time.Time ExpiresAt *time.Time - AccessCount *int32 + AccessCount *int64 Score interface{} } +// COALESCE guards against NULL embeddings (score=0 rather than NULL); rows are still ordered last by the ORDER BY clause. func (q *Queries) SearchAgentMemory(ctx context.Context, arg SearchAgentMemoryParams) ([]SearchAgentMemoryRow, error) { rows, err := q.db.Query(ctx, searchAgentMemory, arg.Embedding, diff --git a/go/core/internal/database/gen/models.go b/go/core/internal/database/gen/models.go index 4b26661da..b3cd22eca 100644 --- a/go/core/internal/database/gen/models.go +++ b/go/core/internal/database/gen/models.go @@ -14,8 +14,8 @@ import ( type Agent struct { ID string - CreatedAt time.Time - UpdatedAt time.Time + CreatedAt *time.Time + UpdatedAt *time.Time DeletedAt *time.Time Type string Config *adk.AgentConfig @@ -24,8 +24,8 @@ type Agent struct { type CrewaiAgentMemory struct { UserID string ThreadID string - CreatedAt time.Time - UpdatedAt time.Time + CreatedAt *time.Time + UpdatedAt *time.Time DeletedAt *time.Time MemoryData string } @@ -34,8 +34,8 @@ type CrewaiFlowState struct { UserID string ThreadID string MethodName string - CreatedAt time.Time - UpdatedAt time.Time + CreatedAt *time.Time + UpdatedAt *time.Time DeletedAt *time.Time StateData string } @@ -44,8 +44,8 @@ type Event struct { ID string UserID string SessionID *string - CreatedAt time.Time - UpdatedAt time.Time + CreatedAt *time.Time + UpdatedAt *time.Time DeletedAt *time.Time Data string } @@ -68,13 +68,13 @@ type LgCheckpoint struct { CheckpointNs string CheckpointID string ParentCheckpointID *string - CreatedAt time.Time - UpdatedAt time.Time + CreatedAt *time.Time + UpdatedAt *time.Time DeletedAt *time.Time Metadata string Checkpoint string CheckpointType string - Version *int32 + Version *int64 } type LgCheckpointWrite struct { @@ -82,13 +82,13 @@ type LgCheckpointWrite struct { ThreadID string CheckpointNs string CheckpointID string - WriteIdx int32 + WriteIdx int64 Value string ValueType string Channel string TaskID string - CreatedAt time.Time - UpdatedAt time.Time + CreatedAt *time.Time + UpdatedAt *time.Time DeletedAt *time.Time } @@ -99,16 +99,16 @@ type Memory struct { Content *string Embedding pgvector_go.Vector Metadata *string - CreatedAt time.Time + CreatedAt *time.Time ExpiresAt *time.Time - AccessCount *int32 + AccessCount *int64 } type PushNotification struct { ID string TaskID string - CreatedAt time.Time - UpdatedAt time.Time + CreatedAt *time.Time + UpdatedAt *time.Time DeletedAt *time.Time Data string } @@ -117,8 +117,8 @@ type Session struct { ID string UserID string Name *string - CreatedAt time.Time - UpdatedAt time.Time + CreatedAt *time.Time + UpdatedAt *time.Time DeletedAt *time.Time AgentID *string Source *string @@ -126,8 +126,8 @@ type Session struct { type Task struct { ID string - CreatedAt time.Time - UpdatedAt time.Time + CreatedAt *time.Time + UpdatedAt *time.Time DeletedAt *time.Time Data string SessionID *string @@ -137,8 +137,8 @@ type Tool struct { ID string ServerName string GroupKind string - CreatedAt time.Time - UpdatedAt time.Time + CreatedAt *time.Time + UpdatedAt *time.Time DeletedAt *time.Time Description *string } @@ -146,8 +146,8 @@ type Tool struct { type Toolserver struct { Name string GroupKind string - CreatedAt time.Time - UpdatedAt time.Time + CreatedAt *time.Time + UpdatedAt *time.Time DeletedAt *time.Time Description *string LastConnected *time.Time diff --git a/go/core/internal/database/gen/querier.go b/go/core/internal/database/gen/querier.go index 281d22edf..ae785ddba 100644 --- a/go/core/internal/database/gen/querier.go +++ b/go/core/internal/database/gen/querier.go @@ -24,7 +24,7 @@ type Querier interface { HardDeleteCrewAIMemory(ctx context.Context, arg HardDeleteCrewAIMemoryParams) error IncrementMemoryAccessCount(ctx context.Context, dollar_1 []string) error InsertEvent(ctx context.Context, arg InsertEventParams) error - InsertFeedback(ctx context.Context, arg InsertFeedbackParams) (Feedback, error) + InsertFeedback(ctx context.Context, arg InsertFeedbackParams) error InsertMemory(ctx context.Context, arg InsertMemoryParams) (string, error) ListAgentMemories(ctx context.Context, arg ListAgentMemoriesParams) ([]Memory, error) ListAgents(ctx context.Context) ([]Agent, error) @@ -45,6 +45,7 @@ type Querier interface { ListToolServers(ctx context.Context) ([]Toolserver, error) ListTools(ctx context.Context) ([]Tool, error) ListToolsForServer(ctx context.Context, arg ListToolsForServerParams) ([]Tool, error) + // COALESCE guards against NULL embeddings (score=0 rather than NULL); rows are still ordered last by the ORDER BY clause. SearchAgentMemory(ctx context.Context, arg SearchAgentMemoryParams) ([]SearchAgentMemoryRow, error) SearchCrewAIMemoryByTask(ctx context.Context, arg SearchCrewAIMemoryByTaskParams) ([]CrewaiAgentMemory, error) SearchCrewAIMemoryByTaskLimit(ctx context.Context, arg SearchCrewAIMemoryByTaskLimitParams) ([]CrewaiAgentMemory, error) diff --git a/go/core/internal/database/queries/feedback.sql b/go/core/internal/database/queries/feedback.sql index e5f9a48b2..6d5b700fe 100644 --- a/go/core/internal/database/queries/feedback.sql +++ b/go/core/internal/database/queries/feedback.sql @@ -1,7 +1,6 @@ --- name: InsertFeedback :one +-- name: InsertFeedback :exec INSERT INTO feedback (user_id, message_id, is_positive, feedback_text, issue_type, created_at, updated_at) -VALUES ($1, $2, $3, $4, $5, NOW(), NOW()) -RETURNING *; +VALUES ($1, $2, $3, $4, $5, NOW(), NOW()); -- name: ListFeedback :many SELECT * FROM feedback diff --git a/go/core/internal/database/queries/memory.sql b/go/core/internal/database/queries/memory.sql index 4cb88edb6..0a0767086 100644 --- a/go/core/internal/database/queries/memory.sql +++ b/go/core/internal/database/queries/memory.sql @@ -4,6 +4,7 @@ VALUES ($1, $2, $3, $4, $5, NOW(), $6, $7) RETURNING id; -- name: SearchAgentMemory :many +-- COALESCE guards against NULL embeddings (score=0 rather than NULL); rows are still ordered last by the ORDER BY clause. SELECT *, COALESCE(1 - (embedding <=> $1), 0) AS score FROM memory WHERE agent_name = $2 AND user_id = $3 diff --git a/go/core/internal/dbtest/dbtest.go b/go/core/internal/dbtest/dbtest.go index b5a38409c..a5573d7ea 100644 --- a/go/core/internal/dbtest/dbtest.go +++ b/go/core/internal/dbtest/dbtest.go @@ -64,7 +64,7 @@ func StartT(ctx context.Context, t *testing.T) string { // If vectorEnabled is true the vector pass is also applied. // Use MigrateT in tests that have a *testing.T; use Migrate in TestMain where no T is available. func Migrate(connStr string, vectorEnabled bool) error { - return migrations.RunUp(context.Background(), connStr, migrations.FS, vectorEnabled) + return migrations.RunUp(connStr, migrations.FS, vectorEnabled) } // MigrateT runs the embedded OSS migrations against connStr and calls t.Fatal on error. diff --git a/go/core/internal/httpserver/handlers/checkpoints.go b/go/core/internal/httpserver/handlers/checkpoints.go index 9664b4727..d3e4655cf 100644 --- a/go/core/internal/httpserver/handlers/checkpoints.go +++ b/go/core/internal/httpserver/handlers/checkpoints.go @@ -110,7 +110,7 @@ func (h *CheckpointsHandler) HandlePutCheckpoint(w ErrorResponseWriter, r *http. ParentCheckpointID: req.ParentCheckpointID, Metadata: req.Metadata, Checkpoint: req.Checkpoint, - Version: int32(req.Version), + Version: int64(req.Version), CheckpointType: req.Type, } // Store checkpoint and writes atomically @@ -232,7 +232,7 @@ func (h *CheckpointsHandler) HandlePutWrites(w ErrorResponseWriter, r *http.Requ ThreadID: req.ThreadID, CheckpointNS: req.CheckpointNS, CheckpointID: req.CheckpointID, - WriteIdx: int32(writeReq.Idx), + WriteIdx: int64(writeReq.Idx), Value: writeReq.Value, ValueType: writeReq.Type, Channel: writeReq.Channel, diff --git a/go/core/pkg/app/app.go b/go/core/pkg/app/app.go index 6ae0b8ec8..0d8776576 100644 --- a/go/core/pkg/app/app.go +++ b/go/core/pkg/app/app.go @@ -264,11 +264,14 @@ type ExtensionConfig struct { type GetExtensionConfig func(bootstrap BootstrapConfig) (*ExtensionConfig, error) // MigrationRunner applies database migrations given the resolved connection URL. -// vectorEnabled mirrors the --database-vector-enabled flag. +// vectorEnabled mirrors the --database-vector-enabled flag; custom runners can use it +// to conditionally apply vector-specific migrations. // Returning a non-nil error causes the app to exit. // // Pass nil to Start to use the default migration runner (migrations.RunUp with migrations.FS). -// An optional migration runner can be provided to take over the migration process. +// Provide a custom runner to take over the migration process entirely — for example, +// to run additional enterprise migrations alongside or instead of the built-in ones. +// Custom runners that want to include the built-in migrations can call migrations.RunUp directly. type MigrationRunner func(ctx context.Context, url string, vectorEnabled bool) error func Start(getExtensionConfig GetExtensionConfig, migrationRunner MigrationRunner) { @@ -429,8 +432,8 @@ func Start(getExtensionConfig GetExtensionConfig, migrationRunner MigrationRunne // Use the built-in migration runner when none is provided. if migrationRunner == nil { - migrationRunner = func(ctx context.Context, url string, vectorEnabled bool) error { - return migrations.RunUp(ctx, url, migrations.FS, vectorEnabled) + migrationRunner = func(_ context.Context, url string, vectorEnabled bool) error { + return migrations.RunUp(url, migrations.FS, vectorEnabled) } } diff --git a/go/core/pkg/migrations/runner.go b/go/core/pkg/migrations/runner.go index ff19cd918..e062ffa95 100644 --- a/go/core/pkg/migrations/runner.go +++ b/go/core/pkg/migrations/runner.go @@ -1,7 +1,6 @@ package migrations import ( - "context" "database/sql" "errors" "fmt" @@ -17,7 +16,7 @@ import ( // RunUp applies all pending migrations for the given FS. // vectorEnabled controls whether the vector track is also applied. // Returns an error if any track fails (and attempts rollback of previously applied tracks). -func RunUp(_ context.Context, url string, migrationsFS fs.FS, vectorEnabled bool) error { +func RunUp(url string, migrationsFS fs.FS, vectorEnabled bool) error { corePrev, err := applyDir(url, migrationsFS, "core", "schema_migrations") if err != nil { return fmt.Errorf("core migrations: %w", err) From 6ddc959a9044b23ea8da8d61cca67616386a6830 Mon Sep 17 00:00:00 2001 From: Jeremy Alvis Date: Thu, 2 Apr 2026 10:18:13 -0700 Subject: [PATCH 11/16] Add additional checks and update skills for migration safety Signed-off-by: Jeremy Alvis --- .../references/database-migrations.md | 61 +++- go/core/pkg/migrations/cross_track_test.go | 296 ++++++++++++++++++ go/core/pkg/migrations/runner_test.go | 52 +++ 3 files changed, 397 insertions(+), 12 deletions(-) create mode 100644 go/core/pkg/migrations/cross_track_test.go diff --git a/.claude/skills/kagent-dev/references/database-migrations.md b/.claude/skills/kagent-dev/references/database-migrations.md index f38df3839..fd7ce5288 100644 --- a/.claude/skills/kagent-dev/references/database-migrations.md +++ b/.claude/skills/kagent-dev/references/database-migrations.md @@ -50,18 +50,27 @@ A CI check (`.github/workflows/sqlc-generate-check.yaml`) fails the PR if `gen/` ## Writing Migrations -### Version compatibility policy +### Additive-only policy -kagent supports **n-1 minor version** compatibility. Users must not skip minor versions when upgrading. This gives us a defined window for schema cleanup: +The schema is **additive-only**. Columns and tables are deprecated in application code but never removed from the database. This is not just a kagent convention — it is the only reliable guarantee that can be made to downstream consumers who deploy on their own schedule and may have FK constraints pointing at kagent-owned tables. -- **Version N**: stop using the old column/table in application code; the schema still contains it (backward compatible with N-1) -- **Version N+1**: drop the old column/table (or N+2 for additional safety if rollback risk is high) +This mirrors how mature projects handle the same problem: Salesforce platform-enforces additive-only for managed packages after GA; Stripe never removes fields from a versioned API response; GitLab requires multi-phase explicit FK teardown before any column can be contracted. -Never migrate data and remove the old structure in the same migration — if the migration fails mid-way, rollback is much harder. Always separate the two steps across versions. +**Why contraction is unsafe with multiple tracks and downstream consumers:** -### Backward-compatible schema changes (expand/contract) +The two-track design (core → vector) and downstream consumers (who may add their own migration track on top of core/vector) create a class of failure that has no clean runtime fix: -During a rolling deploy, old pods (running the previous code version) will be reading and writing a schema that has already been upgraded by the new pod's init container. **Every migration must be backward-compatible with the n-1 minor version's code.** Locking serializes concurrent migration runs but does nothing to protect old pods still running against the new schema. +1. **Fresh install**: all core migrations run to completion — including any contraction — before vector or downstream migrations run. A later track's migration referencing a contracted column fails because it no longer exists. +2. **Existing database**: Postgres CASCADE silently drops dependent indexes or constraints created by a later track when core contracts the column. Migration tracking shows the later track at its old version, but the actual schema no longer matches. +3. **Downstream at unknown version**: downstream consumers may have deployed weeks behind. A core contraction breaks their upgrade path with no warning. + +No migration tool (Flyway, Liquibase, Atlas, golang-migrate) automatically detects or prevents this. It must be enforced by policy. + +**Contracting is not allowed.** If a column or table is no longer needed, stop using it in application code and leave it in the database. + +### Backward-compatible schema changes + +During a rolling deploy, old pods will be reading and writing a schema that has already been upgraded. **Every migration must be backward-compatible with the previous version's code.** | Change | Old code behavior | Safe? | |--------|------------------|-------| @@ -73,12 +82,26 @@ During a rolling deploy, old pods (running the previous code version) will be re | Drop/rename column old code references | Old SELECT/INSERT errors | ❌ | | Change compatible type (e.g. `int` → `bigint`) | Usually fine | ⚠️ | -**Expand/contract pattern for destructive changes:** +**Expand pattern for schema changes:** 1. **Version N (Expand)**: add the new column/table (nullable or with default); old code still works -2. **Version N (Deploy)**: ship new code that reads from the new structure, writes to both -3. **Version N+1 (Contract)**: drop the old column/table in a follow-on migration +2. **Version N (Deploy)**: ship new code that uses the new structure +3. Old column/table stays in the database indefinitely — stop using it in code, do not drop it + +### Idempotency and cross-track safety + +All DDL statements must use `IF EXISTS` / `IF NOT EXISTS` guards: + +```sql +-- Up +CREATE TABLE IF NOT EXISTS foo (...); +ALTER TABLE foo ADD COLUMN IF NOT EXISTS bar TEXT; -Never drop a column or rename a column in the same release as the code change that stops using it. +-- Down +DROP TABLE IF EXISTS foo; +ALTER TABLE foo DROP COLUMN IF EXISTS bar; +``` + +This is especially important because the `core` and `vector` tracks are applied sequentially and rolled back independently. If a `core` migration that a `vector` migration depends on is rolled back (e.g., vector fails and triggers a core rollback), the vector track may later attempt to reference a table or column that no longer exists. Guards prevent those cross-track dependencies from causing hard errors during rollback. ### Naming @@ -120,4 +143,18 @@ For additive, backward-compatible migrations a rolling update is safe: 2. New pod passes readiness probe → old pod terminates 3. Backward-compatible schema means old pods continue operating during the window -For a migration that is **not** backward-compatible, restructure it using expand/contract. +For a migration that is **not** backward-compatible, restructure it using the expand pattern (add new column/table, migrate code, leave old column in place indefinitely). + +## Static Analysis Enforcement + +The policies above are enforced by static analysis tests in `go/core/pkg/migrations/cross_track_test.go`. These run against the embedded SQL files — no database required. + +| Test | What it enforces | +|------|-----------------| +| `TestNoContractingDDL` | Up migrations must not contain `DROP TABLE`, `DROP COLUMN`, `RENAME TABLE`, or `RENAME COLUMN` | +| `TestNoCrossTrackDDL` | No track may `ALTER TABLE` or `CREATE INDEX ON` a table owned by another track | +| `TestMigrationGuards` | Up migrations must use `IF NOT EXISTS` on all `CREATE`/`ADD COLUMN`; down migrations must use `IF EXISTS` on all `DROP` statements | + +**Adding a new track**: add the track directory name to the `tracks` slice in each test so the new track is covered by the same checks. + +These tests catch policy violations at PR time without needing a running database. They complement the integration tests in `runner_test.go`, which verify the runner's rollback and concurrency behavior against a real Postgres instance. diff --git a/go/core/pkg/migrations/cross_track_test.go b/go/core/pkg/migrations/cross_track_test.go new file mode 100644 index 000000000..5b588fff3 --- /dev/null +++ b/go/core/pkg/migrations/cross_track_test.go @@ -0,0 +1,296 @@ +package migrations_test + +import ( + "fmt" + "io/fs" + "regexp" + "strings" + "testing" + + "github.com/kagent-dev/kagent/go/core/pkg/migrations" +) + +// Cross-track DDL rules: +// +// - Each track owns the tables it creates (via CREATE TABLE). +// - A track must not ALTER TABLE or CREATE INDEX ON a table owned by another track. +// +// This is a static analysis check against the embedded migration files. It runs +// against real SQL — no database required. + +var ( + createTableRe = regexp.MustCompile(`(?i)CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?(\w+)`) + alterTableRe = regexp.MustCompile(`(?i)ALTER\s+TABLE\s+(?:IF\s+EXISTS\s+)?(\w+)`) + createIndexRe = regexp.MustCompile(`(?i)CREATE\s+(?:UNIQUE\s+)?INDEX\s+(?:IF\s+NOT\s+EXISTS\s+)?(?:\w+\s+)?ON\s+(\w+)`) +) + +// ownedTables returns the set of table names created by up migrations in fsys. +func ownedTables(fsys fs.FS) (map[string]string, error) { + tables := make(map[string]string) // table name → file that created it + err := fs.WalkDir(fsys, ".", func(path string, d fs.DirEntry, err error) error { + if err != nil || d.IsDir() || !strings.HasSuffix(path, ".up.sql") { + return err + } + data, err := fs.ReadFile(fsys, path) + if err != nil { + return err + } + for _, m := range createTableRe.FindAllSubmatch(data, -1) { + name := strings.ToLower(string(m[1])) + tables[name] = path + } + return nil + }) + return tables, err +} + +type violation struct { + file string + statement string + table string + ownedBy string +} + +// crossTrackViolations returns any up-migration DDL in fsys that modifies a +// table owned by another track. +func crossTrackViolations(fsys fs.FS, foreignTables map[string]string) ([]violation, error) { + var violations []violation + err := fs.WalkDir(fsys, ".", func(path string, d fs.DirEntry, err error) error { + if err != nil || d.IsDir() || !strings.HasSuffix(path, ".up.sql") { + return err + } + data, err := fs.ReadFile(fsys, path) + if err != nil { + return err + } + content := string(data) + + check := func(matches [][]string) { + for _, m := range matches { + table := strings.ToLower(m[1]) + if owner, ok := foreignTables[table]; ok { + violations = append(violations, violation{ + file: path, + statement: m[0], + table: table, + ownedBy: owner, + }) + } + } + } + check(alterTableRe.FindAllStringSubmatch(content, -1)) + check(createIndexRe.FindAllStringSubmatch(content, -1)) + return nil + }) + return violations, err +} + +// contractingPatterns lists DDL that shrinks or renames the schema. +// These are forbidden in up migrations — schema changes must be additive-only. +// Down migrations are intentionally excluded from this check. +var contractingPatterns = []struct { + name string + re *regexp.Regexp +}{ + // Removing a table entirely. + {"DROP TABLE", regexp.MustCompile(`(?i)\bDROP\s+TABLE\b`)}, + // Removing a column with the explicit COLUMN keyword. + {"DROP COLUMN", regexp.MustCompile(`(?i)\bDROP\s+COLUMN\b`)}, + // Renaming a table changes the name old code references. + {"RENAME TABLE", regexp.MustCompile(`(?i)\bALTER\s+TABLE\b[^;]+\bRENAME\s+TO\b`)}, + // Renaming a column breaks any code or query that still uses the old name. + {"RENAME COLUMN", regexp.MustCompile(`(?i)\bRENAME\s+COLUMN\b`)}, +} + +// alterDropRe captures the word immediately after ALTER TABLE ... DROP [IF EXISTS]. +// The COLUMN keyword is optional in Postgres, so "ALTER TABLE foo DROP bar" is +// a valid column removal. We capture the first word after DROP and check whether +// it is a known non-contracting variant (COLUMN, CONSTRAINT, DEFAULT, NOT). +var alterDropRe = regexp.MustCompile(`(?i)\bALTER\s+TABLE\s+\S+\s+DROP\s+(?:IF\s+EXISTS\s+)?(\w+)`) + +// safeDropKeywords are words that can legitimately follow DROP in an ALTER TABLE +// without removing a column. +var safeDropKeywords = map[string]bool{ + "column": true, // already caught by the DROP COLUMN pattern above + "constraint": true, // removes a constraint, not a column + "default": true, // removes a column default, not the column itself + "not": true, // ALTER TABLE t ALTER COLUMN c DROP NOT NULL +} + +// TestNoContractingDDL enforces the additive-only schema policy: up migrations +// must never remove or rename tables or columns. Down migrations are excluded +// because they exist specifically to reverse schema changes. +func TestNoContractingDDL(t *testing.T) { + tracks := []string{"core", "vector"} + + for _, track := range tracks { + sub, err := fs.Sub(migrations.FS, track) + if err != nil { + t.Fatalf("fs.Sub(%q): %v", track, err) + } + + err = fs.WalkDir(sub, ".", func(path string, d fs.DirEntry, err error) error { + if err != nil || d.IsDir() || !strings.HasSuffix(path, ".up.sql") { + return err + } + data, err := fs.ReadFile(sub, path) + if err != nil { + return err + } + content := string(data) + + for _, p := range contractingPatterns { + if m := p.re.FindString(content); m != "" { + t.Errorf( + "contracting DDL in %s/%s: %q matches %q — up migrations must be additive-only", + track, path, m, p.name, + ) + } + } + + // Check for bare column drops: ALTER TABLE foo DROP bar (no COLUMN keyword). + // RE2 has no negative lookahead, so we capture the word and filter here. + for _, m := range alterDropRe.FindAllStringSubmatch(content, -1) { + if !safeDropKeywords[strings.ToLower(m[1])] { + t.Errorf( + "contracting DDL in %s/%s: %q — bare DROP without COLUMN keyword is a column removal; up migrations must be additive-only", + track, path, m[0], + ) + } + } + + return nil + }) + if err != nil { + t.Fatalf("WalkDir(%q): %v", track, err) + } + } +} + +// guardCheck describes a DDL statement that requires an idempotency guard. +// re captures the first significant word after the keyword; if that word is not +// "if" (case-insensitive) the guard is absent. +type guardCheck struct { + name string + re *regexp.Regexp +} + +// upGuardChecks are statements in up migrations that must use IF NOT EXISTS. +var upGuardChecks = []guardCheck{ + {"CREATE TABLE", regexp.MustCompile(`(?i)\bCREATE\s+TABLE\s+(\w+)`)}, + {"CREATE INDEX", regexp.MustCompile(`(?i)\bCREATE\s+(?:UNIQUE\s+)?INDEX\s+(?:CONCURRENTLY\s+)?(\w+)`)}, + {"CREATE EXTENSION", regexp.MustCompile(`(?i)\bCREATE\s+EXTENSION\s+(\w+)`)}, + {"ADD COLUMN", regexp.MustCompile(`(?i)\bADD\s+COLUMN\s+(\w+)`)}, +} + +// downGuardChecks are statements in down migrations that must use IF EXISTS. +var downGuardChecks = []guardCheck{ + {"DROP TABLE", regexp.MustCompile(`(?i)\bDROP\s+TABLE\s+(\w+)`)}, + {"DROP INDEX", regexp.MustCompile(`(?i)\bDROP\s+INDEX\s+(\w+)`)}, + {"DROP EXTENSION", regexp.MustCompile(`(?i)\bDROP\s+EXTENSION\s+(\w+)`)}, + {"DROP COLUMN", regexp.MustCompile(`(?i)\bDROP\s+COLUMN\s+(\w+)`)}, +} + +// TestMigrationGuards enforces idempotency guards across all migration files: +// - Up migrations: CREATE TABLE/INDEX/EXTENSION and ADD COLUMN must use IF NOT EXISTS. +// - Down migrations: DROP TABLE/INDEX/EXTENSION/COLUMN must use IF EXISTS. +// +// This ensures migrations are safe to re-run and that the two-track rollback +// logic can call down migrations more than once without errors. +func TestMigrationGuards(t *testing.T) { + tracks := []string{"core", "vector"} + + for _, track := range tracks { + sub, err := fs.Sub(migrations.FS, track) + if err != nil { + t.Fatalf("fs.Sub(%q): %v", track, err) + } + + err = fs.WalkDir(sub, ".", func(path string, d fs.DirEntry, err error) error { + if err != nil || d.IsDir() { + return err + } + + var checks []guardCheck + switch { + case strings.HasSuffix(path, ".up.sql"): + checks = upGuardChecks + case strings.HasSuffix(path, ".down.sql"): + checks = downGuardChecks + default: + return nil + } + + data, err := fs.ReadFile(sub, path) + if err != nil { + return err + } + content := string(data) + + for _, c := range checks { + for _, m := range c.re.FindAllStringSubmatch(content, -1) { + if !strings.EqualFold(m[1], "if") { + t.Errorf( + "missing guard in %s/%s: %q — %s requires IF NOT EXISTS / IF EXISTS", + track, path, m[0], c.name, + ) + } + } + } + return nil + }) + if err != nil { + t.Fatalf("WalkDir(%q): %v", track, err) + } + } +} + +// TestNoCrossTrackDDL verifies that no migration track modifies tables owned +// by another track. Each track must only ALTER or index its own tables. +func TestNoCrossTrackDDL(t *testing.T) { + tracks := []string{"core", "vector"} + + // Build the ownership map for each track. + owned := make(map[string]map[string]string, len(tracks)) + for _, track := range tracks { + sub, err := fs.Sub(migrations.FS, track) + if err != nil { + t.Fatalf("fs.Sub(%q): %v", track, err) + } + tables, err := ownedTables(sub) + if err != nil { + t.Fatalf("ownedTables(%q): %v", track, err) + } + owned[track] = tables + } + + // For each track, check its migrations don't touch tables owned by others. + for _, track := range tracks { + sub, err := fs.Sub(migrations.FS, track) + if err != nil { + t.Fatalf("fs.Sub(%q): %v", track, err) + } + + // Collect all tables owned by *other* tracks. + foreign := make(map[string]string) + for otherTrack, tables := range owned { + if otherTrack == track { + continue + } + for table, file := range tables { + foreign[table] = fmt.Sprintf("%s/%s", otherTrack, file) + } + } + + violations, err := crossTrackViolations(sub, foreign) + if err != nil { + t.Fatalf("crossTrackViolations(%q): %v", track, err) + } + for _, v := range violations { + t.Errorf( + "cross-track DDL violation: %s/%s contains %q targeting table %q (owned by %s)", + track, v.file, v.statement, v.table, v.ownedBy, + ) + } + } +} diff --git a/go/core/pkg/migrations/runner_test.go b/go/core/pkg/migrations/runner_test.go index d5f5262e3..2db4ad996 100644 --- a/go/core/pkg/migrations/runner_test.go +++ b/go/core/pkg/migrations/runner_test.go @@ -51,6 +51,23 @@ var failVectorFS = fstest.MapFS{ "vector/000001_bad.down.sql": {Data: []byte(`SELECT 1;`)}, } +// expandCoreFS creates shared_data with two columns. Used to test cross-track +// rollback scenarios where the vector track depends on this table. +var expandCoreFS = fstest.MapFS{ + "core/000001_create_shared.up.sql": {Data: []byte(`CREATE TABLE IF NOT EXISTS shared_data (id SERIAL PRIMARY KEY, col_a TEXT);`)}, + "core/000001_create_shared.down.sql": {Data: []byte(`DROP TABLE IF EXISTS shared_data;`)}, + "core/000002_add_col_b.up.sql": {Data: []byte(`ALTER TABLE shared_data ADD COLUMN IF NOT EXISTS col_b TEXT;`)}, + "core/000002_add_col_b.down.sql": {Data: []byte(`ALTER TABLE shared_data DROP COLUMN IF EXISTS col_b;`)}, +} + +// failVectorWithDependencyFS is a vector migration that partially succeeds +// (adds a column to shared_data) then fails. Its down migration uses IF EXISTS +// so rollback is safe even if the column was never added. +var failVectorWithDependencyFS = fstest.MapFS{ + "vector/000001_bad_depends_on_core.up.sql": {Data: []byte(`ALTER TABLE shared_data ADD COLUMN IF NOT EXISTS vec_col VECTOR(3); ALTER TABLE no_such_table ADD COLUMN x TEXT;`)}, + "vector/000001_bad_depends_on_core.down.sql": {Data: []byte(`ALTER TABLE shared_data DROP COLUMN IF EXISTS vec_col;`)}, +} + // mergeFS combines multiple MapFS values into one. func mergeFS(fsMaps ...fstest.MapFS) fstest.MapFS { out := fstest.MapFS{} @@ -279,3 +296,38 @@ func TestCrossTrackRollback_CoreRolledBackWhenVectorFails(t *testing.T) { t.Errorf("core version after cross-track rollback = %d, want %d", got, corePrev) } } + +// TestCrossTrackRollback_IfExistsGuardsSafeOnVectorFailure verifies that when a +// vector migration fails and triggers a core cross-track rollback, the IF EXISTS +// guards in both down migrations prevent errors even though the vector migration +// only partially applied and shared_data is being dropped by core's rollback. +func TestCrossTrackRollback_IfExistsGuardsSafeOnVectorFailure(t *testing.T) { + connStr := startTestDB(t) + + combined := mergeFS(expandCoreFS, failVectorWithDependencyFS) + + // Core succeeds (shared_data created with col_a and col_b). + corePrev, err := applyDir(connStr, combined, "core", "schema_migrations") + if err != nil { + t.Fatalf("core apply: %v", err) + } + if got := trackVersion(t, connStr, "schema_migrations"); got != 2 { + t.Fatalf("core version = %d, want 2", got) + } + + // Vector fails; its down migration (DROP COLUMN IF EXISTS vec_col) must not + // error even though the column was never added. + if _, err := applyDir(connStr, combined, "vector", "vector_schema_migrations"); err == nil { + t.Fatal("expected vector error, got nil") + } + if got := trackVersion(t, connStr, "vector_schema_migrations"); got != 0 { + t.Errorf("vector version after self-rollback = %d, want 0", got) + } + + // Cross-track rollback: core rolls back to its pre-run version. + // Core's down migration (DROP TABLE IF EXISTS shared_data) must succeed. + rollbackDir(connStr, combined, "core", "schema_migrations", corePrev) + if got := trackVersion(t, connStr, "schema_migrations"); got != corePrev { + t.Errorf("core version after cross-track rollback = %d, want %d", got, corePrev) + } +} From 5d083558db715d93847399855b14dd3fcd5312fe Mon Sep 17 00:00:00 2001 From: Jeremy Alvis Date: Fri, 3 Apr 2026 13:47:00 -0700 Subject: [PATCH 12/16] Updates to simplify Signed-off-by: Jeremy Alvis --- .../references/database-migrations.md | 44 +++++----- go/core/pkg/migrations/cross_track_test.go | 82 ------------------- 2 files changed, 20 insertions(+), 106 deletions(-) diff --git a/.claude/skills/kagent-dev/references/database-migrations.md b/.claude/skills/kagent-dev/references/database-migrations.md index fd7ce5288..9478ff859 100644 --- a/.claude/skills/kagent-dev/references/database-migrations.md +++ b/.claude/skills/kagent-dev/references/database-migrations.md @@ -50,24 +50,6 @@ A CI check (`.github/workflows/sqlc-generate-check.yaml`) fails the PR if `gen/` ## Writing Migrations -### Additive-only policy - -The schema is **additive-only**. Columns and tables are deprecated in application code but never removed from the database. This is not just a kagent convention — it is the only reliable guarantee that can be made to downstream consumers who deploy on their own schedule and may have FK constraints pointing at kagent-owned tables. - -This mirrors how mature projects handle the same problem: Salesforce platform-enforces additive-only for managed packages after GA; Stripe never removes fields from a versioned API response; GitLab requires multi-phase explicit FK teardown before any column can be contracted. - -**Why contraction is unsafe with multiple tracks and downstream consumers:** - -The two-track design (core → vector) and downstream consumers (who may add their own migration track on top of core/vector) create a class of failure that has no clean runtime fix: - -1. **Fresh install**: all core migrations run to completion — including any contraction — before vector or downstream migrations run. A later track's migration referencing a contracted column fails because it no longer exists. -2. **Existing database**: Postgres CASCADE silently drops dependent indexes or constraints created by a later track when core contracts the column. Migration tracking shows the later track at its old version, but the actual schema no longer matches. -3. **Downstream at unknown version**: downstream consumers may have deployed weeks behind. A core contraction breaks their upgrade path with no warning. - -No migration tool (Flyway, Liquibase, Atlas, golang-migrate) automatically detects or prevents this. It must be enforced by policy. - -**Contracting is not allowed.** If a column or table is no longer needed, stop using it in application code and leave it in the database. - ### Backward-compatible schema changes During a rolling deploy, old pods will be reading and writing a schema that has already been upgraded. **Every migration must be backward-compatible with the previous version's code.** @@ -82,10 +64,10 @@ During a rolling deploy, old pods will be reading and writing a schema that has | Drop/rename column old code references | Old SELECT/INSERT errors | ❌ | | Change compatible type (e.g. `int` → `bigint`) | Usually fine | ⚠️ | -**Expand pattern for schema changes:** +**Expand-then-contract pattern for schema changes:** 1. **Version N (Expand)**: add the new column/table (nullable or with default); old code still works 2. **Version N (Deploy)**: ship new code that uses the new structure -3. Old column/table stays in the database indefinitely — stop using it in code, do not drop it +3. **Version N+1 (Contract)**: drop the old column/table once version N is fully deployed and no pods run version N-1 ### Idempotency and cross-track safety @@ -101,7 +83,7 @@ DROP TABLE IF EXISTS foo; ALTER TABLE foo DROP COLUMN IF EXISTS bar; ``` -This is especially important because the `core` and `vector` tracks are applied sequentially and rolled back independently. If a `core` migration that a `vector` migration depends on is rolled back (e.g., vector fails and triggers a core rollback), the vector track may later attempt to reference a table or column that no longer exists. Guards prevent those cross-track dependencies from causing hard errors during rollback. +Guards provide defense-in-depth for crash recovery and dirty-state cleanup, where a partially-applied migration may be re-run or rolled back. ### Naming @@ -137,13 +119,13 @@ If the controller crashes mid-migration, the migration runner records the versio ### Rollout strategy -For additive, backward-compatible migrations a rolling update is safe: +For backward-compatible migrations a rolling update is safe: 1. New pod starts → migration runner applies pending migrations (advisory lock serializes concurrent runs) 2. New pod passes readiness probe → old pod terminates 3. Backward-compatible schema means old pods continue operating during the window -For a migration that is **not** backward-compatible, restructure it using the expand pattern (add new column/table, migrate code, leave old column in place indefinitely). +For a migration that is **not** backward-compatible, restructure it using the expand-then-contract pattern (add new column/table in version N, ship code that uses it, drop the old column in version N+1). ## Static Analysis Enforcement @@ -151,10 +133,24 @@ The policies above are enforced by static analysis tests in `go/core/pkg/migrati | Test | What it enforces | |------|-----------------| -| `TestNoContractingDDL` | Up migrations must not contain `DROP TABLE`, `DROP COLUMN`, `RENAME TABLE`, or `RENAME COLUMN` | | `TestNoCrossTrackDDL` | No track may `ALTER TABLE` or `CREATE INDEX ON` a table owned by another track | | `TestMigrationGuards` | Up migrations must use `IF NOT EXISTS` on all `CREATE`/`ADD COLUMN`; down migrations must use `IF EXISTS` on all `DROP` statements | **Adding a new track**: add the track directory name to the `tracks` slice in each test so the new track is covered by the same checks. These tests catch policy violations at PR time without needing a running database. They complement the integration tests in `runner_test.go`, which verify the runner's rollback and concurrency behavior against a real Postgres instance. + +## Downstream Extension Model + +The migration layer is designed for downstream consumers to extend with their own migrations alongside OSS. The extension points are: + +1. **SQL files as the contract.** The migration files in `go/core/pkg/migrations/core/` and `vector/` are the stable interface. Downstream consumers sync these files into their own repos and build their own migration runners. Don't move or reorganize migration file paths without considering downstream impact. + +2. **`MigrationRunner` DI callback.** Downstream consumers pass a custom `MigrationRunner` to `app.Start` to take full ownership of the migration process — running OSS migrations alongside their own in whatever order they need. The signature `func(ctx context.Context, url string, vectorEnabled bool) error` is stable. + +3. **Vector track stays separate.** The vector track is conditionally applied and has its own tracking table. Downstream extensions should not modify vector-owned tables (enforced by `TestNoCrossTrackDDL`). + +### What this means for OSS development + +- **Migration immutability is cross-repo.** Once a migration file is merged and tagged, downstream consumers may have synced it. Modifying it breaks their tracking table state. +- **The `MigrationRunner` DI signature is stable.** Changes to this type are breaking for downstream consumers. diff --git a/go/core/pkg/migrations/cross_track_test.go b/go/core/pkg/migrations/cross_track_test.go index 5b588fff3..66575873c 100644 --- a/go/core/pkg/migrations/cross_track_test.go +++ b/go/core/pkg/migrations/cross_track_test.go @@ -85,88 +85,6 @@ func crossTrackViolations(fsys fs.FS, foreignTables map[string]string) ([]violat return violations, err } -// contractingPatterns lists DDL that shrinks or renames the schema. -// These are forbidden in up migrations — schema changes must be additive-only. -// Down migrations are intentionally excluded from this check. -var contractingPatterns = []struct { - name string - re *regexp.Regexp -}{ - // Removing a table entirely. - {"DROP TABLE", regexp.MustCompile(`(?i)\bDROP\s+TABLE\b`)}, - // Removing a column with the explicit COLUMN keyword. - {"DROP COLUMN", regexp.MustCompile(`(?i)\bDROP\s+COLUMN\b`)}, - // Renaming a table changes the name old code references. - {"RENAME TABLE", regexp.MustCompile(`(?i)\bALTER\s+TABLE\b[^;]+\bRENAME\s+TO\b`)}, - // Renaming a column breaks any code or query that still uses the old name. - {"RENAME COLUMN", regexp.MustCompile(`(?i)\bRENAME\s+COLUMN\b`)}, -} - -// alterDropRe captures the word immediately after ALTER TABLE ... DROP [IF EXISTS]. -// The COLUMN keyword is optional in Postgres, so "ALTER TABLE foo DROP bar" is -// a valid column removal. We capture the first word after DROP and check whether -// it is a known non-contracting variant (COLUMN, CONSTRAINT, DEFAULT, NOT). -var alterDropRe = regexp.MustCompile(`(?i)\bALTER\s+TABLE\s+\S+\s+DROP\s+(?:IF\s+EXISTS\s+)?(\w+)`) - -// safeDropKeywords are words that can legitimately follow DROP in an ALTER TABLE -// without removing a column. -var safeDropKeywords = map[string]bool{ - "column": true, // already caught by the DROP COLUMN pattern above - "constraint": true, // removes a constraint, not a column - "default": true, // removes a column default, not the column itself - "not": true, // ALTER TABLE t ALTER COLUMN c DROP NOT NULL -} - -// TestNoContractingDDL enforces the additive-only schema policy: up migrations -// must never remove or rename tables or columns. Down migrations are excluded -// because they exist specifically to reverse schema changes. -func TestNoContractingDDL(t *testing.T) { - tracks := []string{"core", "vector"} - - for _, track := range tracks { - sub, err := fs.Sub(migrations.FS, track) - if err != nil { - t.Fatalf("fs.Sub(%q): %v", track, err) - } - - err = fs.WalkDir(sub, ".", func(path string, d fs.DirEntry, err error) error { - if err != nil || d.IsDir() || !strings.HasSuffix(path, ".up.sql") { - return err - } - data, err := fs.ReadFile(sub, path) - if err != nil { - return err - } - content := string(data) - - for _, p := range contractingPatterns { - if m := p.re.FindString(content); m != "" { - t.Errorf( - "contracting DDL in %s/%s: %q matches %q — up migrations must be additive-only", - track, path, m, p.name, - ) - } - } - - // Check for bare column drops: ALTER TABLE foo DROP bar (no COLUMN keyword). - // RE2 has no negative lookahead, so we capture the word and filter here. - for _, m := range alterDropRe.FindAllStringSubmatch(content, -1) { - if !safeDropKeywords[strings.ToLower(m[1])] { - t.Errorf( - "contracting DDL in %s/%s: %q — bare DROP without COLUMN keyword is a column removal; up migrations must be additive-only", - track, path, m[0], - ) - } - } - - return nil - }) - if err != nil { - t.Fatalf("WalkDir(%q): %v", track, err) - } - } -} - // guardCheck describes a DDL statement that requires an idempotency guard. // re captures the first significant word after the keyword; if that word is not // "if" (case-insensitive) the guard is absent. From 217a91b3dfe4d035ef7f6aea1e33ae1d2a1b05b6 Mon Sep 17 00:00:00 2001 From: Jeremy Alvis Date: Fri, 3 Apr 2026 13:51:12 -0700 Subject: [PATCH 13/16] Remove unused code Signed-off-by: Jeremy Alvis --- .../references/database-migrations.md | 4 +- go/core/internal/dbtest/dbtest.go | 10 ---- go/core/pkg/migrations/migrations.go | 5 +- go/core/pkg/migrations/runner.go | 58 ------------------- 4 files changed, 4 insertions(+), 73 deletions(-) diff --git a/.claude/skills/kagent-dev/references/database-migrations.md b/.claude/skills/kagent-dev/references/database-migrations.md index 9478ff859..9f34b9eb2 100644 --- a/.claude/skills/kagent-dev/references/database-migrations.md +++ b/.claude/skills/kagent-dev/references/database-migrations.md @@ -7,7 +7,7 @@ kagent uses [golang-migrate](https://github.com/golang-migrate/migrate) with emb ``` go/core/pkg/migrations/ ├── migrations.go # Embeds the FS (go:embed); exports FS for downstream consumers -├── runner.go # RunUp / RunDown / RunDownAll / RunVersion / RunForce +├── runner.go # RunUp (applies pending migrations at startup) ├── core/ # Core schema (tracked in schema_migrations table) │ ├── 000001_initial.up.sql / .down.sql │ ├── 000002_add_session_source.up.sql / .down.sql @@ -115,7 +115,7 @@ If the controller crashes mid-migration, the migration runner records the versio 2. Runs the down migration to restore the previous clean state. 3. Re-runs the failed up migration. -**Requirement**: down migrations must be idempotent and correctly reverse their up migration. A missing or broken down migration requires manual recovery using `RunForce`. +**Requirement**: down migrations must be idempotent and correctly reverse their up migration. A missing or broken down migration requires manual recovery. ### Rollout strategy diff --git a/go/core/internal/dbtest/dbtest.go b/go/core/internal/dbtest/dbtest.go index a5573d7ea..43355db1a 100644 --- a/go/core/internal/dbtest/dbtest.go +++ b/go/core/internal/dbtest/dbtest.go @@ -76,13 +76,3 @@ func MigrateT(t *testing.T, connStr string, vectorEnabled bool) { } } -// MigrateDown rolls back all OSS migrations against connStr and returns any error. -// If vectorEnabled is true the vector pass is also rolled back first. -func MigrateDown(connStr string, vectorEnabled bool) error { - if vectorEnabled { - if err := migrations.RunDownAll(connStr, migrations.FS, "vector", "vector_schema_migrations"); err != nil { - return fmt.Errorf("vector down migrations: %w", err) - } - } - return migrations.RunDownAll(connStr, migrations.FS, "core", "schema_migrations") -} diff --git a/go/core/pkg/migrations/migrations.go b/go/core/pkg/migrations/migrations.go index 48746a7f6..50c3f22d5 100644 --- a/go/core/pkg/migrations/migrations.go +++ b/go/core/pkg/migrations/migrations.go @@ -1,6 +1,5 @@ -// Package migrations exports the embedded SQL migration files for the kagent OSS -// database schema. Enterprise builds import this FS to bundle OSS migrations -// alongside enterprise-specific ones at build time. +// Package migrations embeds the SQL migration files for the kagent database schema +// and provides the runner that applies them at startup. package migrations import "embed" diff --git a/go/core/pkg/migrations/runner.go b/go/core/pkg/migrations/runner.go index e062ffa95..208f1d804 100644 --- a/go/core/pkg/migrations/runner.go +++ b/go/core/pkg/migrations/runner.go @@ -33,64 +33,6 @@ func RunUp(url string, migrationsFS fs.FS, vectorEnabled bool) error { return nil } -// RunDown rolls back steps migrations on a single track. -func RunDown(url string, migrationsFS fs.FS, dir, migrationsTable string, steps int) error { - mg, err := newMigrate(url, migrationsFS, dir, migrationsTable) - if err != nil { - return err - } - defer closeMigrate(dir, mg) - - if err := mg.Steps(-steps); err != nil && !errors.Is(err, migrate.ErrNoChange) { - return fmt.Errorf("roll back %d migration(s) for %s: %w", steps, dir, err) - } - return nil -} - -// RunDownAll rolls back all applied migrations on a single track. -func RunDownAll(url string, migrationsFS fs.FS, dir, migrationsTable string) error { - mg, err := newMigrate(url, migrationsFS, dir, migrationsTable) - if err != nil { - return err - } - defer closeMigrate(dir, mg) - - if err := mg.Down(); err != nil && !errors.Is(err, migrate.ErrNoChange) { - return fmt.Errorf("down migrations for %s: %w", dir, err) - } - return nil -} - -// RunVersion returns the current applied version and dirty flag for a single track. -func RunVersion(url string, migrationsFS fs.FS, dir, migrationsTable string) (version uint, dirty bool, err error) { - mg, err := newMigrate(url, migrationsFS, dir, migrationsTable) - if err != nil { - return 0, false, err - } - defer closeMigrate(dir, mg) - - version, dirty, err = mg.Version() - if err != nil && !errors.Is(err, migrate.ErrNilVersion) { - return 0, false, fmt.Errorf("get version for %s: %w", dir, err) - } - return version, dirty, nil -} - -// RunForce forces the tracking table for a single track to version (clears the dirty flag). -// Pass version=-1 to remove the version record entirely. -func RunForce(url string, migrationsFS fs.FS, dir, migrationsTable string, version int) error { - mg, err := newMigrate(url, migrationsFS, dir, migrationsTable) - if err != nil { - return err - } - defer closeMigrate(dir, mg) - - if err := mg.Force(version); err != nil { - return fmt.Errorf("force %s to version %d: %w", dir, version, err) - } - return nil -} - // applyDir runs Up for dir and rolls back on failure. It returns the pre-run // version so the caller can roll back this track if a later track fails. func applyDir(url string, migrationsFS fs.FS, dir, migrationsTable string) (prevVersion uint, err error) { From 46f9fb7c508bae847bbca60686a2cf09c6b5e8dc Mon Sep 17 00:00:00 2001 From: Jeremy Alvis Date: Fri, 3 Apr 2026 14:09:48 -0700 Subject: [PATCH 14/16] Updates based on review Signed-off-by: Jeremy Alvis --- go/core/internal/database/client_postgres.go | 17 ++++---------- go/core/internal/database/gen/feedback.sql.go | 2 +- .../internal/database/gen/langgraph.sql.go | 2 +- go/core/internal/database/gen/models.go | 4 ++-- go/core/internal/dbtest/dbtest.go | 1 - .../pkg/migrations/core/000001_initial.up.sql | 4 ++-- go/core/pkg/migrations/runner.go | 22 ++++++++++--------- 7 files changed, 22 insertions(+), 30 deletions(-) diff --git a/go/core/internal/database/client_postgres.go b/go/core/internal/database/client_postgres.go index c5aaaa10d..68952ddbc 100644 --- a/go/core/internal/database/client_postgres.go +++ b/go/core/internal/database/client_postgres.go @@ -271,11 +271,10 @@ func (c *postgresClient) DeletePushNotification(ctx context.Context, taskID stri // ── Feedback ────────────────────────────────────────────────────────────────── func (c *postgresClient) StoreFeedback(ctx context.Context, feedback *dbpkg.Feedback) error { - isPositive := feedback.IsPositive err := c.q.InsertFeedback(ctx, dbgen.InsertFeedbackParams{ UserID: feedback.UserID, MessageID: feedback.MessageID, - IsPositive: &isPositive, + IsPositive: feedback.IsPositive, FeedbackText: feedback.FeedbackText, IssueType: feedback.IssueType, }) @@ -393,7 +392,6 @@ func (c *postgresClient) DeleteToolServer(ctx context.Context, serverName, group // ── LangGraph Checkpoints ───────────────────────────────────────────────────── func (c *postgresClient) StoreCheckpoint(ctx context.Context, cp *dbpkg.LangGraphCheckpoint) error { - version := cp.Version return c.q.UpsertCheckpoint(ctx, dbgen.UpsertCheckpointParams{ UserID: cp.UserID, ThreadID: cp.ThreadID, @@ -403,7 +401,7 @@ func (c *postgresClient) StoreCheckpoint(ctx context.Context, cp *dbpkg.LangGrap Metadata: cp.Metadata, Checkpoint: cp.Checkpoint, CheckpointType: cp.CheckpointType, - Version: &version, + Version: cp.Version, }) } @@ -742,7 +740,7 @@ func toFeedback(r dbgen.Feedback) *dbpkg.Feedback { DeletedAt: r.DeletedAt, UserID: r.UserID, MessageID: r.MessageID, - IsPositive: derefBool(r.IsPositive), + IsPositive: r.IsPositive, FeedbackText: r.FeedbackText, IssueType: r.IssueType, } @@ -785,7 +783,7 @@ func toCheckpoint(r dbgen.LgCheckpoint) *dbpkg.LangGraphCheckpoint { Metadata: r.Metadata, Checkpoint: r.Checkpoint, CheckpointType: r.CheckpointType, - Version: derefInt64(r.Version), + Version: r.Version, } } @@ -872,10 +870,3 @@ func derefTime(t *time.Time) time.Time { } return time.Time{} } - -func derefBool(b *bool) bool { - if b != nil { - return *b - } - return false -} diff --git a/go/core/internal/database/gen/feedback.sql.go b/go/core/internal/database/gen/feedback.sql.go index 34f4c293e..93ef91cf4 100644 --- a/go/core/internal/database/gen/feedback.sql.go +++ b/go/core/internal/database/gen/feedback.sql.go @@ -19,7 +19,7 @@ VALUES ($1, $2, $3, $4, $5, NOW(), NOW()) type InsertFeedbackParams struct { UserID string MessageID *int64 - IsPositive *bool + IsPositive bool FeedbackText string IssueType *database.FeedbackIssueType } diff --git a/go/core/internal/database/gen/langgraph.sql.go b/go/core/internal/database/gen/langgraph.sql.go index 4453595f8..eb483f138 100644 --- a/go/core/internal/database/gen/langgraph.sql.go +++ b/go/core/internal/database/gen/langgraph.sql.go @@ -253,7 +253,7 @@ type UpsertCheckpointParams struct { Metadata string Checkpoint string CheckpointType string - Version *int64 + Version int64 } func (q *Queries) UpsertCheckpoint(ctx context.Context, arg UpsertCheckpointParams) error { diff --git a/go/core/internal/database/gen/models.go b/go/core/internal/database/gen/models.go index b3cd22eca..92a5585d2 100644 --- a/go/core/internal/database/gen/models.go +++ b/go/core/internal/database/gen/models.go @@ -57,7 +57,7 @@ type Feedback struct { DeletedAt *time.Time UserID string MessageID *int64 - IsPositive *bool + IsPositive bool FeedbackText string IssueType *database.FeedbackIssueType } @@ -74,7 +74,7 @@ type LgCheckpoint struct { Metadata string Checkpoint string CheckpointType string - Version *int64 + Version int64 } type LgCheckpointWrite struct { diff --git a/go/core/internal/dbtest/dbtest.go b/go/core/internal/dbtest/dbtest.go index 43355db1a..3202c39fb 100644 --- a/go/core/internal/dbtest/dbtest.go +++ b/go/core/internal/dbtest/dbtest.go @@ -75,4 +75,3 @@ func MigrateT(t *testing.T, connStr string, vectorEnabled bool) { t.Fatalf("dbtest.MigrateT: %v", err) } } - diff --git a/go/core/pkg/migrations/core/000001_initial.up.sql b/go/core/pkg/migrations/core/000001_initial.up.sql index 3b17c3f1c..a7893e688 100644 --- a/go/core/pkg/migrations/core/000001_initial.up.sql +++ b/go/core/pkg/migrations/core/000001_initial.up.sql @@ -74,7 +74,7 @@ CREATE TABLE IF NOT EXISTS feedback ( deleted_at TIMESTAMPTZ, user_id TEXT NOT NULL, message_id BIGINT, - is_positive BOOLEAN DEFAULT false, + is_positive BOOLEAN NOT NULL DEFAULT false, feedback_text TEXT NOT NULL, issue_type TEXT ); @@ -118,7 +118,7 @@ CREATE TABLE IF NOT EXISTS lg_checkpoint ( metadata TEXT NOT NULL, checkpoint TEXT NOT NULL, checkpoint_type TEXT NOT NULL, - version BIGINT DEFAULT 1, + version BIGINT NOT NULL DEFAULT 1, PRIMARY KEY (user_id, thread_id, checkpoint_ns, checkpoint_id) ); CREATE INDEX IF NOT EXISTS idx_lg_checkpoint_parent_checkpoint_id ON lg_checkpoint(parent_checkpoint_id); diff --git a/go/core/pkg/migrations/runner.go b/go/core/pkg/migrations/runner.go index 208f1d804..b450cc882 100644 --- a/go/core/pkg/migrations/runner.go +++ b/go/core/pkg/migrations/runner.go @@ -5,14 +5,16 @@ import ( "errors" "fmt" "io/fs" - "log" "github.com/golang-migrate/migrate/v4" migratepgx "github.com/golang-migrate/migrate/v4/database/pgx/v5" "github.com/golang-migrate/migrate/v4/source/iofs" _ "github.com/jackc/pgx/v5/stdlib" + ctrl "sigs.k8s.io/controller-runtime" ) +var log = ctrl.Log.WithName("migrations") + // RunUp applies all pending migrations for the given FS. // vectorEnabled controls whether the vector track is also applied. // Returns an error if any track fails (and attempts rollback of previously applied tracks). @@ -24,7 +26,7 @@ func RunUp(url string, migrationsFS fs.FS, vectorEnabled bool) error { if vectorEnabled { if _, err := applyDir(url, migrationsFS, "vector", "vector_schema_migrations"); err != nil { - log.Printf("migrations: rolling back core to version %d", corePrev) + log.Info("rolling back core after vector failure", "targetVersion", corePrev) rollbackDir(url, migrationsFS, "core", "schema_migrations", corePrev) return fmt.Errorf("vector migrations: %w", err) } @@ -52,11 +54,11 @@ func applyDir(url string, migrationsFS fs.FS, dir, migrationsTable string) (prev if errors.Is(upErr, migrate.ErrNoChange) { return prevVersion, nil } - log.Printf("migrations: migration failed for %s, attempting rollback to version %d", dir, prevVersion) + log.Info("migration failed, attempting rollback", "track", dir, "targetVersion", prevVersion) if rbErr := rollbackToVersion(mg, dir, prevVersion); rbErr != nil { - log.Printf("migrations: rollback failed for %s: %v", dir, rbErr) + log.Error(rbErr, "rollback failed", "track", dir) } else { - log.Printf("migrations: rolled back %s to version %d", dir, prevVersion) + log.Info("rollback complete", "track", dir, "version", prevVersion) } return prevVersion, fmt.Errorf("run migrations for %s: %w", dir, upErr) } @@ -68,14 +70,14 @@ func applyDir(url string, migrationsFS fs.FS, dir, migrationsTable string) (prev func rollbackDir(url string, migrationsFS fs.FS, dir, migrationsTable string, targetVersion uint) { mg, err := newMigrate(url, migrationsFS, dir, migrationsTable) if err != nil { - log.Printf("migrations: rollback of %s failed (open): %v", dir, err) + log.Error(err, "rollback failed (open)", "track", dir) return } defer closeMigrate(dir, mg) if err := rollbackToVersion(mg, dir, targetVersion); err != nil { - log.Printf("migrations: rollback of %s failed: %v", dir, err) + log.Error(err, "rollback failed", "track", dir) } else { - log.Printf("migrations: rolled back %s to version %d", dir, targetVersion) + log.Info("rollback complete", "track", dir, "version", targetVersion) } } @@ -151,9 +153,9 @@ func newMigrate(url string, migrationsFS fs.FS, dir, migrationsTable string) (*m func closeMigrate(dir string, mg *migrate.Migrate) { srcErr, dbErr := mg.Close() if srcErr != nil { - log.Printf("warning: closing migration source for %s: %v", dir, srcErr) + log.Error(srcErr, "closing migration source", "track", dir) } if dbErr != nil { - log.Printf("warning: closing migration database for %s: %v", dir, dbErr) + log.Error(dbErr, "closing migration database", "track", dir) } } From ffcafc0f2dfbcd51c84d668aa1292073c7463eff Mon Sep 17 00:00:00 2001 From: Jeremy Alvis Date: Fri, 3 Apr 2026 14:13:26 -0700 Subject: [PATCH 15/16] Small update Signed-off-by: Jeremy Alvis --- .claude/skills/kagent-dev/references/database-migrations.md | 1 - 1 file changed, 1 deletion(-) diff --git a/.claude/skills/kagent-dev/references/database-migrations.md b/.claude/skills/kagent-dev/references/database-migrations.md index 9f34b9eb2..34de1c1c0 100644 --- a/.claude/skills/kagent-dev/references/database-migrations.md +++ b/.claude/skills/kagent-dev/references/database-migrations.md @@ -153,4 +153,3 @@ The migration layer is designed for downstream consumers to extend with their ow ### What this means for OSS development - **Migration immutability is cross-repo.** Once a migration file is merged and tagged, downstream consumers may have synced it. Modifying it breaks their tracking table state. -- **The `MigrationRunner` DI signature is stable.** Changes to this type are breaking for downstream consumers. From f54084a3f9ef82a9dde4ff641a8ab93d8aada149 Mon Sep 17 00:00:00 2001 From: Jeremy Alvis Date: Fri, 3 Apr 2026 14:47:27 -0700 Subject: [PATCH 16/16] Split out a few column modifications to bring gorm in-line with clean install Signed-off-by: Jeremy Alvis --- .../pkg/migrations/core/000002_not_null_defaults.down.sql | 2 ++ .../pkg/migrations/core/000002_not_null_defaults.up.sql | 8 ++++++++ 2 files changed, 10 insertions(+) create mode 100644 go/core/pkg/migrations/core/000002_not_null_defaults.down.sql create mode 100644 go/core/pkg/migrations/core/000002_not_null_defaults.up.sql diff --git a/go/core/pkg/migrations/core/000002_not_null_defaults.down.sql b/go/core/pkg/migrations/core/000002_not_null_defaults.down.sql new file mode 100644 index 000000000..d4db72970 --- /dev/null +++ b/go/core/pkg/migrations/core/000002_not_null_defaults.down.sql @@ -0,0 +1,2 @@ +ALTER TABLE lg_checkpoint ALTER COLUMN version DROP NOT NULL; +ALTER TABLE feedback ALTER COLUMN is_positive DROP NOT NULL; diff --git a/go/core/pkg/migrations/core/000002_not_null_defaults.up.sql b/go/core/pkg/migrations/core/000002_not_null_defaults.up.sql new file mode 100644 index 000000000..bd78b4b5b --- /dev/null +++ b/go/core/pkg/migrations/core/000002_not_null_defaults.up.sql @@ -0,0 +1,8 @@ +-- Backfill any NULLs (none expected, but safe) then add NOT NULL constraints. +-- These columns always had DEFAULT values but were missing NOT NULL in 000001. + +UPDATE feedback SET is_positive = false WHERE is_positive IS NULL; +ALTER TABLE feedback ALTER COLUMN is_positive SET NOT NULL; + +UPDATE lg_checkpoint SET version = 1 WHERE version IS NULL; +ALTER TABLE lg_checkpoint ALTER COLUMN version SET NOT NULL;