diff --git a/table.go b/table.go new file mode 100644 index 0000000..a98a31b --- /dev/null +++ b/table.go @@ -0,0 +1,404 @@ +package pgkit + +import ( + "context" + "errors" + "fmt" + "iter" + "slices" + "time" + + sq "github.com/Masterminds/squirrel" + "github.com/jackc/pgx/v5" +) + +// ID is a comparable type used for record IDs. +type ID comparable + +// Records must be a pointer with the methods defined on the pointer. +type Record[T any, I ID] interface { + *T // Enforce T is a pointer. + GetID() I + Validate() error +} + +// Table provides basic CRUD operations for database records. +type Table[T any, P Record[T, I], I ID] struct { + *DB + Name string + IDColumn string + Paginator Paginator[P] +} + +// helpers for setting timestamp fields +type ( + hasSetCreatedAt interface { + SetCreatedAt(time.Time) + } + hasSetUpdatedAt interface { + SetUpdatedAt(time.Time) + } + hasSetDeletedAt interface { + SetDeletedAt(time.Time) + } +) + +// Save inserts or updates given records. Auto-detects insert vs update by ID based on zerovalue of ID from GetID() method on record. +func (t *Table[T, P, I]) Save(ctx context.Context, records ...P) error { + switch len(records) { + case 0: + return nil + case 1: + return t.saveOne(ctx, records[0]) + default: + return t.saveAll(ctx, records) + } +} + +func (t *Table[T, P, I]) saveOne(ctx context.Context, record P) error { + if record == nil { + return fmt.Errorf("record is nil") + } + + if err := record.Validate(); err != nil { + return fmt.Errorf("validate record: %w", err) + } + + if row, ok := any(record).(hasSetUpdatedAt); ok { + row.SetUpdatedAt(time.Now().UTC()) + } + + // Insert + var zero I + if record.GetID() == zero { + q := t.SQL. + InsertRecord(record). + Into(t.Name). + Suffix("RETURNING *") + + if err := t.Query.GetOne(ctx, q, record); err != nil { + return fmt.Errorf("save: insert record: %w", err) + } + + return nil + } + + // Update + q := t.SQL.UpdateRecord(record, sq.Eq{t.IDColumn: record.GetID()}, t.Name) + if _, err := t.Query.Exec(ctx, q); err != nil { + return fmt.Errorf("save: update record: %w", err) + } + + return nil +} + +const chunkSize = 1000 + +func (t *Table[T, P, I]) saveAll(ctx context.Context, records []P) error { + now := time.Now().UTC() + + insertRecords := make([]P, 0) + insertIndices := make([]int, 0) // keep track of original indices, so we can update the records with IDs in passed slice + + updateQueries := make(Queries, 0) + + for i, r := range records { + if r == nil { + return fmt.Errorf("record with index=%d is nil", i) + } + + if err := r.Validate(); err != nil { + return fmt.Errorf("validate record: %w", err) + } + + if row, ok := any(r).(hasSetUpdatedAt); ok { + row.SetUpdatedAt(now) + } + + var zero I + if r.GetID() == zero { + if row, ok := any(r).(hasSetCreatedAt); ok { + row.SetCreatedAt(now) + } + + insertRecords = append(insertRecords, r) + insertIndices = append(insertIndices, i) // remember index + } else { + updateQueries.Add(t.SQL. + UpdateRecord(r, sq.Eq{"id": r.GetID()}, t.Name). + SuffixExpr(sq.Expr(" RETURNING *")), + ) + } + } + + // Handle inserts in chunks, has to be done manually, slices.Chunk does not return index :/ + for start := 0; start < len(insertRecords); start += chunkSize { + end := start + chunkSize + if end > len(insertRecords) { + end = len(insertRecords) + } + + chunk := insertRecords[start:end] + q := t.SQL. + InsertRecords(chunk). + Into(t.Name). + SuffixExpr(sq.Expr(" RETURNING *")) + + if err := t.Query.GetAll(ctx, q, &chunk); err != nil { + return fmt.Errorf("insert records: %w", err) + } + + // update original slice + for i, rr := range chunk { + records[insertIndices[start+i]] = rr + } + } + + if len(updateQueries) > 0 { + for chunk := range slices.Chunk(updateQueries, chunkSize) { + if _, err := t.Query.BatchExec(ctx, chunk); err != nil { + return fmt.Errorf("update records: %w", err) + } + } + } + + return nil +} + +// getListQuery builds a base select query for listing records. +func (t *Table[T, P, I]) getListQuery(where sq.Sqlizer, orderBy []string) sq.SelectBuilder { + if len(orderBy) == 0 { + orderBy = []string{t.IDColumn} + } + + q := t.SQL. + Select("*"). + From(t.Name). + Where(where). + OrderBy(orderBy...) + return q +} + +// Get returns the first record matching the condition. +func (t *Table[T, P, I]) Get(ctx context.Context, where sq.Sqlizer, orderBy []string) (P, error) { + record := new(T) + + q := t.getListQuery(where, orderBy).Limit(1) + + if err := t.Query.GetOne(ctx, q, record); err != nil { + return nil, fmt.Errorf("get record: %w", err) + } + + return record, nil +} + +// List returns all records matching the condition. +func (t *Table[T, P, I]) List(ctx context.Context, where sq.Sqlizer, orderBy []string) ([]P, error) { + q := t.getListQuery(where, orderBy) + var records []P + if err := t.Query.GetAll(ctx, q, &records); err != nil { + return nil, err + } + + return records, nil +} + +// ListPaged returns paginated records matching the condition. +func (t *Table[T, P, I]) ListPaged(ctx context.Context, where sq.Sqlizer, page *Page) ([]P, *Page, error) { + q := t.SQL.Select("*").From(t.Name).Where(where) + + result, q := t.Paginator.PrepareQuery(q, page) + if err := t.Query.GetAll(ctx, q, &result); err != nil { + return nil, nil, err + } + result = t.Paginator.PrepareResult(result, page) + return result, page, nil +} + +// Iter returns an iterator for records matching the condition. +func (t *Table[T, P, I]) Iter(ctx context.Context, where sq.Sqlizer, orderBy []string) (iter.Seq2[P, error], error) { + q := t.getListQuery(where, orderBy) + rows, err := t.Query.QueryRows(ctx, q) + if err != nil { + return nil, fmt.Errorf("query rows: %w", err) + } + + return func(yield func(P, error) bool) { + defer rows.Close() + for rows.Next() { + var record T + if err := t.Query.Scan.ScanOne(&record, rows); err != nil { + if !errors.Is(err, pgx.ErrNoRows) { + yield(nil, err) + } + return + } + if !yield(&record, nil) { + return + } + } + }, nil +} + +// GetByID returns a record by its ID. +func (t *Table[T, P, I]) GetByID(ctx context.Context, id I) (P, error) { + return t.Get(ctx, sq.Eq{t.IDColumn: id}, []string{t.IDColumn}) +} + +// ListByIDs returns records by their IDs. +func (t *Table[T, P, I]) ListByIDs(ctx context.Context, ids []I) ([]P, error) { + return t.List(ctx, sq.Eq{t.IDColumn: ids}, nil) +} + +// Count returns the number of matching records. +func (t *Table[T, P, I]) Count(ctx context.Context, where sq.Sqlizer) (uint64, error) { + var count uint64 + q := t.SQL. + Select("COUNT(1)"). + From(t.Name). + Where(where) + + if err := t.Query.GetOne(ctx, q, &count); err != nil { + return count, fmt.Errorf("count: %w", err) + } + + return count, nil +} + +// DeleteByID deletes a record by ID. Uses soft delete if .SetDeletedAt() method exists. +func (t *Table[T, P, I]) DeleteByID(ctx context.Context, id I) error { + record, err := t.GetByID(ctx, id) + if err != nil { + return fmt.Errorf("delete: %w", err) + } + + // Soft delete. + if row, ok := any(record).(hasSetDeletedAt); ok { + row.SetDeletedAt(time.Now().UTC()) + if err := t.Save(ctx, record); err != nil { + return fmt.Errorf("soft delete: %w", err) + } + return nil + } + + // Hard delete for tables without timestamps. + return t.HardDeleteByID(ctx, id) +} + +// HardDeleteByID permanently deletes a record by ID. +func (t *Table[T, P, I]) HardDeleteByID(ctx context.Context, id I) error { + q := t.SQL.Delete(t.Name).Where(sq.Eq{t.IDColumn: id}) + if _, err := t.Query.Exec(ctx, q); err != nil { + return fmt.Errorf("hard delete: %w", err) + } + return nil +} + +// WithPaginator returns a table instance with the given paginator. +func (t *Table[T, P, I]) WithPaginator(opts ...PaginatorOption) *Table[T, P, I] { + return &Table[T, P, I]{ + DB: t.DB, + Name: t.Name, + IDColumn: t.IDColumn, + Paginator: NewPaginator[P](opts...), + } +} + +// WithTx returns a table instance bound to the given transaction. +func (t *Table[T, P, I]) WithTx(tx pgx.Tx) *Table[T, P, I] { + return &Table[T, P, I]{ + DB: &DB{ + Conn: t.DB.Conn, + SQL: t.DB.SQL, + Query: t.DB.TxQuery(tx), + }, + Name: t.Name, + IDColumn: t.IDColumn, + Paginator: t.Paginator, + } +} + +// LockForUpdate locks and updates one record using PostgreSQL's FOR UPDATE SKIP LOCKED pattern +// within a database transaction for safe concurrent processing. The record is processed exactly +// once across multiple workers. The record is automatically updated after updateFn() completes. +// +// Keep updateFn() fast to avoid holding the transaction. For long-running work, update status +// to "processing" and return early, then process asynchronously. Use defer LockForUpdate() +// to update status to "completed" or "failed". +// +// Returns ErrNoRows if no matching records are available for locking. +func (t *Table[T, P, I]) LockForUpdate(ctx context.Context, where sq.Sqlizer, orderBy []string, updateFn func(record P)) error { + var noRows bool + + err := t.LockForUpdates(ctx, where, orderBy, 1, func(records []P) { + if len(records) > 0 { + updateFn(records[0]) + } else { + noRows = true + } + }) + if err != nil { + return err //nolint:wrapcheck + } + + if noRows { + return ErrNoRows + } + + return nil +} + +// LockForUpdates locks and updates records using PostgreSQL's FOR UPDATE SKIP LOCKED pattern +// within a database transaction for safe concurrent processing. Each record is processed exactly +// once across multiple workers. Records are automatically updated after updateFn() completes. +// +// Keep updateFn() fast to avoid holding the transaction. For long-running work, update status +// to "processing" and return early, then process asynchronously. Use defer LockForUpdate() +// to update status to "completed" or "failed". +func (t *Table[T, P, I]) LockForUpdates(ctx context.Context, where sq.Sqlizer, orderBy []string, limit uint64, updateFn func(records []P)) error { + // Check if we're already in a transaction + if t.DB.Query.Tx != nil { + if err := t.lockForUpdatesWithTx(ctx, t.DB.Query.Tx, where, orderBy, limit, updateFn); err != nil { + return fmt.Errorf("lock for update (with tx): %w", err) + } + } + + return pgx.BeginFunc(ctx, t.DB.Conn, func(pgTx pgx.Tx) error { + if err := t.lockForUpdatesWithTx(ctx, pgTx, where, orderBy, limit, updateFn); err != nil { + return fmt.Errorf("lock for update (new tx): %w", err) + } + return nil + }) +} + +func (t *Table[T, P, I]) lockForUpdatesWithTx(ctx context.Context, pgTx pgx.Tx, where sq.Sqlizer, orderBy []string, limit uint64, updateFn func(records []P)) error { + if len(orderBy) == 0 { + orderBy = []string{t.IDColumn} + } + + q := t.SQL. + Select("*"). + From(t.Name). + Where(where). + OrderBy(orderBy...). + Limit(limit). + Suffix("FOR UPDATE SKIP LOCKED") + + txQuery := t.DB.TxQuery(pgTx) + + var records []P + if err := txQuery.GetAll(ctx, q, &records); err != nil { + return fmt.Errorf("select for update skip locked: %w", err) + } + + updateFn(records) + + for _, record := range records { + q := t.SQL.UpdateRecord(record, sq.Eq{t.IDColumn: record.GetID()}, t.Name) + if _, err := txQuery.Exec(ctx, q); err != nil { + return fmt.Errorf("update record: %w", err) + } + } + + return nil +} diff --git a/tests/database_test.go b/tests/database_test.go new file mode 100644 index 0000000..daefa63 --- /dev/null +++ b/tests/database_test.go @@ -0,0 +1,46 @@ +package pgkit_test + +import ( + "context" + + "github.com/goware/pgkit/v2" + "github.com/jackc/pgx/v5" +) + +type Database struct { + *pgkit.DB + + Accounts *accountsTable + Articles *articlesTable + Reviews *reviewsTable +} + +func initDB(db *pgkit.DB) *Database { + return &Database{ + DB: db, + Accounts: &accountsTable{Table: &pgkit.Table[Account, *Account, int64]{DB: db, Name: "accounts", IDColumn: "id"}}, + Articles: &articlesTable{Table: &pgkit.Table[Article, *Article, uint64]{DB: db, Name: "articles", IDColumn: "id"}}, + Reviews: &reviewsTable{Table: &pgkit.Table[Review, *Review, uint64]{DB: db, Name: "reviews", IDColumn: "id"}}, + } +} + +func (db *Database) BeginTx(ctx context.Context, fn func(tx *Database) error) error { + return pgx.BeginFunc(ctx, db.Conn, func(pgTx pgx.Tx) error { + tx := db.WithTx(pgTx) + return fn(tx) + }) +} + +func (db *Database) WithTx(tx pgx.Tx) *Database { + pgkitDB := &pgkit.DB{ + Conn: db.Conn, + SQL: db.SQL, + Query: db.TxQuery(tx), + } + + return initDB(pgkitDB) +} + +func (db *Database) Close() { + db.DB.Conn.Close() +} diff --git a/tests/helpers_test.go b/tests/helpers_test.go index f25ced6..cde9604 100644 --- a/tests/helpers_test.go +++ b/tests/helpers_test.go @@ -8,6 +8,14 @@ import ( "github.com/stretchr/testify/assert" ) +func truncateAllTables(t *testing.T) { + truncateTable(t, "accounts") + truncateTable(t, "reviews") + truncateTable(t, "logs") + truncateTable(t, "stats") + truncateTable(t, "articles") +} + func truncateTable(t *testing.T, tableName string) { _, err := DB.Conn.Exec(context.Background(), fmt.Sprintf(`TRUNCATE TABLE %q CASCADE`, tableName)) assert.NoError(t, err) diff --git a/tests/pgkit_test.go b/tests/pgkit_test.go index 281f387..c6c8530 100644 --- a/tests/pgkit_test.go +++ b/tests/pgkit_test.go @@ -307,8 +307,16 @@ func TestRecordsWithJSONB(t *testing.T) { func TestRecordsWithJSONStruct(t *testing.T) { truncateTable(t, "articles") + account := &Account{ + Name: "TestRecordsWithJSONStruct", + } + err := DB.Query.QueryRow(context.Background(), DB.SQL.InsertRecord(account).Suffix(`RETURNING "id"`)).Scan(&account.ID) + assert.NoError(t, err) + assert.True(t, account.ID > 0) + article := &Article{ - Author: "Gary", + AccountID: account.ID, + Author: "Gary", Content: Content{ Title: "How to cook pizza", Body: "flour+water+salt+yeast+cheese", @@ -319,7 +327,7 @@ func TestRecordsWithJSONStruct(t *testing.T) { cols, _, err := pgkit.Map(article) assert.NoError(t, err) sort.Strings(cols) - assert.Equal(t, []string{"alias", "author", "content"}, cols) + assert.Equal(t, []string{"account_id", "alias", "author", "content", "deleted_at"}, cols) // Insert record q1 := DB.SQL.InsertRecord(article, "articles") diff --git a/tests/schema_test.go b/tests/schema_test.go index 3c40bf5..dd2c3ee 100644 --- a/tests/schema_test.go +++ b/tests/schema_test.go @@ -1,6 +1,7 @@ package pgkit_test import ( + "fmt" "time" "github.com/goware/pgkit/v2/dbtype" @@ -11,19 +12,85 @@ type Account struct { Name string `db:"name"` Disabled bool `db:"disabled"` CreatedAt time.Time `db:"created_at,omitempty"` // ,omitempty will rely on postgres DEFAULT + UpdatedAt time.Time `db:"updated_at,omitempty"` // ,omitempty will rely on postgres DEFAULT } -func (a *Account) DBTableName() string { - return "accounts" +func (a *Account) DBTableName() string { return "accounts" } +func (a *Account) GetID() int64 { return a.ID } +func (a *Account) SetUpdatedAt(t time.Time) { a.UpdatedAt = t } + +func (a *Account) Validate() error { + if a.Name == "" { + return fmt.Errorf("name is required") + } + + return nil +} + +type Article struct { + ID uint64 `db:"id,omitempty"` + Author string `db:"author"` + Alias *string `db:"alias"` + Content Content `db:"content"` // using JSONB postgres datatype + AccountID int64 `db:"account_id"` + CreatedAt time.Time `db:"created_at,omitempty"` // ,omitempty will rely on postgres DEFAULT + UpdatedAt time.Time `db:"updated_at,omitempty"` // ,omitempty will rely on postgres DEFAULT + DeletedAt *time.Time `db:"deleted_at"` +} + +func (a *Article) GetID() uint64 { return a.ID } +func (a *Article) SetUpdatedAt(t time.Time) { a.UpdatedAt = t } +func (a *Article) SetDeletedAt(t time.Time) { a.DeletedAt = &t } + +func (a *Article) Validate() error { + if a.Author == "" { + return fmt.Errorf("author is required") + } + + return nil +} + +type Content struct { + Title string `json:"title"` + Body string `json:"body"` + Views int64 `json:"views"` } type Review struct { - ID int64 `db:"id,omitempty"` - Name string `db:"name"` - Comments string `db:"comments"` - CreatedAt time.Time `db:"created_at"` // if unset, will store Go zero-value + ID uint64 `db:"id,omitempty"` + Comment string `db:"comment"` + Status ReviewStatus `db:"status"` + Sentiment int64 `db:"sentiment"` + AccountID int64 `db:"account_id"` + ArticleID uint64 `db:"article_id"` + ProcessedAt *time.Time `db:"processed_at"` + CreatedAt time.Time `db:"created_at,omitempty"` // ,omitempty will rely on postgres DEFAULT + UpdatedAt time.Time `db:"updated_at,omitempty"` // ,omitempty will rely on postgres DEFAULT + DeletedAt *time.Time `db:"deleted_at"` } +func (r *Review) GetID() uint64 { return r.ID } +func (r *Review) SetUpdatedAt(t time.Time) { r.UpdatedAt = t } +func (r *Review) SetDeletedAt(t time.Time) { r.DeletedAt = &t } + +func (r *Review) Validate() error { + if len(r.Comment) < 3 { + return fmt.Errorf("comment too short") + } + + return nil +} + +type ReviewStatus int64 + +const ( + ReviewStatusPending ReviewStatus = iota + ReviewStatusProcessing + ReviewStatusApproved + ReviewStatusRejected + ReviewStatusFailed +) + type Log struct { ID int64 `db:"id,omitempty"` Message string `db:"message"` @@ -38,16 +105,3 @@ type Stat struct { Num dbtype.BigInt `db:"big_num"` // using NUMERIC(78,0) postgres datatype Rating dbtype.BigInt `db:"rating"` // using NUMERIC(78,0) postgres datatype } - -type Article struct { - ID int64 `db:"id,omitempty"` - Author string `db:"author"` - Alias *string `db:"alias"` - Content Content `db:"content"` // using JSONB postgres datatype -} - -type Content struct { - Title string `json:"title"` - Body string `json:"body"` - Views int64 `json:"views"` -} diff --git a/tests/table_test.go b/tests/table_test.go new file mode 100644 index 0000000..bf3aa70 --- /dev/null +++ b/tests/table_test.go @@ -0,0 +1,391 @@ +package pgkit_test + +import ( + "fmt" + "slices" + "sync" + "testing" + + sq "github.com/Masterminds/squirrel" + "github.com/goware/pgkit/v2" + "github.com/jackc/pgx/v5" + "github.com/stretchr/testify/require" +) + +func TestTable(t *testing.T) { + truncateAllTables(t) + + ctx := t.Context() + db := initDB(DB) + + t.Run("Simple CRUD", func(t *testing.T) { + account := &Account{ + Name: "Save Account", + } + + // Create. + err := db.Accounts.Save(ctx, account) + require.NoError(t, err, "Create failed") + require.NotZero(t, account.ID, "ID should be set") + require.NotZero(t, account.CreatedAt, "CreatedAt should be set") + require.NotZero(t, account.UpdatedAt, "UpdatedAt should be set") + + // Check count. + count, err := db.Accounts.Count(ctx, nil) + require.NoError(t, err, "FindAll failed") + require.Equal(t, uint64(1), count, "Expected 1 account") + + // Read from DB & check for equality. + accountCheck, err := db.Accounts.GetByID(ctx, account.ID) + require.NoError(t, err, "FindByID failed") + require.Equal(t, account.ID, accountCheck.ID, "account ID should match") + require.Equal(t, account.Name, accountCheck.Name, "account name should match") + + // Update. + account.Name = "Updated account" + err = db.Accounts.Save(ctx, account) + require.NoError(t, err, "Save failed") + + // Read from DB & check for equality again. + accountCheck, err = db.Accounts.GetByID(ctx, account.ID) + require.NoError(t, err, "FindByID failed") + require.Equal(t, account.ID, accountCheck.ID, "account ID should match") + require.Equal(t, account.Name, accountCheck.Name, "account name should match") + + // Check count again. + count, err = db.Accounts.Count(ctx, nil) + require.NoError(t, err, "FindAll failed") + require.Equal(t, uint64(1), count, "Expected 1 account") + + // Iterate all accounts. + iter, err := db.Accounts.Iter(ctx, nil, nil) + require.NoError(t, err, "Iter failed") + var accounts []Account + for account, err := range iter { + require.NoError(t, err, "Iter error") + accounts = append(accounts, *account) + } + }) + + t.Run("Save multiple", func(t *testing.T) { + t.Parallel() + // Create account. + account := &Account{Name: "Save Multiple Account"} + err := db.Accounts.Save(ctx, account) + require.NoError(t, err, "Create account failed") + articles := []*Article{ + {Author: "FirstNew", AccountID: account.ID}, + {Author: "SecondNew", AccountID: account.ID}, + {ID: 10001, Author: "FirstOld", AccountID: account.ID}, + {ID: 10002, Author: "SecondOld", AccountID: account.ID}, + } + err = db.Articles.Save(ctx, articles...) + require.NoError(t, err, "Save articles") + require.NotZero(t, articles[0].ID, "ID should be set") + require.NotZero(t, articles[1].ID, "ID should be set") + require.Equal(t, uint64(10001), articles[2].ID, "ID should be same") + require.Equal(t, uint64(10002), articles[3].ID, "ID should be same") + // test update for multiple records + updateArticles := []*Article{ + articles[0], + articles[1], + } + updateArticles[0].Author = "Updated Author Name 1" + updateArticles[1].Author = "Updated Author Name 2" + err = db.Articles.Save(ctx, updateArticles...) + require.NoError(t, err, "Save articles") + updateArticle0, err := db.Articles.GetByID(ctx, articles[0].ID) + require.NoError(t, err, "Get By ID") + require.Equal(t, updateArticles[0].Author, updateArticle0.Author, "Author should be same") + updateArticle1, err := db.Articles.GetByID(ctx, articles[1].ID) + require.NoError(t, err, "Get By ID") + require.Equal(t, updateArticles[1].Author, updateArticle1.Author, "Author should be same") + }) + + t.Run("Complex Transaction", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + + err := db.BeginTx(ctx, func(tx *Database) error { + // Create account. + account := &Account{Name: "Complex Transaction Account"} + err := tx.Accounts.Save(ctx, account) + require.NoError(t, err, "Create account failed") + + articles := []*Article{ + {Author: "First", AccountID: account.ID}, + {Author: "Second", AccountID: account.ID}, + {Author: "Third", AccountID: account.ID}, + } + + // Save articles (3x insert). + err = tx.Articles.Save(ctx, articles...) + require.NoError(t, err, "Save failed") + + for _, article := range articles { + require.NotZero(t, article.ID, "ID should be set") + require.NotZero(t, article.CreatedAt, "CreatedAt should be set") + require.NotZero(t, article.UpdatedAt, "UpdatedAt should be set") + } + + firstArticle := articles[0] + + // Save articles (3x update, 1x insert). + articles = append(articles, &Article{Author: "Fourth", AccountID: account.ID}) + err = tx.Articles.Save(ctx, articles...) + require.NoError(t, err, "Save failed") + + for _, article := range articles { + require.NotZero(t, article.ID, "ID should be set") + require.NotZero(t, article.CreatedAt, "CreatedAt should be set") + require.NotZero(t, article.UpdatedAt, "UpdatedAt should be set") + } + require.Equal(t, firstArticle.ID, articles[0].ID, "First article ID should be the same") + + // Verify we can load all articles with .GetById() + for _, article := range articles { + articleCheck, err := tx.Articles.GetByID(ctx, article.ID) + require.NoError(t, err, "GetByID failed") + require.Equal(t, article.ID, articleCheck.ID, "Article ID should match") + require.Equal(t, article.Author, articleCheck.Author, "Article Author should match") + require.Equal(t, article.AccountID, articleCheck.AccountID, "Article AccountID should match") + require.Equal(t, article.CreatedAt, articleCheck.CreatedAt, "Article CreatedAt should match") + // require.Equal(t, article.UpdatedAt, articleCheck.UpdatedAt, "Article UpdatedAt should match") + // require.NotEqual(t, article.UpdatedAt, articleCheck.UpdatedAt, "Article UpdatedAt shouldn't match") // The .Save() aboe updates the timestamp. + require.Equal(t, article.DeletedAt, articleCheck.DeletedAt, "Article DeletedAt should match") + } + + // Verify we can load all articles with .ListByIDs() + articleIDs := make([]uint64, len(articles)) + for _, article := range articles { + articleIDs = append(articleIDs, article.ID) + } + articlesCheck, err := tx.Articles.ListByIDs(ctx, articleIDs) + require.NoError(t, err, "ListByIDs failed") + require.Equal(t, len(articles), len(articlesCheck), "Number of articles should match") + for i := range articlesCheck { + require.Equal(t, articles[i].ID, articlesCheck[i].ID, "Article ID should match") + require.Equal(t, articles[i].Author, articlesCheck[i].Author, "Article Author should match") + require.Equal(t, articles[i].AccountID, articlesCheck[i].AccountID, "Article AccountID should match") + require.Equal(t, articles[i].CreatedAt, articlesCheck[i].CreatedAt, "Article CreatedAt should match") + // require.Equal(t, articles[i].UpdatedAt, articlesCheck[i].UpdatedAt, "Article UpdatedAt should match") + require.Equal(t, articles[i].DeletedAt, articlesCheck[i].DeletedAt, "Article DeletedAt should match") + } + + // Soft-delete first article. + err = tx.Articles.DeleteByID(ctx, firstArticle.ID) + require.NoError(t, err, "DeleteByID failed") + + // Check if article is soft-deleted. + article, err := tx.Articles.GetByID(ctx, firstArticle.ID) + require.NoError(t, err, "GetByID failed") + require.Equal(t, firstArticle.ID, article.ID, "DeletedAt should be set") + require.NotNil(t, article.DeletedAt, "DeletedAt should be set") + + // Hard-delete first article. + err = tx.Articles.HardDeleteByID(ctx, firstArticle.ID) + require.NoError(t, err, "HardDeleteByID failed") + + // Check if article is hard-deleted. + article, err = tx.Articles.GetByID(ctx, firstArticle.ID) + require.Error(t, err, "article was not hard-deleted") + require.Nil(t, article, "article is not nil") + + return nil + }) + require.NoError(t, err, "SaveTx transaction failed") + }) + + t.Run("ListPaged", func(t *testing.T) { + ctx := t.Context() + + account := &Account{Name: "ListPaged Account"} + err := db.Accounts.Save(ctx, account) + require.NoError(t, err) + + // Create 15 articles. + for i := range 15 { + err := db.Articles.Save(ctx, &Article{ + AccountID: account.ID, + Author: fmt.Sprintf("Author %02d", i), + }) + require.NoError(t, err) + } + + // Default paginator (page size 10). + page := pgkit.NewPage(0, 1) + results, retPage, err := db.Articles.ListPaged(ctx, sq.Eq{"account_id": account.ID}, page) + require.NoError(t, err) + require.Len(t, results, 10) + require.True(t, retPage.More, "should have more pages") + + // Second page. + page2 := pgkit.NewPage(0, 2) + results2, retPage2, err := db.Articles.ListPaged(ctx, sq.Eq{"account_id": account.ID}, page2) + require.NoError(t, err) + require.Len(t, results2, 5) + require.False(t, retPage2.More, "should not have more pages") + + // No overlap between pages. + for _, r1 := range results { + for _, r2 := range results2 { + require.NotEqual(t, r1.ID, r2.ID, "pages should not overlap") + } + } + }) + + t.Run("WithPaginator", func(t *testing.T) { + ctx := t.Context() + + account := &Account{Name: "WithPaginator Account"} + err := db.Accounts.Save(ctx, account) + require.NoError(t, err) + + for i := range 10 { + err := db.Articles.Save(ctx, &Article{ + AccountID: account.ID, + Author: fmt.Sprintf("PagAuthor %02d", i), + }) + require.NoError(t, err) + } + + // Use a custom paginator with page size 3. + pagedTable := db.Articles.Table.WithPaginator(pgkit.WithDefaultSize(3), pgkit.WithMaxSize(5)) + + page := pgkit.NewPage(0, 1) + results, retPage, err := pagedTable.ListPaged(ctx, sq.Eq{"account_id": account.ID}, page) + require.NoError(t, err) + require.Len(t, results, 3, "should return 3 records with custom paginator") + require.True(t, retPage.More) + + // Request size larger than max should be capped. + bigPage := pgkit.NewPage(100, 1) + results, _, err = pagedTable.ListPaged(ctx, sq.Eq{"account_id": account.ID}, bigPage) + require.NoError(t, err) + require.Len(t, results, 5, "should be capped at max size 5") + }) + + t.Run("WithTx preserves Paginator", func(t *testing.T) { + ctx := t.Context() + + account := &Account{Name: "WithTx Paginator Account"} + err := db.Accounts.Save(ctx, account) + require.NoError(t, err) + + for i := range 5 { + err := db.Articles.Save(ctx, &Article{ + AccountID: account.ID, + Author: fmt.Sprintf("TxPag %02d", i), + }) + require.NoError(t, err) + } + + pagedTable := db.Articles.Table.WithPaginator(pgkit.WithDefaultSize(2)) + + err = pgx.BeginFunc(ctx, db.Conn, func(pgTx pgx.Tx) error { + txTable := pagedTable.WithTx(pgTx) + page := pgkit.NewPage(0, 1) + results, retPage, err := txTable.ListPaged(ctx, sq.Eq{"account_id": account.ID}, page) + require.NoError(t, err) + require.Len(t, results, 2, "paginator should be preserved through WithTx") + require.True(t, retPage.More) + return nil + }) + require.NoError(t, err) + }) + + t.Run("WithTx keeps IDColumn", func(t *testing.T) { + ctx := t.Context() + + account := &Account{Name: "WithTx IDColumn Account"} + err := db.Accounts.Save(ctx, account) + require.NoError(t, err, "create account failed") + + article := &Article{AccountID: account.ID, Author: "WithTx author"} + err = db.Articles.Save(ctx, article) + require.NoError(t, err, "create article failed") + + err = pgx.BeginFunc(ctx, db.Conn, func(pgTx pgx.Tx) error { + txTable := db.Articles.Table.WithTx(pgTx) + if err := txTable.HardDeleteByID(ctx, article.ID); err != nil { + return err + } + + _, err := txTable.GetByID(ctx, article.ID) + require.Error(t, err, "article should be deleted inside tx") + + return nil + }) + require.NoError(t, err, "WithTx HardDeleteByID failed") + + _, err = db.Articles.GetByID(ctx, article.ID) + require.Error(t, err, "article should be deleted") + }) +} + +func TestLockForUpdates(t *testing.T) { + truncateAllTables(t) + + ctx := t.Context() + db := initDB(DB) + worker := &Worker{DB: db} + + t.Run("TestLockForUpdates", func(t *testing.T) { + // Create account. + account := &Account{Name: "LockForUpdates Account"} + err := db.Accounts.Save(ctx, account) + require.NoError(t, err, "Create account failed") + + // Create article. + article := &Article{AccountID: account.ID, Author: "Author", Content: Content{Title: "Title", Body: "Body"}} + err = db.Articles.Save(ctx, article) + require.NoError(t, err, "Create article failed") + + // Create 1000 reviews. + reviews := make([]*Review, 100) + for i := range 100 { + reviews[i] = &Review{ + Comment: fmt.Sprintf("Test comment %d", i), + AccountID: account.ID, + ArticleID: article.ID, + Status: ReviewStatusPending, + } + } + err = db.Reviews.Save(ctx, reviews...) + require.NoError(t, err, "create review") + + var ids [][]uint64 = make([][]uint64, 10) + var wg sync.WaitGroup + + for range 10 { + wg.Add(1) + go func() { + defer wg.Done() + + reviews, err := db.Reviews.DequeueForProcessing(ctx, 10) + require.NoError(t, err, "dequeue reviews") + + for i, review := range reviews { + go worker.ProcessReview(ctx, review) + + ids[i] = append(ids[i], review.ID) + } + }() + } + wg.Wait() + + // Ensure that all reviews were picked up for processing exactly once. + uniqueIDs := slices.Concat(ids...) + slices.Sort(uniqueIDs) + uniqueIDs = slices.Compact(uniqueIDs) + require.Equal(t, 100, len(uniqueIDs), "number of unique reviews picked up for processing should be 100") + + // Wait for all reviews to be processed asynchronously. + worker.Wait() + + // Double check there's no reviews stuck in "processing" status. + count, err := db.Reviews.Count(ctx, sq.Eq{"status": ReviewStatusProcessing}) + require.NoError(t, err, "count reviews") + require.Zero(t, count, "there should be no reviews stuck in 'processing' status") + }) +} diff --git a/tests/tables_test.go b/tests/tables_test.go new file mode 100644 index 0000000..d734da0 --- /dev/null +++ b/tests/tables_test.go @@ -0,0 +1,47 @@ +package pgkit_test + +import ( + "context" + "fmt" + "time" + + sq "github.com/Masterminds/squirrel" + "github.com/goware/pgkit/v2" +) + +type accountsTable struct { + *pgkit.Table[Account, *Account, int64] +} + +type articlesTable struct { + *pgkit.Table[Article, *Article, uint64] +} + +type reviewsTable struct { + *pgkit.Table[Review, *Review, uint64] +} + +func (t *reviewsTable) DequeueForProcessing(ctx context.Context, limit uint64) ([]*Review, error) { + var dequeued []*Review + where := sq.Eq{ + "status": ReviewStatusPending, + "deleted_at": nil, + } + orderBy := []string{ + "created_at ASC", + } + + err := t.LockForUpdates(ctx, where, orderBy, limit, func(reviews []*Review) { + now := time.Now().UTC() + for _, review := range reviews { + review.Status = ReviewStatusProcessing + review.ProcessedAt = &now + } + dequeued = reviews + }) + if err != nil { + return nil, fmt.Errorf("lock for updates: %w", err) + } + + return dequeued, nil +} diff --git a/tests/testdata/pgkit_test_db.sql b/tests/testdata/pgkit_test_db.sql index a55dbf8..406d655 100644 --- a/tests/testdata/pgkit_test_db.sql +++ b/tests/testdata/pgkit_test_db.sql @@ -3,15 +3,32 @@ CREATE TABLE accounts ( name VARCHAR(255), disabled BOOLEAN, new_column_not_in_code BOOLEAN, -- test for backward-compatible migrations, see https://github.com/goware/pgkit/issues/13 - created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL, + updated_at TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL +); + +CREATE TABLE articles ( + id SERIAL PRIMARY KEY, + author VARCHAR(80) NOT NULL, + alias VARCHAR(80), + content JSONB, + account_id INTEGER NOT NULL REFERENCES accounts(id), + created_at TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL, + updated_at TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL, + deleted_at TIMESTAMP WITHOUT TIME ZONE NULL ); CREATE TABLE reviews ( id SERIAL PRIMARY KEY, - -- article_id integer, - name VARCHAR(80), - comments TEXT, - created_at TIMESTAMP WITHOUT TIME ZONE + article_id INTEGER REFERENCES articles(id), + account_id INTEGER NOT NULL REFERENCES accounts(id), + comment TEXT, + status SMALLINT, + sentiment SMALLINT, + processed_at TIMESTAMP WITHOUT TIME ZONE NULL, + created_at TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL, + updated_at TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL, + deleted_at TIMESTAMP WITHOUT TIME ZONE NULL ); CREATE TABLE logs ( @@ -27,10 +44,3 @@ CREATE TABLE stats ( big_num NUMERIC(78,0) NOT NULL, -- representing a big.Int runtime type rating NUMERIC(78,0) NULL -- representing a nullable big.Int runtime type ); - -CREATE TABLE articles ( - id SERIAL PRIMARY KEY, - author VARCHAR(80) NOT NULL, - alias VARCHAR(80), - content JSONB -); diff --git a/tests/worker_test.go b/tests/worker_test.go new file mode 100644 index 0000000..711a3af --- /dev/null +++ b/tests/worker_test.go @@ -0,0 +1,64 @@ +package pgkit_test + +import ( + "context" + "fmt" + "log" + "math/rand" + "sync" + "time" + + sq "github.com/Masterminds/squirrel" +) + +type Worker struct { + DB *Database + + wg sync.WaitGroup +} + +func (w *Worker) Wait() { + w.wg.Wait() +} + +func (w *Worker) ProcessReview(ctx context.Context, review *Review) (err error) { + w.wg.Add(1) + defer w.wg.Done() + + defer func() { + // Always update review status to "approved", "rejected" or "failed". + noCtx := context.Background() + err = w.DB.Reviews.LockForUpdate(noCtx, sq.Eq{"id": review.ID}, []string{"id DESC"}, func(update *Review) { + now := time.Now().UTC() + update.ProcessedAt = &now + if err != nil { + update.Status = ReviewStatusFailed + return + } + update.Status = review.Status + }) + if err != nil { + log.Printf("failed to save review: %v", err) + } + }() + + // Simulate long-running work. + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(1 * time.Second): + } + + // Simulate external API call to an LLM. + if rand.Intn(2) == 0 { + return fmt.Errorf("failed to process review: ") + } + + review.Status = ReviewStatusApproved + if rand.Intn(2) == 0 { + review.Status = ReviewStatusRejected + } + now := time.Now().UTC() + review.ProcessedAt = &now + return nil +}