package db import ( "context" "embed" "fmt" "log" "github.com/jackc/pgx/v5/pgxpool" ) //go:embed migrations/*.sql var migrationsFS embed.FS // Pool wraps a pgxpool.Pool with migration support. type Pool struct { *pgxpool.Pool } // Connect opens a connection pool and runs pending migrations. func Connect(ctx context.Context, databaseURL string, logger *log.Logger) (*Pool, error) { pool, err := pgxpool.New(ctx, databaseURL) if err != nil { return nil, fmt.Errorf("db: open pool: %w", err) } if err := pool.Ping(ctx); err != nil { pool.Close() return nil, fmt.Errorf("db: ping: %w", err) } p := &Pool{pool} if err := p.migrate(ctx, logger); err != nil { pool.Close() return nil, fmt.Errorf("db: migrate: %w", err) } return p, nil } // migrate runs all embedded SQL migration files in order (idempotent). func (p *Pool) migrate(ctx context.Context, logger *log.Logger) error { // Ensure schema_migrations table exists first. _, err := p.Exec(ctx, ` create table if not exists schema_migrations ( version integer primary key, applied_at timestamptz not null default now() )`) if err != nil { return fmt.Errorf("create schema_migrations: %w", err) } entries, err := migrationsFS.ReadDir("migrations") if err != nil { return err } for i, e := range entries { version := i + 1 var applied bool err := p.QueryRow(ctx, "select exists(select 1 from schema_migrations where version=$1)", version, ).Scan(&applied) if err != nil { return fmt.Errorf("check migration %d: %w", version, err) } if applied { continue } sql, err := migrationsFS.ReadFile("migrations/" + e.Name()) if err != nil { return err } if _, err := p.Exec(ctx, string(sql)); err != nil { return fmt.Errorf("run migration %s: %w", e.Name(), err) } if _, err := p.Exec(ctx, "insert into schema_migrations(version) values($1)", version, ); err != nil { return fmt.Errorf("record migration %d: %w", version, err) } logger.Printf("event=migration_applied version=%d file=%s", version, e.Name()) } return nil }