202 lines
5.7 KiB
Go
202 lines
5.7 KiB
Go
package postgres
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"sort"
|
|
"strings"
|
|
"time"
|
|
|
|
_ "github.com/jackc/pgx/v5/stdlib" // Register pgx driver for database/sql
|
|
)
|
|
|
|
// MigrateOptions configures the migration runner.
|
|
type MigrateOptions struct {
|
|
// MigrationsDir is the directory containing *.up.sql migration files.
|
|
// If empty, it defaults to "migrations" relative to the working directory.
|
|
MigrationsDir string
|
|
|
|
// BootstrapTable is the name of a table whose existence signals that the
|
|
// database already has a schema applied before the schema_migrations
|
|
// tracking table was introduced. When non-empty and this table exists but
|
|
// schema_migrations does not, all discovered migrations are recorded as
|
|
// already applied (bootstrapped). When empty, no bootstrapping is performed.
|
|
BootstrapTable string
|
|
}
|
|
|
|
// Migrate runs forward (up) migrations using the config's DSN.
|
|
// It discovers *.up.sql files in the configured directory, tracks applied
|
|
// versions in a schema_migrations table, and applies any pending migrations
|
|
// in filename-sorted order.
|
|
func (c *Config) Migrate(opts *MigrateOptions) error {
|
|
if opts == nil {
|
|
opts = &MigrateOptions{}
|
|
}
|
|
|
|
db, err := sql.Open("pgx", c.DSN())
|
|
if err != nil {
|
|
return fmt.Errorf("failed to open database: %w", err)
|
|
}
|
|
defer db.Close()
|
|
|
|
if err := db.Ping(); err != nil {
|
|
return fmt.Errorf("failed to ping database: %w", err)
|
|
}
|
|
|
|
ctx := context.Background()
|
|
|
|
migrationsTableExists := tableExists(ctx, db, "schema_migrations")
|
|
|
|
if err := ensureMigrationsTable(ctx, db); err != nil {
|
|
return fmt.Errorf("failed to create migrations table: %w", err)
|
|
}
|
|
|
|
migrationsPath := opts.MigrationsDir
|
|
if migrationsPath == "" {
|
|
wd, err := os.Getwd()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get working directory: %w", err)
|
|
}
|
|
migrationsPath = filepath.Join(wd, "migrations")
|
|
}
|
|
|
|
migrations, err := discoverMigrations(migrationsPath)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to discover migrations: %w", err)
|
|
}
|
|
|
|
// Bootstrap: if schema_migrations was just created but the database already
|
|
// has a schema, mark all discovered migrations as already applied.
|
|
if !migrationsTableExists && opts.BootstrapTable != "" {
|
|
if err := bootstrapMigrationState(ctx, db, migrations, opts.BootstrapTable); err != nil {
|
|
return fmt.Errorf("failed to bootstrap migration state: %w", err)
|
|
}
|
|
}
|
|
|
|
applied, err := getAppliedMigrations(ctx, db)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get applied migrations: %w", err)
|
|
}
|
|
|
|
for _, migration := range migrations {
|
|
if applied[migration] {
|
|
continue
|
|
}
|
|
|
|
migrationSQL, err := os.ReadFile(filepath.Join(migrationsPath, migration))
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read migration file %s: %w", migration, err)
|
|
}
|
|
|
|
if _, err := db.ExecContext(ctx, string(migrationSQL)); err != nil {
|
|
return fmt.Errorf("failed to execute migration %s: %w", migration, err)
|
|
}
|
|
|
|
if err := recordMigration(ctx, db, migration); err != nil {
|
|
return fmt.Errorf("failed to record migration %s: %w", migration, err)
|
|
}
|
|
|
|
fmt.Printf("Applied migration: %s\n", migration)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ensureMigrationsTable creates the schema_migrations table if it doesn't exist.
|
|
func ensureMigrationsTable(ctx context.Context, db *sql.DB) error {
|
|
_, err := db.ExecContext(ctx, `
|
|
CREATE TABLE IF NOT EXISTS schema_migrations (
|
|
version VARCHAR(255) PRIMARY KEY,
|
|
applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
`)
|
|
return err
|
|
}
|
|
|
|
// tableExists checks if a table exists in the database.
|
|
func tableExists(ctx context.Context, db *sql.DB, tableName string) bool {
|
|
var exists bool
|
|
err := db.QueryRowContext(ctx,
|
|
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = $1)",
|
|
tableName,
|
|
).Scan(&exists)
|
|
return err == nil && exists
|
|
}
|
|
|
|
// bootstrapMigrationState marks all discovered migrations as applied when the
|
|
// database already contains a schema (detected via sentinelTable) but has no
|
|
// schema_migrations table yet.
|
|
func bootstrapMigrationState(ctx context.Context, db *sql.DB, migrations []string, sentinelTable string) error {
|
|
if !tableExists(ctx, db, sentinelTable) {
|
|
// Fresh database, no bootstrapping needed.
|
|
return nil
|
|
}
|
|
|
|
for _, migration := range migrations {
|
|
if err := recordMigration(ctx, db, migration); err != nil {
|
|
return fmt.Errorf("failed to record bootstrapped migration %s: %w", migration, err)
|
|
}
|
|
fmt.Printf("Bootstrapped migration (already applied): %s\n", migration)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// getAppliedMigrations returns a set of already-applied migration filenames.
|
|
func getAppliedMigrations(ctx context.Context, db *sql.DB) (map[string]bool, error) {
|
|
applied := make(map[string]bool)
|
|
|
|
rows, err := db.QueryContext(ctx, "SELECT version FROM schema_migrations")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var version string
|
|
if err := rows.Scan(&version); err != nil {
|
|
return nil, err
|
|
}
|
|
applied[version] = true
|
|
}
|
|
|
|
return applied, rows.Err()
|
|
}
|
|
|
|
// recordMigration records a migration version as applied.
|
|
func recordMigration(ctx context.Context, db *sql.DB, version string) error {
|
|
_, err := db.ExecContext(ctx,
|
|
"INSERT INTO schema_migrations (version, applied_at) VALUES ($1, $2)",
|
|
version, time.Now(),
|
|
)
|
|
return err
|
|
}
|
|
|
|
// discoverMigrations finds all *.up.sql files in the given directory and
|
|
// returns them sorted by filename (relies on numeric prefixes like
|
|
// 000001_, 000002_, etc.).
|
|
func discoverMigrations(migrationsPath string) ([]string, error) {
|
|
entries, err := os.ReadDir(migrationsPath)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var migrations []string
|
|
for _, entry := range entries {
|
|
if entry.IsDir() {
|
|
continue
|
|
}
|
|
name := entry.Name()
|
|
if strings.HasSuffix(name, ".up.sql") {
|
|
migrations = append(migrations, name)
|
|
}
|
|
}
|
|
|
|
sort.Strings(migrations)
|
|
|
|
return migrations, nil
|
|
}
|