update to add tests

This commit is contained in:
2026-02-20 20:54:33 -05:00
parent aceea44c90
commit af85de2226
5 changed files with 450 additions and 2 deletions

View File

@@ -0,0 +1,206 @@
package migrate
import (
"context"
"database/sql"
"os"
"path/filepath"
"testing"
_ "modernc.org/sqlite"
)
// sqliteDialect implements Dialect for SQLite.
type sqliteDialect struct{}
func (sqliteDialect) Placeholder(n int) string { return "?" }
func (sqliteDialect) TableExistsQuery() string {
return "SELECT EXISTS (SELECT 1 FROM sqlite_master WHERE type='table' AND name = ?)"
}
func openTestDB(t *testing.T) *sql.DB {
t.Helper()
db, err := sql.Open("sqlite", ":memory:")
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
t.Cleanup(func() { db.Close() })
return db
}
// writeMigrations creates .up.sql files in dir and returns the directory path.
func writeMigrations(t *testing.T, files map[string]string) string {
t.Helper()
dir := t.TempDir()
for name, content := range files {
if err := os.WriteFile(filepath.Join(dir, name), []byte(content), 0o644); err != nil {
t.Fatalf("write migration %s: %v", name, err)
}
}
return dir
}
func TestDiscoverMigrations(t *testing.T) {
dir := writeMigrations(t, map[string]string{
"000002_create_posts.up.sql": "CREATE TABLE posts (id INTEGER);",
"000001_create_users.up.sql": "CREATE TABLE users (id INTEGER);",
"000001_create_users.down.sql": "DROP TABLE users;",
"README.md": "not a migration",
})
got, err := discoverMigrations(dir)
if err != nil {
t.Fatalf("discoverMigrations: %v", err)
}
want := []string{
"000001_create_users.up.sql",
"000002_create_posts.up.sql",
}
if len(got) != len(want) {
t.Fatalf("got %d migrations, want %d", len(got), len(want))
}
for i := range want {
if got[i] != want[i] {
t.Errorf("migration[%d] = %q, want %q", i, got[i], want[i])
}
}
}
func TestDiscoverMigrations_empty(t *testing.T) {
dir := t.TempDir()
got, err := discoverMigrations(dir)
if err != nil {
t.Fatalf("discoverMigrations: %v", err)
}
if got != nil {
t.Fatalf("got %v, want nil", got)
}
}
func TestRun(t *testing.T) {
db := openTestDB(t)
ctx := context.Background()
dialect := sqliteDialect{}
dir := writeMigrations(t, map[string]string{
"000001_create_users.up.sql": "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT);",
"000002_create_posts.up.sql": "CREATE TABLE posts (id INTEGER PRIMARY KEY, user_id INTEGER);",
})
err := Run(ctx, db, dialect, &Options{MigrationsDir: dir})
if err != nil {
t.Fatalf("Run: %v", err)
}
// Verify tables were created.
for _, table := range []string{"users", "posts", "schema_migrations"} {
if !tableExists(ctx, db, dialect, table) {
t.Errorf("table %q should exist", table)
}
}
// Verify migration records.
applied, err := getAppliedMigrations(ctx, db)
if err != nil {
t.Fatalf("getAppliedMigrations: %v", err)
}
if !applied["000001_create_users.up.sql"] || !applied["000002_create_posts.up.sql"] {
t.Errorf("applied = %v, want both migrations", applied)
}
}
func TestRun_idempotent(t *testing.T) {
db := openTestDB(t)
ctx := context.Background()
dialect := sqliteDialect{}
dir := writeMigrations(t, map[string]string{
"000001_create_users.up.sql": "CREATE TABLE users (id INTEGER PRIMARY KEY);",
})
opts := &Options{MigrationsDir: dir}
if err := Run(ctx, db, dialect, opts); err != nil {
t.Fatalf("first Run: %v", err)
}
if err := Run(ctx, db, dialect, opts); err != nil {
t.Fatalf("second Run: %v", err)
}
// Still only one record.
applied, err := getAppliedMigrations(ctx, db)
if err != nil {
t.Fatalf("getAppliedMigrations: %v", err)
}
if len(applied) != 1 {
t.Errorf("got %d applied migrations, want 1", len(applied))
}
}
func TestRun_nilOpts(t *testing.T) {
db := openTestDB(t)
ctx := context.Background()
dialect := sqliteDialect{}
// nil opts should not panic; it will use cwd/migrations which won't exist,
// so we expect an error but not a panic.
err := Run(ctx, db, dialect, nil)
if err == nil {
t.Fatal("expected error with nil opts (no migrations dir), got nil")
}
}
func TestRun_bootstrap(t *testing.T) {
db := openTestDB(t)
ctx := context.Background()
dialect := sqliteDialect{}
// Create the sentinel table to simulate an existing schema.
if _, err := db.ExecContext(ctx, "CREATE TABLE my_app (id INTEGER)"); err != nil {
t.Fatalf("create sentinel: %v", err)
}
dir := writeMigrations(t, map[string]string{
"000001_create_users.up.sql": "CREATE TABLE users (id INTEGER PRIMARY KEY);",
"000002_create_posts.up.sql": "CREATE TABLE posts (id INTEGER PRIMARY KEY);",
})
err := Run(ctx, db, dialect, &Options{
MigrationsDir: dir,
BootstrapTable: "my_app",
})
if err != nil {
t.Fatalf("Run: %v", err)
}
// Migrations should be recorded but NOT executed (tables should not exist).
applied, err := getAppliedMigrations(ctx, db)
if err != nil {
t.Fatalf("getAppliedMigrations: %v", err)
}
if len(applied) != 2 {
t.Fatalf("got %d applied, want 2 (bootstrapped)", len(applied))
}
// The actual tables should NOT have been created.
if tableExists(ctx, db, dialect, "users") {
t.Error("users table should not exist after bootstrap")
}
if tableExists(ctx, db, dialect, "posts") {
t.Error("posts table should not exist after bootstrap")
}
}
func TestRun_noMigrationsDir(t *testing.T) {
db := openTestDB(t)
ctx := context.Background()
dialect := sqliteDialect{}
err := Run(ctx, db, dialect, &Options{MigrationsDir: "/nonexistent/path"})
if err == nil {
t.Fatal("expected error for nonexistent dir, got nil")
}
}

View File

@@ -1,6 +1,12 @@
package postgres
import "fmt"
import (
"fmt"
"git.nonahob.net/jacob/golibs/datastores/sql/migrate"
)
var _ migrate.Dialect = Dialect{}
// Dialect implements migrate.Dialect for PostgreSQL.
type Dialect struct{}

View File

@@ -0,0 +1,164 @@
package postgres
import (
"context"
"testing"
"time"
)
func TestDSN(t *testing.T) {
c := &Config{
Host: "db.example.com",
Port: "5433",
User: "alice",
Password: "s3cret",
DBName: "mydb",
SSLMode: "require",
}
want := "host=db.example.com port=5433 user=alice password=s3cret dbname=mydb sslmode=require"
if got := c.DSN(); got != want {
t.Errorf("DSN() = %q, want %q", got, want)
}
}
func TestNewConfig_defaults(t *testing.T) {
// Clear all env vars that NewConfig reads so we get pure defaults.
for _, key := range []string{"DB_HOST", "DB_PORT", "DB_USER", "DB_PASSWORD", "DB_NAME", "DB_SSL_MODE"} {
t.Setenv(key, "")
}
c := NewConfig()
if c.Host != "localhost" {
t.Errorf("Host = %q, want %q", c.Host, "localhost")
}
if c.Port != "5432" {
t.Errorf("Port = %q, want %q", c.Port, "5432")
}
if c.User != "postgres" {
t.Errorf("User = %q, want %q", c.User, "postgres")
}
if c.Password != "postgres" {
t.Errorf("Password = %q, want %q", c.Password, "postgres")
}
if c.DBName != "" {
t.Errorf("DBName = %q, want empty", c.DBName)
}
if c.SSLMode != "disable" {
t.Errorf("SSLMode = %q, want %q", c.SSLMode, "disable")
}
}
func TestNewConfig_envOverrides(t *testing.T) {
t.Setenv("DB_HOST", "remotehost")
t.Setenv("DB_PORT", "9999")
t.Setenv("DB_USER", "bob")
t.Setenv("DB_PASSWORD", "hunter2")
t.Setenv("DB_NAME", "testdb")
t.Setenv("DB_SSL_MODE", "verify-full")
c := NewConfig()
if c.Host != "remotehost" {
t.Errorf("Host = %q, want %q", c.Host, "remotehost")
}
if c.Port != "9999" {
t.Errorf("Port = %q, want %q", c.Port, "9999")
}
if c.User != "bob" {
t.Errorf("User = %q, want %q", c.User, "bob")
}
if c.Password != "hunter2" {
t.Errorf("Password = %q, want %q", c.Password, "hunter2")
}
if c.DBName != "testdb" {
t.Errorf("DBName = %q, want %q", c.DBName, "testdb")
}
if c.SSLMode != "verify-full" {
t.Errorf("SSLMode = %q, want %q", c.SSLMode, "verify-full")
}
}
func TestNewConfig_withContext(t *testing.T) {
ctx := context.WithValue(context.Background(), struct{}{}, "test")
c := NewConfig(WithContext(ctx))
if c.Context() != ctx {
t.Error("WithContext option did not set the context")
}
}
func TestContext_default(t *testing.T) {
c := &Config{}
ctx := c.Context()
if ctx == nil {
t.Fatal("Context() returned nil")
}
// Should return background context.
if ctx != context.Background() {
t.Error("Context() should return context.Background() when none set")
}
}
func TestContext_set(t *testing.T) {
ctx := t.Context()
c := &Config{ctx: ctx}
if c.Context() != ctx {
t.Error("Context() did not return the set context")
}
}
func TestDialect_Placeholder(t *testing.T) {
d := Dialect{}
tests := []struct {
n int
want string
}{
{1, "$1"},
{2, "$2"},
{10, "$10"},
{100, "$100"},
}
for _, tt := range tests {
if got := d.Placeholder(tt.n); got != tt.want {
t.Errorf("Placeholder(%d) = %q, want %q", tt.n, got, tt.want)
}
}
}
func TestDialect_TableExistsQuery(t *testing.T) {
q := Dialect{}.TableExistsQuery()
if q == "" {
t.Fatal("TableExistsQuery() returned empty string")
}
// Should reference information_schema and use $1 placeholder.
if got := q; got != "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = $1)" {
t.Errorf("TableExistsQuery() = %q", got)
}
}
func TestDefaultPoolConfig(t *testing.T) {
pc := DefaultPoolConfig()
if pc.MaxConns != 10 {
t.Errorf("MaxConns = %d, want 10", pc.MaxConns)
}
if pc.MinConns != 2 {
t.Errorf("MinConns = %d, want 2", pc.MinConns)
}
if pc.MaxConnLifetime != time.Hour {
t.Errorf("MaxConnLifetime = %v, want %v", pc.MaxConnLifetime, time.Hour)
}
if pc.MaxConnIdleTime != 30*time.Minute {
t.Errorf("MaxConnIdleTime = %v, want %v", pc.MaxConnIdleTime, 30*time.Minute)
}
if pc.HealthCheckPeriod != time.Minute {
t.Errorf("HealthCheckPeriod = %v, want %v", pc.HealthCheckPeriod, time.Minute)
}
}
func TestClosePool_nil(t *testing.T) {
// Should not panic.
ClosePool(nil)
}