Files
smthqueue/internal/postgres/queue.go
2023-10-31 23:26:48 +03:00

276 lines
6.5 KiB
Go

package postgres
import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"golang.org/x/exp/slog"
)
type QueuedMessage struct {
// ID of the message. Unique.
ID string
Topic string
Payload []byte
CreatedAt time.Time
UpdatedAt time.Time
// visibleAt utilizes visibility in the queue. This is used
// if message was not commited for a long time. (Acked or Nacked)
visibleAt time.Time
// versionID of the message. For future use to provide consistency.
// In case message is tried to be acked with different version id,
// this message will be discarded and error will be returned.
versionID uint64
}
type DequeueParams struct {
// Timeout sets visibleAfter value to the future. Used in case
// retry is needed, if this message should be handled for sure
// atleast once.
// If timeout is 0, this message will be deleted from queue.
Timeout time.Duration
Topic string
}
type EnqueueParams struct {
// Topic to which this message belongs to.
Topic string
// Payload of the message.
Payload []byte
// VisibleTimeout
VisibleTimeout time.Duration
}
type QueueClient interface {
// Enqueue a message into message bus.
Enqueue(context.Context, EnqueueParams) (QueuedMessage, error)
// Dequeue a message from message bus.
Dequeue(context.Context, DequeueParams) (QueuedMessage, error)
// Ack removes message from queue in case versionID matches.
Ack(context.Context, QueuedMessage) error
// Nack sets visibilityAfter to now() and also updates versionID.
Nack(context.Context, QueuedMessage) error
}
func (c *client) Queue() QueueClient {
return &queueClient{
pool: c.pool,
logger: c.log.WithGroup("queue"),
}
}
type queueClient struct {
pool *pgxpool.Pool
logger *slog.Logger
}
const tableCreateQuery = `
CREATE TABLE smth.queue (
id TEXT NOT NULL,
version_id BIGINT NOT NULL,
topic TEXT NOT NULL DEFAULT '',
payload BYTEA NULL,
created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL,
updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL,
visible_at TIMESTAMP WITHOUT TIME ZONE NOT NULL,
CONSTRAINT PRIMARY KEY (id)
);
CREATE INDEX queue_visible_at_idx ON smth.queue(visible_at ASC);
CREATE INDEX queue_topic_idx ON smth.queue(topic);`
func (c *queueClient) scanIntoMessage(row pgx.Row) (qm QueuedMessage, err error) {
err = row.Scan(
&qm.ID,
&qm.Topic,
&qm.Payload,
&qm.CreatedAt,
&qm.UpdatedAt,
&qm.visibleAt,
&qm.versionID,
)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return qm, ErrNoMessage
}
return qm, fmt.Errorf("scanning row: %w", err)
}
return qm, nil
}
func (c *queueClient) Enqueue(ctx context.Context, params EnqueueParams) (qm QueuedMessage, err error) {
const initialVersion = 1
const query = `INSERT INTO smth.queue
(id, topic, payload, created_at, updated_at, visible_at, version_id)
VALUES
($1, $2, $3, NOW(), NOW(), $4, $5)
RETURNING
id, topic, payload, created_at, updated_at, visible_at, version_id`
args := []any{
c.generateID(),
params.Topic,
params.Payload,
time.Now().Add(params.VisibleTimeout).UTC(),
initialVersion,
}
qt := traceMethod(ctx, c.logger, "Enqueue")
defer qt.finish(&err)
qt.query(query, args...)
qm, err = c.scanIntoMessage(c.pool.QueryRow(ctx, query, args...))
if err != nil {
return qm, fmt.Errorf("scanning row: %w", err)
}
return qm, nil
}
func (c *queueClient) Dequeue(ctx context.Context, params DequeueParams) (qm QueuedMessage, err error) {
const queryDeletePrefix = `DELETE FROM smth.queue`
const queryUpdatePrefix = `UPDATE smth.queue SET updated_at = NOW(), version_id = version_id + 1, visible_at = $2`
const querySuffix = ` WHERE id = (
SELECT id
FROM smth.queue
WHERE
topic = $1
and visible_at < NOW()
ORDER BY visible_at ASC
LIMIT 1
FOR UPDATE SKIP LOCKED
) RETURNING id, topic, payload, created_at, updated_at, visible_at, version_id`
var query string
args := append(
make([]any, 0, 2),
params.Topic,
)
if params.Timeout == 0 {
query = queryDeletePrefix + querySuffix
} else {
query = queryUpdatePrefix + querySuffix
args = append(args, time.Now().UTC().Add(params.Timeout))
}
qt := traceMethod(ctx, c.logger, "Dequeue")
defer qt.finish(&err)
qt.query(query, args...)
qm, err = c.scanIntoMessage(c.pool.QueryRow(ctx, query, args...))
if err != nil {
return qm, fmt.Errorf("querying: %w", err)
}
return qm, nil
}
func (c *queueClient) Ack(ctx context.Context, qm QueuedMessage) (err error) {
const query = `DELETE FROM smth.queue` +
` WHERE id = $1` +
` AND version_id = $2`
return c.modifyByMessage(ctx, qm, "Ack", query)
}
func (c *queueClient) Nack(ctx context.Context, qm QueuedMessage) error {
const query = `UPDATE smth.queue SET` +
` visible_at = TO_TIMESTMAP(0)` +
`, updated_at = NOW()` +
` WHERE id = $1 AND version_id = $2`
return c.modifyByMessage(ctx, qm, "Nack", query)
}
func (c *queueClient) modifyByMessage(ctx context.Context, qm QueuedMessage, method, query string) (err error) {
if qm.versionID == 0 {
panic("queued message was not fetched")
}
args := []any{
&qm.ID,
&qm.versionID,
}
qt := traceMethod(ctx, c.logger, method)
defer qt.finish(&err)
qt.query(query, args...)
tag, err := c.pool.Exec(ctx, query, args...)
if err != nil {
return fmt.Errorf("executing query: %w", err)
}
affected := tag.RowsAffected()
qt.setResultCount(affected)
if affected == 0 {
return ErrVersionIDMismatch
}
return nil
}
func (c *queueClient) generateID() string {
var idByte [8]byte
_, _ = rand.Read(idByte[:])
return hex.EncodeToString(idByte[:])
}
type queryTracer struct {
ctx context.Context
logger *slog.Logger
start time.Time
count int64
}
func (qt *queryTracer) query(query string, args ...any) {
qt.logger.DebugContext(qt.ctx, "executing query", slog.String("query", query), slog.Any("args", args))
}
func (qt *queryTracer) setResultCount(count int64) {
qt.count = count
}
func (qt *queryTracer) finish(errptr *error) {
var err error
if errptr != nil {
err = *errptr
}
var level slog.Level
if err == nil {
level = slog.LevelDebug
} else {
level = slog.LevelDebug
}
qt.logger.Log(qt.ctx, level,
"query finished",
slog.Bool("success", err == nil),
slog.Duration("elapsed", time.Since(qt.start)),
slog.Int64("rows_count", qt.count),
)
}
func traceMethod(ctx context.Context, log *slog.Logger, method string) *queryTracer {
return &queryTracer{
ctx: ctx,
logger: log.With(slog.String("method", method)),
start: time.Now(),
}
}