fix: critical bugs from code review (data corruption, error contract, HTTP hardening) #5

Merged
frx merged 7 commits from refactor/code-review-fixes into master 2026-06-28 14:04:25 +00:00
20 changed files with 138 additions and 86 deletions

View File

@ -1,6 +1,7 @@
package main package main
import ( import (
"fmt"
"log/slog" "log/slog"
"net/http" "net/http"
"strings" "strings"
@ -54,6 +55,7 @@ func setupHTTP(cfg config.HTTP, srv xhttp.Server, log *slog.Logger) *http.Server
router := mux.NewRouter() router := mux.NewRouter()
router.Use( router.Use(
middlewareRecovery(log),
middlewareCustomWriterInjector(), middlewareCustomWriterInjector(),
mux.CORSMethodMiddleware(router), mux.CORSMethodMiddleware(router),
middlewareLogger(log), middlewareLogger(log),
@ -101,6 +103,29 @@ func setupHTTP(cfg config.HTTP, srv xhttp.Server, log *slog.Logger) *http.Server
} }
} }
func middlewareRecovery(log *slog.Logger) mux.MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
rec := recover()
if rec == nil {
return
}
if rec == http.ErrAbortHandler {
panic(rec)
}
xcontext.LogWithError(
r.Context(), log, fmt.Errorf("%v", rec), "recovered from panic",
slog.String("method", r.Method),
slog.String("path", r.URL.Path),
)
http.Error(w, "internal server error", http.StatusInternalServerError)
}()
next.ServeHTTP(w, r)
})
}
}
func middlewareCustomWriterInjector() mux.MiddlewareFunc { func middlewareCustomWriterInjector() mux.MiddlewareFunc {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

View File

@ -50,7 +50,7 @@ func setupOtelSDK(ctx context.Context, cfg config.Trace) (shutdown shutdownFunc,
return err return err
} }
resource, err := makeServiceResource(ctx) resource, err := makeServiceResource(ctx, cfg.Environment)
if err != nil { if err != nil {
return shutdown, fmt.Errorf("making service resource: %w", err) return shutdown, fmt.Errorf("making service resource: %w", err)
} }
@ -102,7 +102,10 @@ type TraceProviderParams struct {
Type config.TraceClientType Type config.TraceClientType
} }
func makeServiceResource(ctx context.Context) (*resource.Resource, error) { func makeServiceResource(ctx context.Context, environment string) (*resource.Resource, error) {
if environment == "" {
environment = "development"
}
r, err := resource.New( r, err := resource.New(
ctx, ctx,
resource.WithDetectors( resource.WithDetectors(
@ -113,7 +116,7 @@ func makeServiceResource(ctx context.Context) (*resource.Resource, error) {
resource.WithHost(), resource.WithHost(),
resource.WithAttributes( resource.WithAttributes(
semconv.ServiceName("bigstats:kuriweb"), semconv.ServiceName("bigstats:kuriweb"),
semconv.DeploymentEnvironment("production"), semconv.DeploymentEnvironment(environment),
), ),
) )
if err != nil { if err != nil {

View File

@ -33,6 +33,7 @@ type Trace struct {
APIKey string `json:"api_key"` APIKey string `json:"api_key"`
APIHeader string `json:"api_header"` APIHeader string `json:"api_header"`
Type TraceClientType `json:"type"` Type TraceClientType `json:"type"`
Environment string `json:"environment"`
ShowMetrics bool `json:"show_metrics"` ShowMetrics bool `json:"show_metrics"`
} }

View File

@ -5,7 +5,7 @@ import (
"testing" "testing"
"git.loyso.art/frx/kurious/internal/kurious/domain" "git.loyso.art/frx/kurious/internal/kurious/domain"
mockrepo "git.loyso.art/frx/kurious/internal/kurious/domain/mocks" mockrepo "git.loyso.art/frx/kurious/internal/kurious/adapters/mocks"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"

View File

@ -0,0 +1,42 @@
package adapters
import (
"context"
cerrors "git.loyso.art/frx/kurious/internal/common/errors"
"git.loyso.art/frx/kurious/internal/kurious/domain"
)
type NotImplementedOrganizationRepository struct{}
func (NotImplementedOrganizationRepository) ListStats(
context.Context,
domain.ListOrganizationsParams,
) ([]domain.OrganizationStat, error) {
return nil, cerrors.ErrNotImplemented
}
func (NotImplementedOrganizationRepository) List(context.Context, domain.ListOrganizationsParams) ([]domain.Organization, error) {
return nil, cerrors.ErrNotImplemented
}
func (NotImplementedOrganizationRepository) Get(context.Context, domain.GetOrganizationParams) (domain.Organization, error) {
return domain.Organization{}, cerrors.ErrNotImplemented
}
func (NotImplementedOrganizationRepository) Create(context.Context, domain.CreateOrganizationParams) (domain.Organization, error) {
return domain.Organization{}, cerrors.ErrNotImplemented
}
func (NotImplementedOrganizationRepository) Delete(ctx context.Context, id string) error {
return cerrors.ErrNotImplemented
}
type NotImplementedLearningCategory struct{}
func (NotImplementedLearningCategory) Upsert(context.Context, domain.LearningCategory) error {
return cerrors.ErrNotImplemented
}
func (NotImplementedLearningCategory) List(context.Context) ([]domain.LearningCategory, error) {
return nil, cerrors.ErrNotImplemented
}
func (NotImplementedLearningCategory) Get(context.Context, string) (domain.LearningCategory, error) {
return domain.LearningCategory{}, cerrors.ErrNotImplemented
}

View File

@ -10,7 +10,6 @@ import (
"time" "time"
"git.loyso.art/frx/kurious/internal/common/nullable" "git.loyso.art/frx/kurious/internal/common/nullable"
"git.loyso.art/frx/kurious/internal/common/xcontext"
"git.loyso.art/frx/kurious/internal/common/xslices" "git.loyso.art/frx/kurious/internal/common/xslices"
"git.loyso.art/frx/kurious/internal/kurious/domain" "git.loyso.art/frx/kurious/internal/kurious/domain"
@ -112,7 +111,7 @@ func (r *sqliteCourseRepository) List(
result.Count, err = r.listCount(ctx, params) result.Count, err = r.listCount(ctx, params)
if err != nil { if err != nil {
xcontext.LogWithWarnError(ctx, r.log, err, "unable to list count") return result, fmt.Errorf("listing count: %w", err)
} }
span.SetAttributes( span.SetAttributes(

View File

@ -8,6 +8,7 @@ import (
"log/slog" "log/slog"
"strings" "strings"
cerrors "git.loyso.art/frx/kurious/internal/common/errors"
"git.loyso.art/frx/kurious/internal/common/xslices" "git.loyso.art/frx/kurious/internal/common/xslices"
"git.loyso.art/frx/kurious/internal/kurious/domain" "git.loyso.art/frx/kurious/internal/kurious/domain"
@ -156,7 +157,7 @@ func (r *sqliteLearingCategoryRepository) Get(ctx context.Context, id string) (c
err = r.db.GetContext(ctx, &cdb, query, id) err = r.db.GetContext(ctx, &cdb, query, id)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return domain.LearningCategory{}, domain.ErrNotFound return domain.LearningCategory{}, cerrors.ErrNotFound
} }
return domain.LearningCategory{}, fmt.Errorf("executing query: %w", err) return domain.LearningCategory{}, fmt.Errorf("executing query: %w", err)
} }

View File

@ -3,6 +3,7 @@ package adapters
import ( import (
"testing" "testing"
"git.loyso.art/frx/kurious/internal/common/errors"
"git.loyso.art/frx/kurious/internal/common/nullable" "git.loyso.art/frx/kurious/internal/common/nullable"
"git.loyso.art/frx/kurious/internal/kurious/domain" "git.loyso.art/frx/kurious/internal/kurious/domain"
@ -75,7 +76,7 @@ func (s *sqliteLearningCategoriesRepositorySuite) TestUpsert() {
const categoryID = "test-id-1" const categoryID = "test-id-1"
repo := s.connection.LearningCategory() repo := s.connection.LearningCategory()
gotCategory, err := repo.Get(s.ctx, categoryID) gotCategory, err := repo.Get(s.ctx, categoryID)
s.ErrorIs(err, domain.ErrNotFound) s.ErrorIs(err, errors.ErrNotFound)
s.Empty(gotCategory) s.Empty(gotCategory)
createdCategory := domain.LearningCategory{ createdCategory := domain.LearningCategory{

View File

@ -9,6 +9,7 @@ import (
"strings" "strings"
"time" "time"
cerrors "git.loyso.art/frx/kurious/internal/common/errors"
"git.loyso.art/frx/kurious/internal/common/xslices" "git.loyso.art/frx/kurious/internal/common/xslices"
"git.loyso.art/frx/kurious/internal/kurious/domain" "git.loyso.art/frx/kurious/internal/kurious/domain"
"go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/attribute"
@ -228,7 +229,7 @@ func (r *sqliteOrganizationRepository) Get(ctx context.Context, params domain.Ge
err = r.db.GetContext(ctx, &orgdb, query, args...) err = r.db.GetContext(ctx, &orgdb, query, args...)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return out, domain.ErrNotFound return out, cerrors.ErrNotFound
} }
return out, fmt.Errorf("executing query: %w", err) return out, fmt.Errorf("executing query: %w", err)
} }
@ -306,14 +307,14 @@ func (r *sqliteOrganizationRepository) Delete(ctx context.Context, id string) (e
result, err := r.db.ExecContext(ctx, query, id) result, err := r.db.ExecContext(ctx, query, id)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return domain.ErrNotFound return cerrors.ErrNotFound
} }
return fmt.Errorf("executing query: %w", err) return fmt.Errorf("executing query: %w", err)
} }
affected, _ := result.RowsAffected() affected, _ := result.RowsAffected()
if affected == 0 { if affected == 0 {
return domain.ErrNotFound return cerrors.ErrNotFound
} }
return nil return nil

View File

@ -110,11 +110,11 @@ func (conn *YDBConnection) Close() error {
} }
func (conn *YDBConnection) Organization() domain.OrganizationRepository { func (conn *YDBConnection) Organization() domain.OrganizationRepository {
return domain.NotImplementedOrganizationRepository{} return NotImplementedOrganizationRepository{}
} }
func (conn *YDBConnection) LearningCategory() domain.LearningCategoryRepository { func (conn *YDBConnection) LearningCategory() domain.LearningCategoryRepository {
return domain.NotImplementedLearningCategory{} return NotImplementedLearningCategory{}
} }
func (conn *YDBConnection) CourseRepository() *ydbCourseRepository { func (conn *YDBConnection) CourseRepository() *ydbCourseRepository {

View File

@ -9,7 +9,7 @@ import (
"git.loyso.art/frx/kurious/internal/common/nullable" "git.loyso.art/frx/kurious/internal/common/nullable"
"git.loyso.art/frx/kurious/internal/kurious/domain" "git.loyso.art/frx/kurious/internal/kurious/domain"
mockrepo "git.loyso.art/frx/kurious/internal/kurious/domain/mocks" mockrepo "git.loyso.art/frx/kurious/internal/kurious/adapters/mocks"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"

View File

@ -8,7 +8,7 @@ import (
"testing" "testing"
"git.loyso.art/frx/kurious/internal/kurious/domain" "git.loyso.art/frx/kurious/internal/kurious/domain"
mockrepo "git.loyso.art/frx/kurious/internal/kurious/domain/mocks" mockrepo "git.loyso.art/frx/kurious/internal/kurious/adapters/mocks"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"

View File

@ -1,12 +0,0 @@
package domain
const (
ErrNotFound PlainError = "not found"
ErrNotImplemented PlainError = "not implemented"
)
type PlainError string
func (err PlainError) Error() string {
return string(err)
}

View File

@ -80,7 +80,7 @@ type ListStatisticsResult struct {
LearningTypeStatistics []StatisticUnit LearningTypeStatistics []StatisticUnit
} }
//go:generate mockery --name CourseRepository //go:generate mockery --name CourseRepository --output ../adapters/mocks
type CourseRepository interface { type CourseRepository interface {
// List courses by specifid parameters. // List courses by specifid parameters.
List(context.Context, ListCoursesParams) (ListCoursesResult, error) List(context.Context, ListCoursesParams) (ListCoursesResult, error)
@ -126,7 +126,7 @@ type ListOrganizationsParams struct {
IDs []string IDs []string
} }
//go:generate mockery --name OrganizationRepository //go:generate mockery --name OrganizationRepository --output ../adapters/mocks
type OrganizationRepository interface { type OrganizationRepository interface {
ListStats(context.Context, ListOrganizationsParams) ([]OrganizationStat, error) ListStats(context.Context, ListOrganizationsParams) ([]OrganizationStat, error)
List(context.Context, ListOrganizationsParams) ([]Organization, error) List(context.Context, ListOrganizationsParams) ([]Organization, error)
@ -135,44 +135,10 @@ type OrganizationRepository interface {
Delete(ctx context.Context, id string) error Delete(ctx context.Context, id string) error
} }
type NotImplementedOrganizationRepository struct{} //go:generate mockery --name LearningCategoryRepository --output ../adapters/mocks
func (NotImplementedOrganizationRepository) ListStats(
context.Context,
ListOrganizationsParams,
) ([]OrganizationStat, error) {
return nil, ErrNotImplemented
}
func (NotImplementedOrganizationRepository) List(context.Context, ListOrganizationsParams) ([]Organization, error) {
return nil, ErrNotImplemented
}
func (NotImplementedOrganizationRepository) Get(context.Context, GetOrganizationParams) (Organization, error) {
return Organization{}, ErrNotImplemented
}
func (NotImplementedOrganizationRepository) Create(context.Context, CreateOrganizationParams) (Organization, error) {
return Organization{}, ErrNotImplemented
}
func (NotImplementedOrganizationRepository) Delete(ctx context.Context, id string) error {
return ErrNotImplemented
}
//go:generate mockery --name LearningCategoryRepository
type LearningCategoryRepository interface { type LearningCategoryRepository interface {
Upsert(context.Context, LearningCategory) error Upsert(context.Context, LearningCategory) error
List(context.Context) ([]LearningCategory, error) List(context.Context) ([]LearningCategory, error)
Get(context.Context, string) (LearningCategory, error) Get(context.Context, string) (LearningCategory, error)
} }
type NotImplementedLearningCategory struct{}
func (NotImplementedLearningCategory) Upsert(context.Context, LearningCategory) error {
return ErrNotImplemented
}
func (NotImplementedLearningCategory) List(context.Context) ([]LearningCategory, error) {
return nil, ErrNotImplemented
}
func (NotImplementedLearningCategory) Get(context.Context, string) (LearningCategory, error) {
return LearningCategory{}, ErrNotImplemented
}

View File

@ -85,6 +85,7 @@ func (h *syncSravniHandler) Handle(ctx context.Context) (err error) {
courses := make([]sravni.Course, 0, 1024) courses := make([]sravni.Course, 0, 1024)
buffer := make([]sravni.Course, 0, 512) buffer := make([]sravni.Course, 0, 512)
organizations := make([]sravni.Organization, 0, 256) organizations := make([]sravni.Organization, 0, 256)
var insertErr error
for _, learningType := range learningTypes.Fields { for _, learningType := range learningTypes.Fields {
select { select {
case <-ctx.Done(): case <-ctx.Done():
@ -174,22 +175,22 @@ func (h *syncSravniHandler) Handle(ctx context.Context) (err error) {
var insertCourseSuccess bool var insertCourseSuccess bool
if len(courses) > 0 { if len(courses) > 0 {
err = h.insertCourses(lctx, courses) if cerr := h.insertCourses(lctx, courses); cerr != nil {
if err != nil { xcontext.LogWithError(lctx, h.log, cerr, "unable to insert courses")
xcontext.LogWithError(lctx, h.log, err, "unable to insert courses") insertErr = errors.Join(insertErr, cerr)
} else {
insertCourseSuccess = true
} }
insertCourseSuccess = err == nil
} }
var insertOrgsSuccess bool var insertOrgsSuccess bool
if len(organizations) > 0 { if len(organizations) > 0 {
err = h.insertOrganizations(lctx, organizations) if oerr := h.insertOrganizations(lctx, organizations); oerr != nil {
if err != nil { xcontext.LogWithError(lctx, h.log, oerr, "unable to insert organizations")
xcontext.LogWithError(lctx, h.log, err, "unable to insert courses") insertErr = errors.Join(insertErr, oerr)
} else {
insertOrgsSuccess = true
} }
insertOrgsSuccess = err == nil
} }
elapsed = time.Since(start) - elapsed elapsed = time.Since(start) - elapsed
@ -205,7 +206,7 @@ func (h *syncSravniHandler) Handle(ctx context.Context) (err error) {
) )
} }
return nil return insertErr
} }
func (h *syncSravniHandler) loadEducationalProducts(ctx context.Context, learningType, courseThematic string, buf []sravni.Course) ([]sravni.Course, map[string]sravni.Organization, error) { func (h *syncSravniHandler) loadEducationalProducts(ctx context.Context, learningType, courseThematic string, buf []sravni.Course) ([]sravni.Course, map[string]sravni.Organization, error) {
@ -384,8 +385,12 @@ func courseAsCreateCourseParams(course sravni.Course) command.CreateCourse {
startAt = *course.DateStart startAt = *course.DateStart
} }
if course.TimeStart != nil { if course.TimeStart != nil {
startAtUnix := startAt.Unix() + course.TimeStart.Unix() clock := *course.TimeStart
startAt = time.Unix(startAtUnix, 0) startAt = time.Date(
startAt.Year(), startAt.Month(), startAt.Day(),
clock.Hour(), clock.Minute(), clock.Second(), clock.Nanosecond(),
startAt.Location(),
)
} }
var courseDuration time.Duration var courseDuration time.Duration

View File

@ -23,7 +23,7 @@ import (
var ( var (
paramsAttr = attribute.Key("params") paramsAttr = attribute.Key("params")
webtracer = otel.Tracer("http") webtracer = otel.Tracer("kuriweb.http")
) )
type courseTemplServer struct { type courseTemplServer struct {
@ -213,6 +213,15 @@ func (c courseTemplServer) List(w http.ResponseWriter, r *http.Request) {
} }
}) })
totalPages := 0
if pathParams.PerPage > 0 {
totalPages = listCoursesResult.Count / pathParams.PerPage
}
currentPage := pathParams.Page
if currentPage > 0 && totalPages > 0 && currentPage > totalPages {
currentPage = totalPages
}
params = bootstrap.ListCoursesParams{ params = bootstrap.ListCoursesParams{
FilterForm: bootstrap.FilterFormParams{ FilterForm: bootstrap.FilterFormParams{
Render: true, Render: true,
@ -233,8 +242,8 @@ func (c courseTemplServer) List(w http.ResponseWriter, r *http.Request) {
Courses: params.Courses, Courses: params.Courses,
Categories: params.Categories, Categories: params.Categories,
Pagination: bootstrap.Pagination{ Pagination: bootstrap.Pagination{
Page: pathParams.Page, Page: currentPage,
TotalPages: listCoursesResult.Count / pathParams.PerPage, TotalPages: totalPages,
BaseURL: r.URL.Path, BaseURL: r.URL.Path,
}, },
} }
@ -285,7 +294,10 @@ func (c courseTemplServer) Index(w http.ResponseWriter, r *http.Request) {
stats := bootstrap.MakeNewStats(1, 2, 3) stats := bootstrap.MakeNewStats(1, 2, 3)
coursesResult, err := c.app.Queries.ListCourses.Handle(ctx, query.ListCourse{}) const indexCoursesLimit = 200
coursesResult, err := c.app.Queries.ListCourses.Handle(ctx, query.ListCourse{
Limit: indexCoursesLimit,
})
if handleError(ctx, err, w, c.log, "unable to list courses") { if handleError(ctx, err, w, c.log, "unable to list courses") {
return return
} }

View File

@ -87,14 +87,22 @@ func parsePaginationFromQuery(r *http.Request) (out pagination, err error) {
} else { } else {
out.PerPage = 20 out.PerPage = 20
} }
if out.PerPage < 1 {
out.PerPage = 1
} else if out.PerPage > 100 {
out.PerPage = 100
}
if query.Has("page") { if query.Has("page") {
out.Page, err = strconv.Atoi(query.Get("page")) out.Page, err = strconv.Atoi(query.Get("page"))
if err != nil { if err != nil {
return out, errors.NewValidationError("page", "bad per_page value") return out, errors.NewValidationError("page", "bad page value")
} }
} else if !query.Has("next") { } else if !query.Has("next") {
out.Page = 1 out.Page = 1
} }
if out.Page < 1 && !query.Has("next") {
out.Page = 1
}
return out, nil return out, nil
} }