Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,6 @@ testdata/tern.conf
/tern
/tmp

.vscode
.idea/*
dist
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,12 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA=
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw=
github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
github.com/jackc/pgx/v5 v5.7.4 h1:9wKznZrhWa2QiHL+NjTSPP6yjl3451BX3imWDnokYlg=
github.com/jackc/pgx/v5 v5.7.4/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
Expand Down
32 changes: 21 additions & 11 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,13 @@ type Config struct {
}

var cliOptions struct {
destinationVersion string
currentVersion string
migrationsPath string
configPaths []string
editNewMigration bool
outputFile string // used for gengen or print-migrations
destinationVersion string
currentVersion string
migrationsPath string
configPaths []string
editNewMigration bool
outputFile string // used for gengen or print-migrations
cockroachDbCompatible bool

connString string
host string
Expand Down Expand Up @@ -187,6 +188,7 @@ The word "last":
Run: Migrate,
}
cmdMigrate.Flags().StringVarP(&cliOptions.destinationVersion, "destination", "d", "last", "destination migration version")
cmdMigrate.Flags().BoolVar(&cliOptions.cockroachDbCompatible, "cockroachdb", false, "CockroachDB compatibility flag avoiding advisory locks (default is false)")
addConfigFlagsToCommand(cmdMigrate)

cmdCode := &cobra.Command{
Expand Down Expand Up @@ -507,7 +509,9 @@ func Migrate(cmd *cobra.Command, args []string) {
config, conn := loadConfigAndConnectToDB(ctx)
defer conn.Close(ctx)

migrator, err := migrate.NewMigrator(ctx, conn, config.VersionTable)
migOpts := migrate.MigratorOptions{ CockroachDbCompatible: cliOptions.cockroachDbCompatible }

migrator, err := migrate.NewMigratorEx(ctx, conn, config.VersionTable, &migOpts)
if err != nil {
fmt.Fprintf(os.Stderr, "Error initializing migrator:\n %v\n", err)
os.Exit(1)
Expand Down Expand Up @@ -612,7 +616,9 @@ func Gengen(cmd *cobra.Command, args []string) {
os.Exit(1)
}

migrator, err := migrate.NewMigrator(context.Background(), nil, config.VersionTable)
migOpts := migrate.MigratorOptions{ CockroachDbCompatible: cliOptions.cockroachDbCompatible }

migrator, err := migrate.NewMigratorEx(context.Background(), nil, config.VersionTable, &migOpts)
if err != nil {
fmt.Fprintf(os.Stderr, "Error initializing migrator:\n %v\n", err)
os.Exit(1)
Expand Down Expand Up @@ -863,7 +869,9 @@ func Status(cmd *cobra.Command, args []string) {
config, conn := loadConfigAndConnectToDB(ctx)
defer conn.Close(ctx)

migrator, err := migrate.NewMigrator(ctx, conn, config.VersionTable)
migOpts := migrate.MigratorOptions{ CockroachDbCompatible: cliOptions.cockroachDbCompatible }

migrator, err := migrate.NewMigratorEx(ctx, conn, config.VersionTable, &migOpts)
if err != nil {
fmt.Fprintf(os.Stderr, "Error initializing migrator:\n %v\n", err)
os.Exit(1)
Expand Down Expand Up @@ -1307,7 +1315,8 @@ func PrintMigrations(cmd *cobra.Command, args []string) {
fmt.Fprintf(os.Stderr, "Error connecting to database:\n %v\n", err)
os.Exit(1)
}
migrator, err = migrate.NewMigrator(ctx, conn, config.VersionTable)
migOpts := migrate.MigratorOptions{ CockroachDbCompatible: cliOptions.cockroachDbCompatible }
migrator, err = migrate.NewMigratorEx(ctx, conn, config.VersionTable, &migOpts)
if err != nil {
fmt.Fprintf(os.Stderr, "Error initializing migrator:\n %v\n", err)
os.Exit(1)
Expand All @@ -1321,7 +1330,8 @@ func PrintMigrations(cmd *cobra.Command, args []string) {
}
currentVersion = int32(n)

migrator, err = migrate.NewMigrator(ctx, nil, config.VersionTable)
migOpts := migrate.MigratorOptions{ CockroachDbCompatible: cliOptions.cockroachDbCompatible }
migrator, err = migrate.NewMigratorEx(ctx, nil, config.VersionTable, &migOpts)
if err != nil {
fmt.Fprintf(os.Stderr, "Error initializing migrator:\n %v\n", err)
os.Exit(1)
Expand Down
171 changes: 147 additions & 24 deletions migrate/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"io/fs"
"path/filepath"
"regexp"
"runtime"
"strconv"
"strings"
"text/template"
Expand All @@ -21,6 +22,7 @@ import (
var (
migrationPattern = regexp.MustCompile(`\A(\d+)_.+\.sql\z`)
disableTxPattern = regexp.MustCompile(`(?m)^---- tern: disable-tx ----$`)
disableRowLocks = regexp.MustCompile(`(?m)^---- tern: disable-row-locks ----$`)
)

const (
Expand Down Expand Up @@ -131,34 +133,61 @@ func (m *Migration) irreversible() bool {
type MigratorOptions struct {
// DisableTx causes the Migrator not to run migrations in a transaction.
DisableTx bool
CockroachDbCompatible bool
}

type Migrator struct {
conn *pgx.Conn

// optionally, used for locking mechanisms
// instead of advisory locks on the primary conn
lockingConn *pgx.Conn
// optionally, Tx for the locking mechanism
lockingTx pgx.Tx

versionTable string
options *MigratorOptions
Migrations []*Migration
OnStart func(int32, string, string, string) // OnStart is called when a migration is run with the sequence, name, direction, and SQL
Data map[string]interface{} // Data available to use in migrations
}

// NewMigrator initializes a new Migrator. It is highly recommended that versionTable be schema qualified.
func NewMigrator(ctx context.Context, conn *pgx.Conn, versionTable string) (m *Migrator, err error) {
return NewMigratorEx(ctx, conn, versionTable, &MigratorOptions{})
}

// NewMigratorEx initializes a new Migrator. It is highly recommended that versionTable be schema qualified.
// NewMigrator initializes a new Migrator. It is highly recommended that versionTable be schema qualified.
func NewMigratorEx(ctx context.Context, conn *pgx.Conn, versionTable string, opts *MigratorOptions) (m *Migrator, err error) {
m = &Migrator{conn: conn, versionTable: versionTable, options: opts}

m.Migrations = make([]*Migration, 0)
m.Data = make(map[string]interface{})

if opts.CockroachDbCompatible {
m.lockingConn, err = pgx.ConnectConfig(ctx, conn.Config().Copy())
if err != nil {
// try anyways and leave lockingconn nil? either way there's a failure
// For now, be explicit and block so it's clear
return
}

// Migrator is the owner of this connection. Instead of requiring the user of Migrator
// to close (a usage change) let go manage the runtime
// Once go compat moves to 1.24 can replace with AddCleanup https://pkg.go.dev/runtime@master#AddCleanup
runtime.SetFinalizer(m.lockingConn, func(c *pgx.Conn) {
if err := c.Close(ctx); err != nil {
fmt.Println("trying to close lockingConn:", err.Error())
}
})
}

// This is a bit of a kludge for the gengen command. A migrator without a conn is normally not allowed. However, the
// gengen command doesn't call any of the methods that require a conn. Potentially, we could refactor Migrator to
// split out the migration loading and parsing from the actual migration execution.
if conn != nil {
err = m.ensureSchemaVersionTableExists(ctx)
}
m.Migrations = make([]*Migration, 0)
m.Data = make(map[string]interface{})

return
}

Expand Down Expand Up @@ -325,9 +354,6 @@ func (m *Migrator) AppendMigration(name, upSQL, downSQL string) {
// Migrate runs pending migrations
// It calls m.OnStart when it begins a migration
func (m *Migrator) Migrate(ctx context.Context) error {
if err := m.validate(); err != nil {
return err
}
return m.MigrateTo(ctx, m.highestSequenceNum())
}

Expand All @@ -347,6 +373,56 @@ func (m *Migrator) validate() error {
return nil
}

func (m *Migrator) acquireLock(ctx context.Context) error {
if m.lockingConn != nil {
return m.acquireCustomLock(ctx)
}

return acquireAdvisoryLock(ctx, m.conn)
}

func (m *Migrator) releaseLock(ctx context.Context) error {
if m.lockingConn != nil {
return m.releaseCustomLock(ctx)
}

return releaseAdvisoryLock(ctx, m.conn)
}

var ErrLockNonRecursive = errors.New("lock is nonrecursive")

// CockroachDB Compatible Locking Mechanism
func (m *Migrator) acquireCustomLock(ctx context.Context) (err error) {
query := fmt.Sprintf("select * from %s_lock for update nowait", m.versionTable)

if m.lockingTx != nil {
return ErrLockNonRecursive
}

m.lockingTx, err = m.lockingConn.Begin(ctx)
if err != nil {
return
}

if _, err := m.lockingTx.Exec(ctx, query); err != nil {
return err
}

return nil
}

// CockroachDB Compatible Locking Mechanism
func (m *Migrator) releaseCustomLock(ctx context.Context) error {
err := m.lockingTx.Commit(ctx)
if err != nil {
return err
}

m.lockingTx = nil

return nil
}

// Lock to ensure multiple migrations cannot occur simultaneously
const lockNum = int64(9628173550095224) // arbitrary random number

Expand All @@ -366,12 +442,12 @@ func (m *Migrator) MigrateTo(ctx context.Context, targetVersion int32) (err erro
return err
}

err = acquireAdvisoryLock(ctx, m.conn)
err = m.acquireLock(ctx)
if err != nil {
return err
}
defer func() {
unlockErr := releaseAdvisoryLock(ctx, m.conn)
unlockErr := m.releaseLock(ctx)
if err == nil && unlockErr != nil {
err = unlockErr
}
Expand Down Expand Up @@ -456,7 +532,7 @@ func (m *Migrator) MigrateTo(ctx context.Context, targetVersion int32) (err erro
m.conn.Exec(ctx, "reset all")

// Add one to the version
_, err = m.conn.Exec(ctx, "update "+m.versionTable+" set version=$1", sequence)
_, err = m.conn.Exec(ctx, "update "+m.versionTable+" set version=$1 where version >= 0", sequence)
if err != nil {
return err
}
Expand All @@ -475,17 +551,30 @@ func (m *Migrator) MigrateTo(ctx context.Context, targetVersion int32) (err erro
}

func (m *Migrator) GetCurrentVersion(ctx context.Context) (v int32, err error) {
err = m.conn.QueryRow(ctx, "select version from "+m.versionTable).Scan(&v)
return v, err
query := "select version from "+m.versionTable+" where version >= 0"

if m.lockingTx != nil {
err = m.lockingTx.QueryRow(ctx, query).Scan(&v)
} else {
err = m.conn.QueryRow(ctx, query).Scan(&v)
}

return
}

func (m *Migrator) ensureSchemaVersionTableExists(ctx context.Context) (err error) {
err = acquireAdvisoryLock(ctx, m.conn)
if m.lockingConn != nil {
// solve the bootstrap problem needing the table
// to lock and needing a lock to create the table
return m.createIfNotExistsVersionTable(ctx)
}

err = m.acquireLock(ctx)
if err != nil {
return err
}
defer func() {
unlockErr := releaseAdvisoryLock(ctx, m.conn)
unlockErr := m.releaseLock(ctx)
if err == nil && unlockErr != nil {
err = unlockErr
}
Expand All @@ -495,13 +584,34 @@ func (m *Migrator) ensureSchemaVersionTableExists(ctx context.Context) (err erro
return err
}

_, err = m.conn.Exec(ctx, fmt.Sprintf(`
create table if not exists %s(version int4 not null);
return m.createIfNotExistsVersionTable(ctx)
}

// Not Thread Safe / Lock Safe
func (m *Migrator) createIfNotExistsVersionTable(ctx context.Context) error {
_, err := m.conn.Exec(ctx, fmt.Sprintf(`
create table if not exists %s(version int4 not null primary key);

with initial(version) as (values (0))
insert into %s(version)
select * from initial
where 0=(select count(*) from %s);
`, m.versionTable, m.versionTable, m.versionTable))
if err != nil {
return err
}

if m.options.CockroachDbCompatible {
_, err = m.conn.Exec(ctx, fmt.Sprintf(`
create table if not exists %s_lock(lock boolean not null primary key default true);

with initial(lock) as (values (true))
insert into %s_lock(lock)
select * from initial
where 0=(select count(*) from %s)
`, m.versionTable, m.versionTable, m.versionTable))
}

insert into %s(version)
select 0
where 0=(select count(*) from %s);
`, m.versionTable, m.versionTable, m.versionTable))
return err
}

Expand Down Expand Up @@ -545,13 +655,26 @@ func (m *Migrator) doSQLMigration(ctx context.Context, migration *Migration, dir
}
// Execute the migration
for _, statement := range sqlStatements {
if _, err := m.conn.Exec(ctx, statement); err != nil {
if err, ok := err.(*pgconn.PgError); ok {
return MigrationPgError{MigrationName: migration.Name, Sql: statement, PgError: err}
}
if err := m.sqlExecMigration(ctx, migration, statement); err != nil {
return err
}
}

return nil
}

func (m *Migrator) sqlExecMigration(ctx context.Context, migration *Migration, statement string) error {
if disableRowLocks.MatchString(statement) && m.lockingTx != nil {
m.releaseLock(ctx)
defer m.acquireLock(ctx)
}

if _, err := m.conn.Exec(ctx, statement); err != nil {
if err, ok := err.(*pgconn.PgError); ok {
return MigrationPgError{MigrationName: migration.Name, Sql: statement, PgError: err}
}
return err
}

return nil
}
Loading