package sqlite import ( "cmp" "context" "database/sql" "errors" "fmt" "io/fs" "log/slog" "path" "slices" "strconv" "strings" "time" "git.loyso.art/frx/kurious/internal/common/xcontext" ) type migrationUnit struct { num int name string path string } func (u migrationUnit) apply(ctx context.Context, tx *sql.Tx) error { content, err := fs.ReadFile(migrations, u.path) if err != nil { return fmt.Errorf("reading file: %w", err) } query := string(content) _, err = tx.ExecContext(ctx, query) if err != nil { return fmt.Errorf("executing query: %w", err) } return nil } func sortMigrationUnit(lhs, rhs migrationUnit) int { if lhs.num < rhs.num { return -1 } else if lhs.num > rhs.num { return 1 } else { return 0 } } func RunMigrations(ctx context.Context, db *sql.DB, log *slog.Logger) error { items, err := getMigrationEntries() if err != nil { return fmt.Errorf("reading directory: %w", err) } units := make([]migrationUnit, 0, len(items)) for _, item := range items { if item.IsDir() { continue } itemName := item.Name() splitted := strings.SplitN(itemName, "_", 2) if len(splitted) != 2 { return fmt.Errorf("bad number of parts, expected 2, got %d", len(splitted)) } splittedNum, err := strconv.Atoi(splitted[0]) if err != nil { return fmt.Errorf("parsing migration number: %w", err) } if splittedNum < 1 { return fmt.Errorf("migration number expected to be greater than 0, but got %d", splittedNum) } unit := migrationUnit{ num: splittedNum, name: strings.TrimSuffix(splitted[1], ".sql"), path: path.Join(itemName), } xcontext.LogDebug(ctx, log, "found migration unit", slog.Any("unit", unit)) units = append(units, unit) } slices.SortFunc(units, sortMigrationUnit) mr := &metaRepository{ db: db, log: log, } err = mr.prepare(ctx) if err != nil { return fmt.Errorf("preparing meta repository: %w", err) } count, err := mr.run(ctx, units...) if err != nil { return fmt.Errorf("running transaction: %w", err) } if count > 0 { xcontext.LogInfo(ctx, log, "some new migrations has been applied", slog.Int("count", count)) } else { xcontext.LogDebug(ctx, log, "no new migrations has been applied") } return nil } type metaRepository struct { db *sql.DB log *slog.Logger lastAppliedNumber int } func (r *metaRepository) run(ctx context.Context, units ...migrationUnit) (count int, err error) { idx, found := slices.BinarySearchFunc(units, r.lastAppliedNumber, func(mu migrationUnit, i int) int { return cmp.Compare(mu.num, i) }) if !found && r.lastAppliedNumber > 0 { return 0, fmt.Errorf("migration %d stored in meta was not found in provided migrations", r.lastAppliedNumber) } else if r.lastAppliedNumber > 0 { idx++ } xcontext.LogDebug( ctx, r.log, "starting to apply migrations", slog.Int("last_applied_migration", r.lastAppliedNumber), slog.Int("next_migration_idx", idx), ) tx, err := r.db.BeginTx(ctx, &sql.TxOptions{ Isolation: sql.LevelDefault, ReadOnly: false, }) if err != nil { return 0, fmt.Errorf("starting transaction: %w", err) } defer func() { var errtx error if err != nil { xcontext.LogError(ctx, r.log, "rolling back migration changes due to error") errtx = tx.Rollback() } else { xcontext.LogDebug(ctx, r.log, "commiting migration changes") errtx = tx.Commit() } err = errors.Join(err, errtx) }() for i := idx; i < len(units); i++ { unit := units[i] err = unit.apply(ctx, tx) if err != nil { return 0, fmt.Errorf("unable to apply migration %q: %w", unit.name, err) } err = r.adjustMigrationApplied(ctx, tx, unit) if err != nil { return 0, fmt.Errorf("storing migration process info: %w", err) } xcontext.LogInfo( ctx, r.log, "migration unit applied", slog.Int("number", unit.num), slog.String("name", unit.name), ) count++ } return count, nil } func (r *metaRepository) prepare(ctx context.Context) error { err := r.makeTable(ctx) if err != nil { return fmt.Errorf("making table: %w", err) } err = r.loadLastAppliedMigration(ctx) if err != nil { return fmt.Errorf("loading last applied migration: %w", err) } return nil } func (r *metaRepository) makeTable(ctx context.Context) error { const query = `CREATE TABLE IF NOT EXISTS migration_meta (` + ` id INT PRIMARY KEY NOT NULL` + `, name TEXT NOT NULL` + `, applied_at INT NOT NULL` + `);` _, err := r.db.ExecContext(ctx, query) if err != nil { return fmt.Errorf("executing query: %w", err) } return nil } func (r *metaRepository) loadLastAppliedMigration(ctx context.Context) error { const query = `SELECT COALESCE(MAX(id), 0) FROM migration_meta;` err := r.db.QueryRowContext(ctx, query).Scan(&r.lastAppliedNumber) if err != nil { if !errors.Is(err, sql.ErrNoRows) { return fmt.Errorf("executing query: %w", err) } } return nil } func (r *metaRepository) adjustMigrationApplied(ctx context.Context, tx *sql.Tx, unit migrationUnit) error { const query = `INSERT INTO migration_meta (id, name, applied_at) VALUES (?, ?, ?)` args := []any{ unit.num, unit.name, time.Now().Truncate(time.Second).Unix(), } _, err := tx.ExecContext(ctx, query, args...) if err != nil { return fmt.Errorf("executing query: %w", err) } r.lastAppliedNumber = unit.num return nil }