241 lines
5.2 KiB
Go
241 lines
5.2 KiB
Go
package sqlite
|
|
|
|
import (
|
|
"cmp"
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"io/fs"
|
|
"log/slog"
|
|
"path"
|
|
"slices"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"git.loyso.art/frx/kurious/internal/common/xcontext"
|
|
)
|
|
|
|
type migrationUnit struct {
|
|
num int
|
|
name string
|
|
path string
|
|
}
|
|
|
|
func (u migrationUnit) apply(ctx context.Context, tx *sql.Tx) error {
|
|
content, err := fs.ReadFile(migrations, u.path)
|
|
if err != nil {
|
|
return fmt.Errorf("reading file: %w", err)
|
|
}
|
|
|
|
query := string(content)
|
|
_, err = tx.ExecContext(ctx, query)
|
|
if err != nil {
|
|
return fmt.Errorf("executing query: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func sortMigrationUnit(lhs, rhs migrationUnit) int {
|
|
if lhs.num < rhs.num {
|
|
return -1
|
|
} else if lhs.num > rhs.num {
|
|
return 1
|
|
} else {
|
|
return 0
|
|
}
|
|
}
|
|
|
|
func RunMigrations(ctx context.Context, db *sql.DB, log *slog.Logger) error {
|
|
items, err := getMigrationEntries()
|
|
if err != nil {
|
|
return fmt.Errorf("reading directory: %w", err)
|
|
}
|
|
|
|
units := make([]migrationUnit, 0, len(items))
|
|
for _, item := range items {
|
|
if item.IsDir() {
|
|
continue
|
|
}
|
|
|
|
itemName := item.Name()
|
|
splitted := strings.SplitN(itemName, "_", 2)
|
|
if len(splitted) != 2 {
|
|
return fmt.Errorf("bad number of parts, expected 2, got %d", len(splitted))
|
|
}
|
|
|
|
splittedNum, err := strconv.Atoi(splitted[0])
|
|
if err != nil {
|
|
return fmt.Errorf("parsing migration number: %w", err)
|
|
}
|
|
|
|
if splittedNum < 1 {
|
|
return fmt.Errorf("migration number expected to be greater than 0, but got %d", splittedNum)
|
|
}
|
|
|
|
unit := migrationUnit{
|
|
num: splittedNum,
|
|
name: strings.TrimSuffix(splitted[1], ".sql"),
|
|
path: path.Join(itemName),
|
|
}
|
|
|
|
xcontext.LogDebug(ctx, log, "found migration unit", slog.Any("unit", unit))
|
|
|
|
units = append(units, unit)
|
|
}
|
|
|
|
slices.SortFunc(units, sortMigrationUnit)
|
|
|
|
mr := &metaRepository{
|
|
db: db,
|
|
log: log,
|
|
}
|
|
err = mr.prepare(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("preparing meta repository: %w", err)
|
|
}
|
|
|
|
count, err := mr.run(ctx, units...)
|
|
if err != nil {
|
|
return fmt.Errorf("running transaction: %w", err)
|
|
}
|
|
|
|
if count > 0 {
|
|
xcontext.LogInfo(ctx, log, "some new migrations has been applied", slog.Int("count", count))
|
|
} else {
|
|
xcontext.LogDebug(ctx, log, "no new migrations has been applied")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type metaRepository struct {
|
|
db *sql.DB
|
|
log *slog.Logger
|
|
lastAppliedNumber int
|
|
}
|
|
|
|
func (r *metaRepository) run(ctx context.Context, units ...migrationUnit) (count int, err error) {
|
|
idx, found := slices.BinarySearchFunc(units, r.lastAppliedNumber, func(mu migrationUnit, i int) int {
|
|
return cmp.Compare(mu.num, i)
|
|
})
|
|
if !found && r.lastAppliedNumber > 0 {
|
|
return 0, fmt.Errorf("migration %d stored in meta was not found in provided migrations", r.lastAppliedNumber)
|
|
} else if r.lastAppliedNumber > 0 {
|
|
idx++
|
|
}
|
|
|
|
xcontext.LogDebug(
|
|
ctx, r.log,
|
|
"starting to apply migrations",
|
|
slog.Int("last_applied_migration", r.lastAppliedNumber),
|
|
slog.Int("next_migration_idx", idx),
|
|
)
|
|
|
|
tx, err := r.db.BeginTx(ctx, &sql.TxOptions{
|
|
Isolation: sql.LevelDefault,
|
|
ReadOnly: false,
|
|
})
|
|
if err != nil {
|
|
return 0, fmt.Errorf("starting transaction: %w", err)
|
|
}
|
|
|
|
defer func() {
|
|
var errtx error
|
|
if err != nil {
|
|
xcontext.LogError(ctx, r.log, "rolling back migration changes due to error")
|
|
errtx = tx.Rollback()
|
|
} else {
|
|
xcontext.LogDebug(ctx, r.log, "commiting migration changes")
|
|
errtx = tx.Commit()
|
|
}
|
|
|
|
err = errors.Join(err, errtx)
|
|
}()
|
|
|
|
for i := idx; i < len(units); i++ {
|
|
unit := units[i]
|
|
err = unit.apply(ctx, tx)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("unable to apply migration %q: %w", unit.name, err)
|
|
}
|
|
|
|
err = r.adjustMigrationApplied(ctx, tx, unit)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("storing migration process info: %w", err)
|
|
}
|
|
|
|
xcontext.LogInfo(
|
|
ctx, r.log, "migration unit applied",
|
|
slog.Int("number", unit.num),
|
|
slog.String("name", unit.name),
|
|
)
|
|
count++
|
|
}
|
|
|
|
return count, nil
|
|
}
|
|
|
|
func (r *metaRepository) prepare(ctx context.Context) error {
|
|
err := r.makeTable(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("making table: %w", err)
|
|
}
|
|
|
|
err = r.loadLastAppliedMigration(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("loading last applied migration: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *metaRepository) makeTable(ctx context.Context) error {
|
|
const query = `CREATE TABLE IF NOT EXISTS migration_meta (` +
|
|
` id INT PRIMARY KEY NOT NULL` +
|
|
`, name TEXT NOT NULL` +
|
|
`, applied_at INT NOT NULL` +
|
|
`);`
|
|
|
|
_, err := r.db.ExecContext(ctx, query)
|
|
if err != nil {
|
|
return fmt.Errorf("executing query: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *metaRepository) loadLastAppliedMigration(ctx context.Context) error {
|
|
const query = `SELECT COALESCE(MAX(id), 0) FROM migration_meta;`
|
|
|
|
err := r.db.QueryRowContext(ctx, query).Scan(&r.lastAppliedNumber)
|
|
if err != nil {
|
|
if !errors.Is(err, sql.ErrNoRows) {
|
|
return fmt.Errorf("executing query: %w", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *metaRepository) adjustMigrationApplied(ctx context.Context, tx *sql.Tx, unit migrationUnit) error {
|
|
const query = `INSERT INTO migration_meta (id, name, applied_at) VALUES (?, ?, ?)`
|
|
|
|
args := []any{
|
|
unit.num,
|
|
unit.name,
|
|
time.Now().Truncate(time.Second).Unix(),
|
|
}
|
|
|
|
_, err := tx.ExecContext(ctx, query, args...)
|
|
if err != nil {
|
|
return fmt.Errorf("executing query: %w", err)
|
|
}
|
|
|
|
r.lastAppliedNumber = unit.num
|
|
|
|
return nil
|
|
}
|