implement
This commit is contained in:
73
internal/postgres/client.go
Normal file
73
internal/postgres/client.go
Normal file
@ -0,0 +1,73 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
type DatabaseStats struct {
|
||||
MaxConnections int32
|
||||
TotalConnections int32
|
||||
AcquiredConnections int32
|
||||
IdleConnections int32
|
||||
}
|
||||
|
||||
type Client interface {
|
||||
io.Closer
|
||||
|
||||
Queue() QueueClient
|
||||
|
||||
GetStats() DatabaseStats
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
MaxConns int64
|
||||
MaxIdleConns int64
|
||||
MasterDSN string
|
||||
}
|
||||
|
||||
type client struct {
|
||||
pool *pgxpool.Pool
|
||||
log *slog.Logger
|
||||
}
|
||||
|
||||
func New(ctx context.Context, config Config, logger *slog.Logger) (*client, error) {
|
||||
pgconfig, err := pgxpool.ParseConfig(config.MasterDSN)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing config: %w", err)
|
||||
}
|
||||
|
||||
pool, err := pgxpool.NewWithConfig(ctx, pgconfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("making new connection: %w", err)
|
||||
}
|
||||
|
||||
return &client{
|
||||
pool: pool,
|
||||
log: logger.With(slog.String("name", "db")),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *client) Close() error {
|
||||
if c.pool == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.pool.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) GetStats() DatabaseStats {
|
||||
stat := c.pool.Stat()
|
||||
|
||||
return DatabaseStats{
|
||||
MaxConnections: stat.MaxConns(),
|
||||
TotalConnections: stat.TotalConns(),
|
||||
AcquiredConnections: stat.AcquiredConns(),
|
||||
IdleConnections: stat.IdleConns(),
|
||||
}
|
||||
}
|
||||
11
internal/postgres/error.go
Normal file
11
internal/postgres/error.go
Normal file
@ -0,0 +1,11 @@
|
||||
package postgres
|
||||
|
||||
type Error string
|
||||
|
||||
func (err Error) Error() string { return string(err) }
|
||||
|
||||
const (
|
||||
ErrVersionIDMismatch Error = "version ids does not match"
|
||||
ErrNoMessage Error = "no messages in queue"
|
||||
ErrNotImplemented Error = "not implemented"
|
||||
)
|
||||
275
internal/postgres/queue.go
Normal file
275
internal/postgres/queue.go
Normal file
@ -0,0 +1,275 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
type QueuedMessage struct {
|
||||
// ID of the message. Unique.
|
||||
ID string
|
||||
Topic string
|
||||
Payload []byte
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
// visibleAt utilizes visibility in the queue. This is used
|
||||
// if message was not commited for a long time. (Acked or Nacked)
|
||||
visibleAt time.Time
|
||||
// versionID of the message. For future use to provide consistency.
|
||||
// In case message is tried to be acked with different version id,
|
||||
// this message will be discarded and error will be returned.
|
||||
versionID uint64
|
||||
}
|
||||
|
||||
type DequeueParams struct {
|
||||
// Timeout sets visibleAfter value to the future. Used in case
|
||||
// retry is needed, if this message should be handled for sure
|
||||
// atleast once.
|
||||
// If timeout is 0, this message will be deleted from queue.
|
||||
Timeout time.Duration
|
||||
Topic string
|
||||
}
|
||||
|
||||
type EnqueueParams struct {
|
||||
// Topic to which this message belongs to.
|
||||
Topic string
|
||||
// Payload of the message.
|
||||
Payload []byte
|
||||
// VisibleTimeout
|
||||
VisibleTimeout time.Duration
|
||||
}
|
||||
|
||||
type QueueClient interface {
|
||||
// Enqueue a message into message bus.
|
||||
Enqueue(context.Context, EnqueueParams) (QueuedMessage, error)
|
||||
// Dequeue a message from message bus.
|
||||
Dequeue(context.Context, DequeueParams) (QueuedMessage, error)
|
||||
|
||||
// Ack removes message from queue in case versionID matches.
|
||||
Ack(context.Context, QueuedMessage) error
|
||||
// Nack sets visibilityAfter to now() and also updates versionID.
|
||||
Nack(context.Context, QueuedMessage) error
|
||||
}
|
||||
|
||||
func (c *client) Queue() QueueClient {
|
||||
return &queueClient{
|
||||
pool: c.pool,
|
||||
logger: c.log.WithGroup("queue"),
|
||||
}
|
||||
}
|
||||
|
||||
type queueClient struct {
|
||||
pool *pgxpool.Pool
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
const tableCreateQuery = `
|
||||
CREATE TABLE smth.queue (
|
||||
id TEXT NOT NULL,
|
||||
version_id BIGINT NOT NULL,
|
||||
topic TEXT NOT NULL DEFAULT '',
|
||||
payload BYTEA NULL,
|
||||
created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL,
|
||||
updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL,
|
||||
visible_at TIMESTAMP WITHOUT TIME ZONE NOT NULL,
|
||||
|
||||
CONSTRAINT PRIMARY KEY (id)
|
||||
);
|
||||
CREATE INDEX queue_visible_at_idx ON smth.queue(visible_at ASC);
|
||||
CREATE INDEX queue_topic_idx ON smth.queue(topic);`
|
||||
|
||||
func (c *queueClient) scanIntoMessage(row pgx.Row) (qm QueuedMessage, err error) {
|
||||
err = row.Scan(
|
||||
&qm.ID,
|
||||
&qm.Topic,
|
||||
&qm.Payload,
|
||||
&qm.CreatedAt,
|
||||
&qm.UpdatedAt,
|
||||
&qm.visibleAt,
|
||||
&qm.versionID,
|
||||
)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return qm, ErrNoMessage
|
||||
}
|
||||
|
||||
return qm, fmt.Errorf("scanning row: %w", err)
|
||||
}
|
||||
|
||||
return qm, nil
|
||||
}
|
||||
|
||||
func (c *queueClient) Enqueue(ctx context.Context, params EnqueueParams) (qm QueuedMessage, err error) {
|
||||
const initialVersion = 1
|
||||
const query = `INSERT INTO smth.queue
|
||||
(id, topic, payload, created_at, updated_at, visible_at, version_id)
|
||||
VALUES
|
||||
($1, $2, $3, NOW(), NOW(), $4, $5)
|
||||
RETURNING
|
||||
id, topic, payload, created_at, updated_at, visible_at, version_id`
|
||||
|
||||
args := []any{
|
||||
c.generateID(),
|
||||
params.Topic,
|
||||
params.Payload,
|
||||
time.Now().Add(params.VisibleTimeout).UTC(),
|
||||
initialVersion,
|
||||
}
|
||||
|
||||
qt := traceMethod(ctx, c.logger, "Enqueue")
|
||||
defer qt.finish(&err)
|
||||
|
||||
qt.query(query, args...)
|
||||
|
||||
qm, err = c.scanIntoMessage(c.pool.QueryRow(ctx, query, args...))
|
||||
if err != nil {
|
||||
return qm, fmt.Errorf("scanning row: %w", err)
|
||||
}
|
||||
|
||||
return qm, nil
|
||||
}
|
||||
|
||||
func (c *queueClient) Dequeue(ctx context.Context, params DequeueParams) (qm QueuedMessage, err error) {
|
||||
const queryDeletePrefix = `DELETE FROM smth.queue`
|
||||
const queryUpdatePrefix = `UPDATE smth.queue SET updated_at = NOW(), version_id = version_id + 1, visible_at = $2`
|
||||
const querySuffix = ` WHERE id = (
|
||||
SELECT id
|
||||
FROM smth.queue
|
||||
WHERE
|
||||
topic = $1
|
||||
and visible_at < NOW()
|
||||
ORDER BY visible_at ASC
|
||||
LIMIT 1
|
||||
FOR UPDATE SKIP LOCKED
|
||||
) RETURNING id, topic, payload, created_at, updated_at, visible_at, version_id`
|
||||
|
||||
var query string
|
||||
args := append(
|
||||
make([]any, 0, 2),
|
||||
params.Topic,
|
||||
)
|
||||
if params.Timeout == 0 {
|
||||
query = queryDeletePrefix + querySuffix
|
||||
} else {
|
||||
query = queryUpdatePrefix + querySuffix
|
||||
args = append(args, time.Now().UTC().Add(params.Timeout))
|
||||
}
|
||||
|
||||
qt := traceMethod(ctx, c.logger, "Dequeue")
|
||||
defer qt.finish(&err)
|
||||
|
||||
qt.query(query, args...)
|
||||
|
||||
qm, err = c.scanIntoMessage(c.pool.QueryRow(ctx, query, args...))
|
||||
if err != nil {
|
||||
return qm, fmt.Errorf("querying: %w", err)
|
||||
}
|
||||
|
||||
return qm, nil
|
||||
}
|
||||
|
||||
func (c *queueClient) Ack(ctx context.Context, qm QueuedMessage) (err error) {
|
||||
const query = `DELETE FROM smth.queue` +
|
||||
` WHERE id = $1` +
|
||||
` AND version_id = $2`
|
||||
|
||||
return c.modifyByMessage(ctx, qm, "Ack", query)
|
||||
}
|
||||
|
||||
func (c *queueClient) Nack(ctx context.Context, qm QueuedMessage) error {
|
||||
const query = `UPDATE smth.queue SET` +
|
||||
` visible_at = TO_TIMESTMAP(0)` +
|
||||
`, updated_at = NOW()` +
|
||||
` WHERE id = $1 AND version_id = $2`
|
||||
|
||||
return c.modifyByMessage(ctx, qm, "Nack", query)
|
||||
}
|
||||
|
||||
func (c *queueClient) modifyByMessage(ctx context.Context, qm QueuedMessage, method, query string) (err error) {
|
||||
if qm.versionID == 0 {
|
||||
panic("queued message was not fetched")
|
||||
}
|
||||
|
||||
args := []any{
|
||||
&qm.ID,
|
||||
&qm.versionID,
|
||||
}
|
||||
|
||||
qt := traceMethod(ctx, c.logger, method)
|
||||
defer qt.finish(&err)
|
||||
|
||||
qt.query(query, args...)
|
||||
|
||||
tag, err := c.pool.Exec(ctx, query, args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("executing query: %w", err)
|
||||
}
|
||||
|
||||
affected := tag.RowsAffected()
|
||||
qt.setResultCount(affected)
|
||||
|
||||
if affected == 0 {
|
||||
return ErrVersionIDMismatch
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *queueClient) generateID() string {
|
||||
var idByte [8]byte
|
||||
_, _ = rand.Read(idByte[:])
|
||||
return hex.EncodeToString(idByte[:])
|
||||
}
|
||||
|
||||
type queryTracer struct {
|
||||
ctx context.Context
|
||||
logger *slog.Logger
|
||||
start time.Time
|
||||
count int64
|
||||
}
|
||||
|
||||
func (qt *queryTracer) query(query string, args ...any) {
|
||||
qt.logger.DebugContext(qt.ctx, "executing query", slog.String("query", query), slog.Any("args", args))
|
||||
}
|
||||
|
||||
func (qt *queryTracer) setResultCount(count int64) {
|
||||
qt.count = count
|
||||
}
|
||||
|
||||
func (qt *queryTracer) finish(errptr *error) {
|
||||
var err error
|
||||
if errptr != nil {
|
||||
err = *errptr
|
||||
}
|
||||
|
||||
var level slog.Level
|
||||
if err == nil {
|
||||
level = slog.LevelDebug
|
||||
} else {
|
||||
level = slog.LevelDebug
|
||||
}
|
||||
|
||||
qt.logger.Log(qt.ctx, level,
|
||||
"query finished",
|
||||
slog.Bool("success", err == nil),
|
||||
slog.Duration("elapsed", time.Since(qt.start)),
|
||||
slog.Int64("rows_count", qt.count),
|
||||
)
|
||||
}
|
||||
|
||||
func traceMethod(ctx context.Context, log *slog.Logger, method string) *queryTracer {
|
||||
return &queryTracer{
|
||||
ctx: ctx,
|
||||
logger: log.With(slog.String("method", method)),
|
||||
start: time.Now(),
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user