package postgres import ( "context" "crypto/rand" "encoding/hex" "errors" "fmt" "time" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "golang.org/x/exp/slog" ) type QueuedMessage struct { // ID of the message. Unique. ID string Topic string Payload []byte CreatedAt time.Time UpdatedAt time.Time // visibleAt utilizes visibility in the queue. This is used // if message was not commited for a long time. (Acked or Nacked) visibleAt time.Time // versionID of the message. For future use to provide consistency. // In case message is tried to be acked with different version id, // this message will be discarded and error will be returned. versionID uint64 } type DequeueParams struct { // Timeout sets visibleAfter value to the future. Used in case // retry is needed, if this message should be handled for sure // atleast once. // If timeout is 0, this message will be deleted from queue. Timeout time.Duration Topic string } type EnqueueParams struct { // Topic to which this message belongs to. Topic string // Payload of the message. Payload []byte // VisibleTimeout VisibleTimeout time.Duration } type QueueClient interface { // Enqueue a message into message bus. Enqueue(context.Context, EnqueueParams) (QueuedMessage, error) // Dequeue a message from message bus. Dequeue(context.Context, DequeueParams) (QueuedMessage, error) // Ack removes message from queue in case versionID matches. Ack(context.Context, QueuedMessage) error // Nack sets visibilityAfter to now() and also updates versionID. Nack(context.Context, QueuedMessage) error } func (c *client) Queue() QueueClient { return &queueClient{ pool: c.pool, logger: c.log.WithGroup("queue"), } } type queueClient struct { pool *pgxpool.Pool logger *slog.Logger } const tableCreateQuery = ` CREATE TABLE smth.queue ( id TEXT NOT NULL, version_id BIGINT NOT NULL, topic TEXT NOT NULL DEFAULT '', payload BYTEA NULL, created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL, updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL, visible_at TIMESTAMP WITHOUT TIME ZONE NOT NULL, CONSTRAINT PRIMARY KEY (id) ); CREATE INDEX queue_visible_at_idx ON smth.queue(visible_at ASC); CREATE INDEX queue_topic_idx ON smth.queue(topic);` func (c *queueClient) scanIntoMessage(row pgx.Row) (qm QueuedMessage, err error) { err = row.Scan( &qm.ID, &qm.Topic, &qm.Payload, &qm.CreatedAt, &qm.UpdatedAt, &qm.visibleAt, &qm.versionID, ) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return qm, ErrNoMessage } return qm, fmt.Errorf("scanning row: %w", err) } return qm, nil } func (c *queueClient) Enqueue(ctx context.Context, params EnqueueParams) (qm QueuedMessage, err error) { const initialVersion = 1 const query = `INSERT INTO smth.queue (id, topic, payload, created_at, updated_at, visible_at, version_id) VALUES ($1, $2, $3, NOW(), NOW(), $4, $5) RETURNING id, topic, payload, created_at, updated_at, visible_at, version_id` args := []any{ c.generateID(), params.Topic, params.Payload, time.Now().Add(params.VisibleTimeout).UTC(), initialVersion, } qt := traceMethod(ctx, c.logger, "Enqueue") defer qt.finish(&err) qt.query(query, args...) qm, err = c.scanIntoMessage(c.pool.QueryRow(ctx, query, args...)) if err != nil { return qm, fmt.Errorf("scanning row: %w", err) } return qm, nil } func (c *queueClient) Dequeue(ctx context.Context, params DequeueParams) (qm QueuedMessage, err error) { const queryDeletePrefix = `DELETE FROM smth.queue` const queryUpdatePrefix = `UPDATE smth.queue SET updated_at = NOW(), version_id = version_id + 1, visible_at = $2` const querySuffix = ` WHERE id = ( SELECT id FROM smth.queue WHERE topic = $1 and visible_at < NOW() ORDER BY visible_at ASC LIMIT 1 FOR UPDATE SKIP LOCKED ) RETURNING id, topic, payload, created_at, updated_at, visible_at, version_id` var query string args := append( make([]any, 0, 2), params.Topic, ) if params.Timeout == 0 { query = queryDeletePrefix + querySuffix } else { query = queryUpdatePrefix + querySuffix args = append(args, time.Now().UTC().Add(params.Timeout)) } qt := traceMethod(ctx, c.logger, "Dequeue") defer qt.finish(&err) qt.query(query, args...) qm, err = c.scanIntoMessage(c.pool.QueryRow(ctx, query, args...)) if err != nil { return qm, fmt.Errorf("querying: %w", err) } return qm, nil } func (c *queueClient) Ack(ctx context.Context, qm QueuedMessage) (err error) { const query = `DELETE FROM smth.queue` + ` WHERE id = $1` + ` AND version_id = $2` return c.modifyByMessage(ctx, qm, "Ack", query) } func (c *queueClient) Nack(ctx context.Context, qm QueuedMessage) error { const query = `UPDATE smth.queue SET` + ` visible_at = TO_TIMESTMAP(0)` + `, updated_at = NOW()` + ` WHERE id = $1 AND version_id = $2` return c.modifyByMessage(ctx, qm, "Nack", query) } func (c *queueClient) modifyByMessage(ctx context.Context, qm QueuedMessage, method, query string) (err error) { if qm.versionID == 0 { panic("queued message was not fetched") } args := []any{ &qm.ID, &qm.versionID, } qt := traceMethod(ctx, c.logger, method) defer qt.finish(&err) qt.query(query, args...) tag, err := c.pool.Exec(ctx, query, args...) if err != nil { return fmt.Errorf("executing query: %w", err) } affected := tag.RowsAffected() qt.setResultCount(affected) if affected == 0 { return ErrVersionIDMismatch } return nil } func (c *queueClient) generateID() string { var idByte [8]byte _, _ = rand.Read(idByte[:]) return hex.EncodeToString(idByte[:]) } type queryTracer struct { ctx context.Context logger *slog.Logger start time.Time count int64 } func (qt *queryTracer) query(query string, args ...any) { qt.logger.DebugContext(qt.ctx, "executing query", slog.String("query", query), slog.Any("args", args)) } func (qt *queryTracer) setResultCount(count int64) { qt.count = count } func (qt *queryTracer) finish(errptr *error) { var err error if errptr != nil { err = *errptr } var level slog.Level if err == nil { level = slog.LevelDebug } else { level = slog.LevelDebug } qt.logger.Log(qt.ctx, level, "query finished", slog.Bool("success", err == nil), slog.Duration("elapsed", time.Since(qt.start)), slog.Int64("rows_count", qt.count), ) } func traceMethod(ctx context.Context, log *slog.Logger, method string) *queryTracer { return &queryTracer{ ctx: ctx, logger: log.With(slog.String("method", method)), start: time.Now(), } }