Extract generic migration logic into datastores/sql/migrate package

Move database-agnostic migration logic (file discovery, version tracking,
bootstrap detection) into a shared migrate package behind a Dialect
interface, leaving postgres as a thin wrapper.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-20 20:10:11 -05:00
parent d403e18d25
commit aceea44c90
3 changed files with 216 additions and 179 deletions

View File

@@ -0,0 +1,197 @@
package migrate
import (
"context"
"database/sql"
"fmt"
"os"
"path/filepath"
"sort"
"strings"
"time"
)
// Dialect abstracts the few SQL differences between database engines.
type Dialect interface {
// Placeholder returns the bind parameter for position n (1-indexed).
// Postgres: "$1", "$2". MySQL/SQLite: "?".
Placeholder(n int) string
// TableExistsQuery returns SQL that checks whether a table exists.
// The query must accept a single bind parameter (the table name) and
// return a single boolean column.
TableExistsQuery() string
}
// Options configures the migration runner.
type Options struct {
// MigrationsDir is the directory containing *.up.sql migration files.
// If empty, it defaults to "migrations" relative to the working directory.
MigrationsDir string
// BootstrapTable is the name of a table whose existence signals that the
// database already has a schema applied before the schema_migrations
// tracking table was introduced. When non-empty and this table exists but
// schema_migrations does not, all discovered migrations are recorded as
// already applied (bootstrapped). When empty, no bootstrapping is performed.
BootstrapTable string
}
// Run executes forward migrations against db using the given dialect.
// It discovers *.up.sql files in the configured directory, tracks applied
// versions in a schema_migrations table, and applies any pending migrations
// in filename-sorted order.
func Run(ctx context.Context, db *sql.DB, dialect Dialect, opts *Options) error {
if opts == nil {
opts = &Options{}
}
migrationsTableExists := tableExists(ctx, db, dialect, "schema_migrations")
if err := ensureMigrationsTable(ctx, db); err != nil {
return fmt.Errorf("failed to create migrations table: %w", err)
}
migrationsPath := opts.MigrationsDir
if migrationsPath == "" {
wd, err := os.Getwd()
if err != nil {
return fmt.Errorf("failed to get working directory: %w", err)
}
migrationsPath = filepath.Join(wd, "migrations")
}
migrations, err := discoverMigrations(migrationsPath)
if err != nil {
return fmt.Errorf("failed to discover migrations: %w", err)
}
if !migrationsTableExists && opts.BootstrapTable != "" {
if err := bootstrapMigrationState(ctx, db, dialect, migrations, opts.BootstrapTable); err != nil {
return fmt.Errorf("failed to bootstrap migration state: %w", err)
}
}
applied, err := getAppliedMigrations(ctx, db)
if err != nil {
return fmt.Errorf("failed to get applied migrations: %w", err)
}
for _, migration := range migrations {
if applied[migration] {
continue
}
migrationSQL, err := os.ReadFile(filepath.Join(migrationsPath, migration))
if err != nil {
return fmt.Errorf("failed to read migration file %s: %w", migration, err)
}
if _, err := db.ExecContext(ctx, string(migrationSQL)); err != nil {
return fmt.Errorf("failed to execute migration %s: %w", migration, err)
}
if err := recordMigration(ctx, db, dialect, migration); err != nil {
return fmt.Errorf("failed to record migration %s: %w", migration, err)
}
fmt.Printf("Applied migration: %s\n", migration)
}
return nil
}
// ensureMigrationsTable creates the schema_migrations table if it doesn't exist.
func ensureMigrationsTable(ctx context.Context, db *sql.DB) error {
_, err := db.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS schema_migrations (
version VARCHAR(255) PRIMARY KEY,
applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
)
`)
return err
}
// tableExists checks if a table exists in the database.
func tableExists(ctx context.Context, db *sql.DB, dialect Dialect, tableName string) bool {
var exists bool
err := db.QueryRowContext(ctx,
dialect.TableExistsQuery(),
tableName,
).Scan(&exists)
return err == nil && exists
}
// bootstrapMigrationState marks all discovered migrations as applied when the
// database already contains a schema (detected via sentinelTable) but has no
// schema_migrations table yet.
func bootstrapMigrationState(ctx context.Context, db *sql.DB, dialect Dialect, migrations []string, sentinelTable string) error {
if !tableExists(ctx, db, dialect, sentinelTable) {
return nil
}
for _, migration := range migrations {
if err := recordMigration(ctx, db, dialect, migration); err != nil {
return fmt.Errorf("failed to record bootstrapped migration %s: %w", migration, err)
}
fmt.Printf("Bootstrapped migration (already applied): %s\n", migration)
}
return nil
}
// getAppliedMigrations returns a set of already-applied migration filenames.
func getAppliedMigrations(ctx context.Context, db *sql.DB) (map[string]bool, error) {
applied := make(map[string]bool)
rows, err := db.QueryContext(ctx, "SELECT version FROM schema_migrations")
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var version string
if err := rows.Scan(&version); err != nil {
return nil, err
}
applied[version] = true
}
return applied, rows.Err()
}
// recordMigration records a migration version as applied.
func recordMigration(ctx context.Context, db *sql.DB, dialect Dialect, version string) error {
_, err := db.ExecContext(ctx,
fmt.Sprintf("INSERT INTO schema_migrations (version, applied_at) VALUES (%s, %s)",
dialect.Placeholder(1), dialect.Placeholder(2)),
version, time.Now(),
)
return err
}
// discoverMigrations finds all *.up.sql files in the given directory and
// returns them sorted by filename (relies on numeric prefixes like
// 000001_, 000002_, etc.).
func discoverMigrations(migrationsPath string) ([]string, error) {
entries, err := os.ReadDir(migrationsPath)
if err != nil {
return nil, err
}
var migrations []string
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
if strings.HasSuffix(name, ".up.sql") {
migrations = append(migrations, name)
}
}
sort.Strings(migrations)
return migrations, nil
}

View File

@@ -0,0 +1,15 @@
package postgres
import "fmt"
// Dialect implements migrate.Dialect for PostgreSQL.
type Dialect struct{}
// Placeholder returns PostgreSQL's positional bind parameter ($1, $2, …).
func (Dialect) Placeholder(n int) string { return fmt.Sprintf("$%d", n) }
// TableExistsQuery returns a query that checks whether a table exists in
// PostgreSQL using information_schema.
func (Dialect) TableExistsQuery() string {
return "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = $1)"
}

View File

@@ -4,38 +4,14 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"os"
"path/filepath"
"sort"
"strings"
"time"
_ "github.com/jackc/pgx/v5/stdlib" // Register pgx driver for database/sql _ "github.com/jackc/pgx/v5/stdlib" // Register pgx driver for database/sql
"git.nonahob.net/jacob/golibs/datastores/sql/migrate"
) )
// MigrateOptions configures the migration runner.
type MigrateOptions struct {
// MigrationsDir is the directory containing *.up.sql migration files.
// If empty, it defaults to "migrations" relative to the working directory.
MigrationsDir string
// BootstrapTable is the name of a table whose existence signals that the
// database already has a schema applied before the schema_migrations
// tracking table was introduced. When non-empty and this table exists but
// schema_migrations does not, all discovered migrations are recorded as
// already applied (bootstrapped). When empty, no bootstrapping is performed.
BootstrapTable string
}
// Migrate runs forward (up) migrations using the config's DSN. // Migrate runs forward (up) migrations using the config's DSN.
// It discovers *.up.sql files in the configured directory, tracks applied func (c *Config) Migrate(opts *migrate.Options) error {
// versions in a schema_migrations table, and applies any pending migrations
// in filename-sorted order.
func (c *Config) Migrate(opts *MigrateOptions) error {
if opts == nil {
opts = &MigrateOptions{}
}
db, err := sql.Open("pgx", c.DSN()) db, err := sql.Open("pgx", c.DSN())
if err != nil { if err != nil {
return fmt.Errorf("failed to open database: %w", err) return fmt.Errorf("failed to open database: %w", err)
@@ -46,156 +22,5 @@ func (c *Config) Migrate(opts *MigrateOptions) error {
return fmt.Errorf("failed to ping database: %w", err) return fmt.Errorf("failed to ping database: %w", err)
} }
ctx := context.Background() return migrate.Run(context.Background(), db, Dialect{}, opts)
migrationsTableExists := tableExists(ctx, db, "schema_migrations")
if err := ensureMigrationsTable(ctx, db); err != nil {
return fmt.Errorf("failed to create migrations table: %w", err)
}
migrationsPath := opts.MigrationsDir
if migrationsPath == "" {
wd, err := os.Getwd()
if err != nil {
return fmt.Errorf("failed to get working directory: %w", err)
}
migrationsPath = filepath.Join(wd, "migrations")
}
migrations, err := discoverMigrations(migrationsPath)
if err != nil {
return fmt.Errorf("failed to discover migrations: %w", err)
}
// Bootstrap: if schema_migrations was just created but the database already
// has a schema, mark all discovered migrations as already applied.
if !migrationsTableExists && opts.BootstrapTable != "" {
if err := bootstrapMigrationState(ctx, db, migrations, opts.BootstrapTable); err != nil {
return fmt.Errorf("failed to bootstrap migration state: %w", err)
}
}
applied, err := getAppliedMigrations(ctx, db)
if err != nil {
return fmt.Errorf("failed to get applied migrations: %w", err)
}
for _, migration := range migrations {
if applied[migration] {
continue
}
migrationSQL, err := os.ReadFile(filepath.Join(migrationsPath, migration))
if err != nil {
return fmt.Errorf("failed to read migration file %s: %w", migration, err)
}
if _, err := db.ExecContext(ctx, string(migrationSQL)); err != nil {
return fmt.Errorf("failed to execute migration %s: %w", migration, err)
}
if err := recordMigration(ctx, db, migration); err != nil {
return fmt.Errorf("failed to record migration %s: %w", migration, err)
}
fmt.Printf("Applied migration: %s\n", migration)
}
return nil
}
// ensureMigrationsTable creates the schema_migrations table if it doesn't exist.
func ensureMigrationsTable(ctx context.Context, db *sql.DB) error {
_, err := db.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS schema_migrations (
version VARCHAR(255) PRIMARY KEY,
applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
)
`)
return err
}
// tableExists checks if a table exists in the database.
func tableExists(ctx context.Context, db *sql.DB, tableName string) bool {
var exists bool
err := db.QueryRowContext(ctx,
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = $1)",
tableName,
).Scan(&exists)
return err == nil && exists
}
// bootstrapMigrationState marks all discovered migrations as applied when the
// database already contains a schema (detected via sentinelTable) but has no
// schema_migrations table yet.
func bootstrapMigrationState(ctx context.Context, db *sql.DB, migrations []string, sentinelTable string) error {
if !tableExists(ctx, db, sentinelTable) {
// Fresh database, no bootstrapping needed.
return nil
}
for _, migration := range migrations {
if err := recordMigration(ctx, db, migration); err != nil {
return fmt.Errorf("failed to record bootstrapped migration %s: %w", migration, err)
}
fmt.Printf("Bootstrapped migration (already applied): %s\n", migration)
}
return nil
}
// getAppliedMigrations returns a set of already-applied migration filenames.
func getAppliedMigrations(ctx context.Context, db *sql.DB) (map[string]bool, error) {
applied := make(map[string]bool)
rows, err := db.QueryContext(ctx, "SELECT version FROM schema_migrations")
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var version string
if err := rows.Scan(&version); err != nil {
return nil, err
}
applied[version] = true
}
return applied, rows.Err()
}
// recordMigration records a migration version as applied.
func recordMigration(ctx context.Context, db *sql.DB, version string) error {
_, err := db.ExecContext(ctx,
"INSERT INTO schema_migrations (version, applied_at) VALUES ($1, $2)",
version, time.Now(),
)
return err
}
// discoverMigrations finds all *.up.sql files in the given directory and
// returns them sorted by filename (relies on numeric prefixes like
// 000001_, 000002_, etc.).
func discoverMigrations(migrationsPath string) ([]string, error) {
entries, err := os.ReadDir(migrationsPath)
if err != nil {
return nil, err
}
var migrations []string
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
if strings.HasSuffix(name, ".up.sql") {
migrations = append(migrations, name)
}
}
sort.Strings(migrations)
return migrations, nil
} }