package migrate import ( "context" "database/sql" "fmt" "os" "path/filepath" "sort" "strings" "time" ) // Dialect abstracts the few SQL differences between database engines. type Dialect interface { // Placeholder returns the bind parameter for position n (1-indexed). // Postgres: "$1", "$2". MySQL/SQLite: "?". Placeholder(n int) string // TableExistsQuery returns SQL that checks whether a table exists. // The query must accept a single bind parameter (the table name) and // return a single boolean column. TableExistsQuery() string } // Options configures the migration runner. type Options struct { // MigrationsDir is the directory containing *.up.sql migration files. // If empty, it defaults to "migrations" relative to the working directory. MigrationsDir string // BootstrapTable is the name of a table whose existence signals that the // database already has a schema applied before the schema_migrations // tracking table was introduced. When non-empty and this table exists but // schema_migrations does not, all discovered migrations are recorded as // already applied (bootstrapped). When empty, no bootstrapping is performed. BootstrapTable string } // Run executes forward migrations against db using the given dialect. // It discovers *.up.sql files in the configured directory, tracks applied // versions in a schema_migrations table, and applies any pending migrations // in filename-sorted order. func Run(ctx context.Context, db *sql.DB, dialect Dialect, opts *Options) error { if opts == nil { opts = &Options{} } migrationsTableExists := tableExists(ctx, db, dialect, "schema_migrations") if err := ensureMigrationsTable(ctx, db); err != nil { return fmt.Errorf("failed to create migrations table: %w", err) } migrationsPath := opts.MigrationsDir if migrationsPath == "" { wd, err := os.Getwd() if err != nil { return fmt.Errorf("failed to get working directory: %w", err) } migrationsPath = filepath.Join(wd, "migrations") } migrations, err := discoverMigrations(migrationsPath) if err != nil { return fmt.Errorf("failed to discover migrations: %w", err) } if !migrationsTableExists && opts.BootstrapTable != "" { if err := bootstrapMigrationState(ctx, db, dialect, migrations, opts.BootstrapTable); err != nil { return fmt.Errorf("failed to bootstrap migration state: %w", err) } } applied, err := getAppliedMigrations(ctx, db) if err != nil { return fmt.Errorf("failed to get applied migrations: %w", err) } for _, migration := range migrations { if applied[migration] { continue } migrationSQL, err := os.ReadFile(filepath.Join(migrationsPath, migration)) if err != nil { return fmt.Errorf("failed to read migration file %s: %w", migration, err) } if _, err := db.ExecContext(ctx, string(migrationSQL)); err != nil { return fmt.Errorf("failed to execute migration %s: %w", migration, err) } if err := recordMigration(ctx, db, dialect, migration); err != nil { return fmt.Errorf("failed to record migration %s: %w", migration, err) } fmt.Printf("Applied migration: %s\n", migration) } return nil } // ensureMigrationsTable creates the schema_migrations table if it doesn't exist. func ensureMigrationsTable(ctx context.Context, db *sql.DB) error { _, err := db.ExecContext(ctx, ` CREATE TABLE IF NOT EXISTS schema_migrations ( version VARCHAR(255) PRIMARY KEY, applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ) `) return err } // tableExists checks if a table exists in the database. func tableExists(ctx context.Context, db *sql.DB, dialect Dialect, tableName string) bool { var exists bool err := db.QueryRowContext(ctx, dialect.TableExistsQuery(), tableName, ).Scan(&exists) return err == nil && exists } // bootstrapMigrationState marks all discovered migrations as applied when the // database already contains a schema (detected via sentinelTable) but has no // schema_migrations table yet. func bootstrapMigrationState(ctx context.Context, db *sql.DB, dialect Dialect, migrations []string, sentinelTable string) error { if !tableExists(ctx, db, dialect, sentinelTable) { return nil } for _, migration := range migrations { if err := recordMigration(ctx, db, dialect, migration); err != nil { return fmt.Errorf("failed to record bootstrapped migration %s: %w", migration, err) } fmt.Printf("Bootstrapped migration (already applied): %s\n", migration) } return nil } // getAppliedMigrations returns a set of already-applied migration filenames. func getAppliedMigrations(ctx context.Context, db *sql.DB) (map[string]bool, error) { applied := make(map[string]bool) rows, err := db.QueryContext(ctx, "SELECT version FROM schema_migrations") if err != nil { return nil, err } defer rows.Close() for rows.Next() { var version string if err := rows.Scan(&version); err != nil { return nil, err } applied[version] = true } return applied, rows.Err() } // recordMigration records a migration version as applied. func recordMigration(ctx context.Context, db *sql.DB, dialect Dialect, version string) error { _, err := db.ExecContext(ctx, fmt.Sprintf("INSERT INTO schema_migrations (version, applied_at) VALUES (%s, %s)", dialect.Placeholder(1), dialect.Placeholder(2)), version, time.Now(), ) return err } // discoverMigrations finds all *.up.sql files in the given directory and // returns them sorted by filename (relies on numeric prefixes like // 000001_, 000002_, etc.). func discoverMigrations(migrationsPath string) ([]string, error) { entries, err := os.ReadDir(migrationsPath) if err != nil { return nil, err } var migrations []string for _, entry := range entries { if entry.IsDir() { continue } name := entry.Name() if strings.HasSuffix(name, ".up.sql") { migrations = append(migrations, name) } } sort.Strings(migrations) return migrations, nil }