Extract generic migration logic into datastores/sql/migrate package
Move database-agnostic migration logic (file discovery, version tracking, bootstrap detection) into a shared migrate package behind a Dialect interface, leaving postgres as a thin wrapper. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
197
datastores/sql/migrate/migrate.go
Normal file
197
datastores/sql/migrate/migrate.go
Normal file
@@ -0,0 +1,197 @@
|
||||
package migrate
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Dialect abstracts the few SQL differences between database engines.
|
||||
type Dialect interface {
|
||||
// Placeholder returns the bind parameter for position n (1-indexed).
|
||||
// Postgres: "$1", "$2". MySQL/SQLite: "?".
|
||||
Placeholder(n int) string
|
||||
|
||||
// TableExistsQuery returns SQL that checks whether a table exists.
|
||||
// The query must accept a single bind parameter (the table name) and
|
||||
// return a single boolean column.
|
||||
TableExistsQuery() string
|
||||
}
|
||||
|
||||
// Options configures the migration runner.
|
||||
type Options struct {
|
||||
// MigrationsDir is the directory containing *.up.sql migration files.
|
||||
// If empty, it defaults to "migrations" relative to the working directory.
|
||||
MigrationsDir string
|
||||
|
||||
// BootstrapTable is the name of a table whose existence signals that the
|
||||
// database already has a schema applied before the schema_migrations
|
||||
// tracking table was introduced. When non-empty and this table exists but
|
||||
// schema_migrations does not, all discovered migrations are recorded as
|
||||
// already applied (bootstrapped). When empty, no bootstrapping is performed.
|
||||
BootstrapTable string
|
||||
}
|
||||
|
||||
// Run executes forward migrations against db using the given dialect.
|
||||
// It discovers *.up.sql files in the configured directory, tracks applied
|
||||
// versions in a schema_migrations table, and applies any pending migrations
|
||||
// in filename-sorted order.
|
||||
func Run(ctx context.Context, db *sql.DB, dialect Dialect, opts *Options) error {
|
||||
if opts == nil {
|
||||
opts = &Options{}
|
||||
}
|
||||
|
||||
migrationsTableExists := tableExists(ctx, db, dialect, "schema_migrations")
|
||||
|
||||
if err := ensureMigrationsTable(ctx, db); err != nil {
|
||||
return fmt.Errorf("failed to create migrations table: %w", err)
|
||||
}
|
||||
|
||||
migrationsPath := opts.MigrationsDir
|
||||
if migrationsPath == "" {
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get working directory: %w", err)
|
||||
}
|
||||
migrationsPath = filepath.Join(wd, "migrations")
|
||||
}
|
||||
|
||||
migrations, err := discoverMigrations(migrationsPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to discover migrations: %w", err)
|
||||
}
|
||||
|
||||
if !migrationsTableExists && opts.BootstrapTable != "" {
|
||||
if err := bootstrapMigrationState(ctx, db, dialect, migrations, opts.BootstrapTable); err != nil {
|
||||
return fmt.Errorf("failed to bootstrap migration state: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
applied, err := getAppliedMigrations(ctx, db)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get applied migrations: %w", err)
|
||||
}
|
||||
|
||||
for _, migration := range migrations {
|
||||
if applied[migration] {
|
||||
continue
|
||||
}
|
||||
|
||||
migrationSQL, err := os.ReadFile(filepath.Join(migrationsPath, migration))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read migration file %s: %w", migration, err)
|
||||
}
|
||||
|
||||
if _, err := db.ExecContext(ctx, string(migrationSQL)); err != nil {
|
||||
return fmt.Errorf("failed to execute migration %s: %w", migration, err)
|
||||
}
|
||||
|
||||
if err := recordMigration(ctx, db, dialect, migration); err != nil {
|
||||
return fmt.Errorf("failed to record migration %s: %w", migration, err)
|
||||
}
|
||||
|
||||
fmt.Printf("Applied migration: %s\n", migration)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureMigrationsTable creates the schema_migrations table if it doesn't exist.
|
||||
func ensureMigrationsTable(ctx context.Context, db *sql.DB) error {
|
||||
_, err := db.ExecContext(ctx, `
|
||||
CREATE TABLE IF NOT EXISTS schema_migrations (
|
||||
version VARCHAR(255) PRIMARY KEY,
|
||||
applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
`)
|
||||
return err
|
||||
}
|
||||
|
||||
// tableExists checks if a table exists in the database.
|
||||
func tableExists(ctx context.Context, db *sql.DB, dialect Dialect, tableName string) bool {
|
||||
var exists bool
|
||||
err := db.QueryRowContext(ctx,
|
||||
dialect.TableExistsQuery(),
|
||||
tableName,
|
||||
).Scan(&exists)
|
||||
return err == nil && exists
|
||||
}
|
||||
|
||||
// bootstrapMigrationState marks all discovered migrations as applied when the
|
||||
// database already contains a schema (detected via sentinelTable) but has no
|
||||
// schema_migrations table yet.
|
||||
func bootstrapMigrationState(ctx context.Context, db *sql.DB, dialect Dialect, migrations []string, sentinelTable string) error {
|
||||
if !tableExists(ctx, db, dialect, sentinelTable) {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, migration := range migrations {
|
||||
if err := recordMigration(ctx, db, dialect, migration); err != nil {
|
||||
return fmt.Errorf("failed to record bootstrapped migration %s: %w", migration, err)
|
||||
}
|
||||
fmt.Printf("Bootstrapped migration (already applied): %s\n", migration)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getAppliedMigrations returns a set of already-applied migration filenames.
|
||||
func getAppliedMigrations(ctx context.Context, db *sql.DB) (map[string]bool, error) {
|
||||
applied := make(map[string]bool)
|
||||
|
||||
rows, err := db.QueryContext(ctx, "SELECT version FROM schema_migrations")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var version string
|
||||
if err := rows.Scan(&version); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
applied[version] = true
|
||||
}
|
||||
|
||||
return applied, rows.Err()
|
||||
}
|
||||
|
||||
// recordMigration records a migration version as applied.
|
||||
func recordMigration(ctx context.Context, db *sql.DB, dialect Dialect, version string) error {
|
||||
_, err := db.ExecContext(ctx,
|
||||
fmt.Sprintf("INSERT INTO schema_migrations (version, applied_at) VALUES (%s, %s)",
|
||||
dialect.Placeholder(1), dialect.Placeholder(2)),
|
||||
version, time.Now(),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// discoverMigrations finds all *.up.sql files in the given directory and
|
||||
// returns them sorted by filename (relies on numeric prefixes like
|
||||
// 000001_, 000002_, etc.).
|
||||
func discoverMigrations(migrationsPath string) ([]string, error) {
|
||||
entries, err := os.ReadDir(migrationsPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var migrations []string
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := entry.Name()
|
||||
if strings.HasSuffix(name, ".up.sql") {
|
||||
migrations = append(migrations, name)
|
||||
}
|
||||
}
|
||||
|
||||
sort.Strings(migrations)
|
||||
|
||||
return migrations, nil
|
||||
}
|
||||
Reference in New Issue
Block a user