261 lines
8.8 KiB
Go
261 lines
8.8 KiB
Go
package store
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
"golang.org/x/crypto/bcrypt"
|
|
)
|
|
|
|
// ------------------------------------------------------------------
|
|
// Domain types
|
|
// ------------------------------------------------------------------
|
|
|
|
// User represents an authenticated user belonging to a tenant.
|
|
type User struct {
|
|
ID string `json:"id"`
|
|
TenantID string `json:"tenant_id"`
|
|
TenantSlug string `json:"tenant_slug"`
|
|
Username string `json:"username"`
|
|
PasswordHash string `json:"-"`
|
|
Role string `json:"role"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
}
|
|
|
|
// Session represents an authenticated session token.
|
|
type Session struct {
|
|
ID string `json:"id"`
|
|
UserID string `json:"user_id"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
ExpiresAt time.Time `json:"expires_at"`
|
|
}
|
|
|
|
// ------------------------------------------------------------------
|
|
// AuthStore
|
|
// ------------------------------------------------------------------
|
|
|
|
// AuthStore handles user authentication and session management.
|
|
type AuthStore struct{ pool *pgxpool.Pool }
|
|
|
|
// NewAuthStore creates a new AuthStore backed by pool.
|
|
func NewAuthStore(pool *pgxpool.Pool) *AuthStore { return &AuthStore{pool} }
|
|
|
|
// GetUserByUsername returns the user with the given username or pgx.ErrNoRows.
|
|
// TenantSlug is populated via LEFT JOIN on tenants.
|
|
func (s *AuthStore) GetUserByUsername(ctx context.Context, username string) (*User, error) {
|
|
row := s.pool.QueryRow(ctx,
|
|
`select u.id, u.tenant_id, coalesce(t.slug, ''), u.username, u.password_hash, u.role, u.created_at
|
|
from users u
|
|
left join tenants t on t.id = u.tenant_id
|
|
where u.username = $1`, username)
|
|
return scanUserWithSlug(row)
|
|
}
|
|
|
|
// CreateSession inserts a new session for userID with the given TTL and returns the session.
|
|
func (s *AuthStore) CreateSession(ctx context.Context, userID string, ttl time.Duration) (*Session, error) {
|
|
expiresAt := time.Now().Add(ttl)
|
|
row := s.pool.QueryRow(ctx,
|
|
`insert into sessions(user_id, expires_at)
|
|
values($1, $2)
|
|
returning id, user_id, created_at, expires_at`,
|
|
userID, expiresAt)
|
|
return scanSession(row)
|
|
}
|
|
|
|
// GetSessionUser returns the user associated with sessionID if the session is still valid.
|
|
// Returns pgx.ErrNoRows when the session does not exist or has expired.
|
|
// TenantSlug is populated via JOIN on tenants.
|
|
func (s *AuthStore) GetSessionUser(ctx context.Context, sessionID string) (*User, error) {
|
|
row := s.pool.QueryRow(ctx,
|
|
`select u.id, u.tenant_id, coalesce(t.slug, ''), u.username, u.password_hash, u.role, u.created_at
|
|
from sessions se
|
|
join users u on u.id = se.user_id
|
|
left join tenants t on t.id = u.tenant_id
|
|
where se.id = $1
|
|
and se.expires_at > now()`, sessionID)
|
|
return scanUserWithSlug(row)
|
|
}
|
|
|
|
// DeleteSession removes the session with the given ID.
|
|
func (s *AuthStore) DeleteSession(ctx context.Context, sessionID string) error {
|
|
_, err := s.pool.Exec(ctx, `delete from sessions where id = $1`, sessionID)
|
|
return err
|
|
}
|
|
|
|
// CleanExpiredSessions removes all sessions whose expires_at is in the past.
|
|
func (s *AuthStore) CleanExpiredSessions(ctx context.Context) error {
|
|
_, err := s.pool.Exec(ctx, `delete from sessions where expires_at <= now()`)
|
|
return err
|
|
}
|
|
|
|
// VerifyPassword checks if the provided password matches the hashed password for a user.
|
|
func (s *AuthStore) VerifyPassword(ctx context.Context, userID, password string) (bool, error) {
|
|
var passwordHash string
|
|
err := s.pool.QueryRow(ctx,
|
|
`select password_hash from users where id = $1`, userID).
|
|
Scan(&passwordHash)
|
|
if err != nil {
|
|
if err == pgx.ErrNoRows {
|
|
return false, nil
|
|
}
|
|
return false, fmt.Errorf("auth: get password hash: %w", err)
|
|
}
|
|
|
|
err = bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(password))
|
|
return err == nil, nil
|
|
}
|
|
|
|
// EnsureAdminUser creates an 'admin' user for the tenant identified by tenantSlug
|
|
// if no user with username 'admin' already exists. The password is hashed with bcrypt.
|
|
// bcrypt cost factor 12 is used (minimum recommended for production).
|
|
func (s *AuthStore) EnsureAdminUser(ctx context.Context, tenantSlug, password string) error {
|
|
// Check whether 'admin' user already exists for this specific tenant.
|
|
// The check must be scoped to the tenant to avoid false positives when
|
|
// another tenant already has an 'admin' user.
|
|
var exists bool
|
|
err := s.pool.QueryRow(ctx,
|
|
`select exists(
|
|
select 1 from users u
|
|
join tenants t on t.id = u.tenant_id
|
|
where u.username = $1 and t.slug = $2
|
|
)`,
|
|
"admin", tenantSlug,
|
|
).Scan(&exists)
|
|
if err != nil {
|
|
return fmt.Errorf("auth: check admin user: %w", err)
|
|
}
|
|
if exists {
|
|
return nil
|
|
}
|
|
|
|
hash, err := bcrypt.GenerateFromPassword([]byte(password), 12)
|
|
if err != nil {
|
|
return fmt.Errorf("auth: hash password: %w", err)
|
|
}
|
|
|
|
var tenantID string
|
|
err = s.pool.QueryRow(ctx, `select id from tenants where slug = $1`, tenantSlug).Scan(&tenantID)
|
|
if err != nil {
|
|
if err == pgx.ErrNoRows {
|
|
return fmt.Errorf("auth: tenant not found: %s", tenantSlug)
|
|
}
|
|
return fmt.Errorf("auth: resolve tenant: %w", err)
|
|
}
|
|
|
|
_, err = s.pool.Exec(ctx,
|
|
`insert into users(tenant_id, username, password_hash, role)
|
|
values($1, 'admin', $2, 'admin')`,
|
|
tenantID, string(hash))
|
|
if err != nil {
|
|
return fmt.Errorf("auth: create admin user: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// CreateScreenUser creates a new user with the given role for the tenant
|
|
// identified by tenantSlug. role must be "screen_user" or "restricted".
|
|
// The password is hashed with bcrypt (cost 12).
|
|
// Returns pgx.ErrNoRows if the tenant does not exist, or a wrapped error if
|
|
// the username is already taken (unique constraint violation).
|
|
func (s *AuthStore) CreateScreenUser(ctx context.Context, tenantSlug, username, password, role string) (*User, error) {
|
|
if role != "screen_user" && role != "restricted" {
|
|
return nil, fmt.Errorf("auth: invalid role: %s", role)
|
|
}
|
|
var tenantID string
|
|
err := s.pool.QueryRow(ctx, `select id from tenants where slug = $1`, tenantSlug).Scan(&tenantID)
|
|
if err != nil {
|
|
if err == pgx.ErrNoRows {
|
|
return nil, fmt.Errorf("auth: tenant not found: %s", tenantSlug)
|
|
}
|
|
return nil, fmt.Errorf("auth: resolve tenant: %w", err)
|
|
}
|
|
|
|
hash, err := bcrypt.GenerateFromPassword([]byte(password), 12)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("auth: hash password: %w", err)
|
|
}
|
|
|
|
row := s.pool.QueryRow(ctx,
|
|
`insert into users(tenant_id, username, password_hash, role)
|
|
values($1, $2, $3, $4)
|
|
returning id, tenant_id, $5::text, username, password_hash, role, created_at`,
|
|
tenantID, username, string(hash), role, tenantSlug)
|
|
u, err := scanUserWithSlug(row)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("auth: create screen user: %w", err)
|
|
}
|
|
return u, nil
|
|
}
|
|
|
|
// ListScreenUsers returns all users with role 'screen_user' or 'restricted' for the given tenant.
|
|
func (s *AuthStore) ListScreenUsers(ctx context.Context, tenantSlug string) ([]*User, error) {
|
|
rows, err := s.pool.Query(ctx,
|
|
`select u.id, u.tenant_id, coalesce(t.slug, ''), u.username, u.password_hash, u.role, u.created_at
|
|
from users u
|
|
left join tenants t on t.id = u.tenant_id
|
|
where t.slug = $1 and u.role IN ('screen_user', 'restricted')
|
|
order by u.username`, tenantSlug)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("auth: list screen users: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
var out []*User
|
|
for rows.Next() {
|
|
u, err := scanUserWithSlug(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
out = append(out, u)
|
|
}
|
|
return out, rows.Err()
|
|
}
|
|
|
|
// DeleteUser removes a user and all their session + screen permission records (CASCADE).
|
|
// It refuses to delete users with role 'admin' to prevent lockout.
|
|
func (s *AuthStore) DeleteUser(ctx context.Context, userID string) error {
|
|
tag, err := s.pool.Exec(ctx,
|
|
`delete from users where id = $1 and role != 'admin'`, userID)
|
|
if err != nil {
|
|
return fmt.Errorf("auth: delete user: %w", err)
|
|
}
|
|
if tag.RowsAffected() == 0 {
|
|
return fmt.Errorf("auth: delete user: not found or is admin")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ------------------------------------------------------------------
|
|
// scan helpers
|
|
// ------------------------------------------------------------------
|
|
|
|
// scanUserWithSlug scans a row that includes tenant_slug as the third column.
|
|
func scanUserWithSlug(row interface {
|
|
Scan(dest ...any) error
|
|
}) (*User, error) {
|
|
var u User
|
|
err := row.Scan(&u.ID, &u.TenantID, &u.TenantSlug, &u.Username, &u.PasswordHash, &u.Role, &u.CreatedAt)
|
|
if err != nil {
|
|
if err == pgx.ErrNoRows {
|
|
return nil, pgx.ErrNoRows
|
|
}
|
|
return nil, fmt.Errorf("scan user: %w", err)
|
|
}
|
|
return &u, nil
|
|
}
|
|
|
|
func scanSession(row interface {
|
|
Scan(dest ...any) error
|
|
}) (*Session, error) {
|
|
var se Session
|
|
err := row.Scan(&se.ID, &se.UserID, &se.CreatedAt, &se.ExpiresAt)
|
|
if err != nil {
|
|
if err == pgx.ErrNoRows {
|
|
return nil, pgx.ErrNoRows
|
|
}
|
|
return nil, fmt.Errorf("scan session: %w", err)
|
|
}
|
|
return &se, nil
|
|
}
|