fix: critical bugs from code review (data corruption, error contract, HTTP hardening) (#5)

Co-authored-by: hermes <hermes@noreply.localhost>
Co-committed-by: hermes <hermes@noreply.localhost>
This commit is contained in:
2026-06-28 14:04:25 +00:00
committed by Aleksandr Trushkin
parent 84656c6c56
commit 5c529ef060
20 changed files with 138 additions and 86 deletions

View File

@ -1,6 +1,7 @@
package main
import (
"fmt"
"log/slog"
"net/http"
"strings"
@ -54,6 +55,7 @@ func setupHTTP(cfg config.HTTP, srv xhttp.Server, log *slog.Logger) *http.Server
router := mux.NewRouter()
router.Use(
middlewareRecovery(log),
middlewareCustomWriterInjector(),
mux.CORSMethodMiddleware(router),
middlewareLogger(log),
@ -101,6 +103,29 @@ func setupHTTP(cfg config.HTTP, srv xhttp.Server, log *slog.Logger) *http.Server
}
}
func middlewareRecovery(log *slog.Logger) mux.MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
rec := recover()
if rec == nil {
return
}
if rec == http.ErrAbortHandler {
panic(rec)
}
xcontext.LogWithError(
r.Context(), log, fmt.Errorf("%v", rec), "recovered from panic",
slog.String("method", r.Method),
slog.String("path", r.URL.Path),
)
http.Error(w, "internal server error", http.StatusInternalServerError)
}()
next.ServeHTTP(w, r)
})
}
}
func middlewareCustomWriterInjector() mux.MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

View File

@ -50,7 +50,7 @@ func setupOtelSDK(ctx context.Context, cfg config.Trace) (shutdown shutdownFunc,
return err
}
resource, err := makeServiceResource(ctx)
resource, err := makeServiceResource(ctx, cfg.Environment)
if err != nil {
return shutdown, fmt.Errorf("making service resource: %w", err)
}
@ -102,7 +102,10 @@ type TraceProviderParams struct {
Type config.TraceClientType
}
func makeServiceResource(ctx context.Context) (*resource.Resource, error) {
func makeServiceResource(ctx context.Context, environment string) (*resource.Resource, error) {
if environment == "" {
environment = "development"
}
r, err := resource.New(
ctx,
resource.WithDetectors(
@ -113,7 +116,7 @@ func makeServiceResource(ctx context.Context) (*resource.Resource, error) {
resource.WithHost(),
resource.WithAttributes(
semconv.ServiceName("bigstats:kuriweb"),
semconv.DeploymentEnvironment("production"),
semconv.DeploymentEnvironment(environment),
),
)
if err != nil {

View File

@ -29,10 +29,11 @@ func (t *TraceClientType) UnmarshalText(data []byte) error {
}
type Trace struct {
Endpoint string `json:"endpoint"`
APIKey string `json:"api_key"`
APIHeader string `json:"api_header"`
Type TraceClientType `json:"type"`
Endpoint string `json:"endpoint"`
APIKey string `json:"api_key"`
APIHeader string `json:"api_header"`
Type TraceClientType `json:"type"`
Environment string `json:"environment"`
ShowMetrics bool `json:"show_metrics"`
}

View File

@ -5,7 +5,7 @@ import (
"testing"
"git.loyso.art/frx/kurious/internal/kurious/domain"
mockrepo "git.loyso.art/frx/kurious/internal/kurious/domain/mocks"
mockrepo "git.loyso.art/frx/kurious/internal/kurious/adapters/mocks"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"

View File

@ -0,0 +1,42 @@
package adapters
import (
"context"
cerrors "git.loyso.art/frx/kurious/internal/common/errors"
"git.loyso.art/frx/kurious/internal/kurious/domain"
)
type NotImplementedOrganizationRepository struct{}
func (NotImplementedOrganizationRepository) ListStats(
context.Context,
domain.ListOrganizationsParams,
) ([]domain.OrganizationStat, error) {
return nil, cerrors.ErrNotImplemented
}
func (NotImplementedOrganizationRepository) List(context.Context, domain.ListOrganizationsParams) ([]domain.Organization, error) {
return nil, cerrors.ErrNotImplemented
}
func (NotImplementedOrganizationRepository) Get(context.Context, domain.GetOrganizationParams) (domain.Organization, error) {
return domain.Organization{}, cerrors.ErrNotImplemented
}
func (NotImplementedOrganizationRepository) Create(context.Context, domain.CreateOrganizationParams) (domain.Organization, error) {
return domain.Organization{}, cerrors.ErrNotImplemented
}
func (NotImplementedOrganizationRepository) Delete(ctx context.Context, id string) error {
return cerrors.ErrNotImplemented
}
type NotImplementedLearningCategory struct{}
func (NotImplementedLearningCategory) Upsert(context.Context, domain.LearningCategory) error {
return cerrors.ErrNotImplemented
}
func (NotImplementedLearningCategory) List(context.Context) ([]domain.LearningCategory, error) {
return nil, cerrors.ErrNotImplemented
}
func (NotImplementedLearningCategory) Get(context.Context, string) (domain.LearningCategory, error) {
return domain.LearningCategory{}, cerrors.ErrNotImplemented
}

View File

@ -10,7 +10,6 @@ import (
"time"
"git.loyso.art/frx/kurious/internal/common/nullable"
"git.loyso.art/frx/kurious/internal/common/xcontext"
"git.loyso.art/frx/kurious/internal/common/xslices"
"git.loyso.art/frx/kurious/internal/kurious/domain"
@ -112,7 +111,7 @@ func (r *sqliteCourseRepository) List(
result.Count, err = r.listCount(ctx, params)
if err != nil {
xcontext.LogWithWarnError(ctx, r.log, err, "unable to list count")
return result, fmt.Errorf("listing count: %w", err)
}
span.SetAttributes(

View File

@ -8,6 +8,7 @@ import (
"log/slog"
"strings"
cerrors "git.loyso.art/frx/kurious/internal/common/errors"
"git.loyso.art/frx/kurious/internal/common/xslices"
"git.loyso.art/frx/kurious/internal/kurious/domain"
@ -156,7 +157,7 @@ func (r *sqliteLearingCategoryRepository) Get(ctx context.Context, id string) (c
err = r.db.GetContext(ctx, &cdb, query, id)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return domain.LearningCategory{}, domain.ErrNotFound
return domain.LearningCategory{}, cerrors.ErrNotFound
}
return domain.LearningCategory{}, fmt.Errorf("executing query: %w", err)
}

View File

@ -3,6 +3,7 @@ package adapters
import (
"testing"
"git.loyso.art/frx/kurious/internal/common/errors"
"git.loyso.art/frx/kurious/internal/common/nullable"
"git.loyso.art/frx/kurious/internal/kurious/domain"
@ -75,7 +76,7 @@ func (s *sqliteLearningCategoriesRepositorySuite) TestUpsert() {
const categoryID = "test-id-1"
repo := s.connection.LearningCategory()
gotCategory, err := repo.Get(s.ctx, categoryID)
s.ErrorIs(err, domain.ErrNotFound)
s.ErrorIs(err, errors.ErrNotFound)
s.Empty(gotCategory)
createdCategory := domain.LearningCategory{

View File

@ -9,6 +9,7 @@ import (
"strings"
"time"
cerrors "git.loyso.art/frx/kurious/internal/common/errors"
"git.loyso.art/frx/kurious/internal/common/xslices"
"git.loyso.art/frx/kurious/internal/kurious/domain"
"go.opentelemetry.io/otel/attribute"
@ -228,7 +229,7 @@ func (r *sqliteOrganizationRepository) Get(ctx context.Context, params domain.Ge
err = r.db.GetContext(ctx, &orgdb, query, args...)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return out, domain.ErrNotFound
return out, cerrors.ErrNotFound
}
return out, fmt.Errorf("executing query: %w", err)
}
@ -306,14 +307,14 @@ func (r *sqliteOrganizationRepository) Delete(ctx context.Context, id string) (e
result, err := r.db.ExecContext(ctx, query, id)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return domain.ErrNotFound
return cerrors.ErrNotFound
}
return fmt.Errorf("executing query: %w", err)
}
affected, _ := result.RowsAffected()
if affected == 0 {
return domain.ErrNotFound
return cerrors.ErrNotFound
}
return nil

View File

@ -110,11 +110,11 @@ func (conn *YDBConnection) Close() error {
}
func (conn *YDBConnection) Organization() domain.OrganizationRepository {
return domain.NotImplementedOrganizationRepository{}
return NotImplementedOrganizationRepository{}
}
func (conn *YDBConnection) LearningCategory() domain.LearningCategoryRepository {
return domain.NotImplementedLearningCategory{}
return NotImplementedLearningCategory{}
}
func (conn *YDBConnection) CourseRepository() *ydbCourseRepository {

View File

@ -9,7 +9,7 @@ import (
"git.loyso.art/frx/kurious/internal/common/nullable"
"git.loyso.art/frx/kurious/internal/kurious/domain"
mockrepo "git.loyso.art/frx/kurious/internal/kurious/domain/mocks"
mockrepo "git.loyso.art/frx/kurious/internal/kurious/adapters/mocks"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"

View File

@ -8,7 +8,7 @@ import (
"testing"
"git.loyso.art/frx/kurious/internal/kurious/domain"
mockrepo "git.loyso.art/frx/kurious/internal/kurious/domain/mocks"
mockrepo "git.loyso.art/frx/kurious/internal/kurious/adapters/mocks"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"

View File

@ -1,12 +0,0 @@
package domain
const (
ErrNotFound PlainError = "not found"
ErrNotImplemented PlainError = "not implemented"
)
type PlainError string
func (err PlainError) Error() string {
return string(err)
}

View File

@ -80,7 +80,7 @@ type ListStatisticsResult struct {
LearningTypeStatistics []StatisticUnit
}
//go:generate mockery --name CourseRepository
//go:generate mockery --name CourseRepository --output ../adapters/mocks
type CourseRepository interface {
// List courses by specifid parameters.
List(context.Context, ListCoursesParams) (ListCoursesResult, error)
@ -126,7 +126,7 @@ type ListOrganizationsParams struct {
IDs []string
}
//go:generate mockery --name OrganizationRepository
//go:generate mockery --name OrganizationRepository --output ../adapters/mocks
type OrganizationRepository interface {
ListStats(context.Context, ListOrganizationsParams) ([]OrganizationStat, error)
List(context.Context, ListOrganizationsParams) ([]Organization, error)
@ -135,44 +135,10 @@ type OrganizationRepository interface {
Delete(ctx context.Context, id string) error
}
type NotImplementedOrganizationRepository struct{}
func (NotImplementedOrganizationRepository) ListStats(
context.Context,
ListOrganizationsParams,
) ([]OrganizationStat, error) {
return nil, ErrNotImplemented
}
func (NotImplementedOrganizationRepository) List(context.Context, ListOrganizationsParams) ([]Organization, error) {
return nil, ErrNotImplemented
}
func (NotImplementedOrganizationRepository) Get(context.Context, GetOrganizationParams) (Organization, error) {
return Organization{}, ErrNotImplemented
}
func (NotImplementedOrganizationRepository) Create(context.Context, CreateOrganizationParams) (Organization, error) {
return Organization{}, ErrNotImplemented
}
func (NotImplementedOrganizationRepository) Delete(ctx context.Context, id string) error {
return ErrNotImplemented
}
//go:generate mockery --name LearningCategoryRepository
//go:generate mockery --name LearningCategoryRepository --output ../adapters/mocks
type LearningCategoryRepository interface {
Upsert(context.Context, LearningCategory) error
List(context.Context) ([]LearningCategory, error)
Get(context.Context, string) (LearningCategory, error)
}
type NotImplementedLearningCategory struct{}
func (NotImplementedLearningCategory) Upsert(context.Context, LearningCategory) error {
return ErrNotImplemented
}
func (NotImplementedLearningCategory) List(context.Context) ([]LearningCategory, error) {
return nil, ErrNotImplemented
}
func (NotImplementedLearningCategory) Get(context.Context, string) (LearningCategory, error) {
return LearningCategory{}, ErrNotImplemented
}

View File

@ -85,6 +85,7 @@ func (h *syncSravniHandler) Handle(ctx context.Context) (err error) {
courses := make([]sravni.Course, 0, 1024)
buffer := make([]sravni.Course, 0, 512)
organizations := make([]sravni.Organization, 0, 256)
var insertErr error
for _, learningType := range learningTypes.Fields {
select {
case <-ctx.Done():
@ -174,22 +175,22 @@ func (h *syncSravniHandler) Handle(ctx context.Context) (err error) {
var insertCourseSuccess bool
if len(courses) > 0 {
err = h.insertCourses(lctx, courses)
if err != nil {
xcontext.LogWithError(lctx, h.log, err, "unable to insert courses")
if cerr := h.insertCourses(lctx, courses); cerr != nil {
xcontext.LogWithError(lctx, h.log, cerr, "unable to insert courses")
insertErr = errors.Join(insertErr, cerr)
} else {
insertCourseSuccess = true
}
insertCourseSuccess = err == nil
}
var insertOrgsSuccess bool
if len(organizations) > 0 {
err = h.insertOrganizations(lctx, organizations)
if err != nil {
xcontext.LogWithError(lctx, h.log, err, "unable to insert courses")
if oerr := h.insertOrganizations(lctx, organizations); oerr != nil {
xcontext.LogWithError(lctx, h.log, oerr, "unable to insert organizations")
insertErr = errors.Join(insertErr, oerr)
} else {
insertOrgsSuccess = true
}
insertOrgsSuccess = err == nil
}
elapsed = time.Since(start) - elapsed
@ -205,7 +206,7 @@ func (h *syncSravniHandler) Handle(ctx context.Context) (err error) {
)
}
return nil
return insertErr
}
func (h *syncSravniHandler) loadEducationalProducts(ctx context.Context, learningType, courseThematic string, buf []sravni.Course) ([]sravni.Course, map[string]sravni.Organization, error) {
@ -384,8 +385,12 @@ func courseAsCreateCourseParams(course sravni.Course) command.CreateCourse {
startAt = *course.DateStart
}
if course.TimeStart != nil {
startAtUnix := startAt.Unix() + course.TimeStart.Unix()
startAt = time.Unix(startAtUnix, 0)
clock := *course.TimeStart
startAt = time.Date(
startAt.Year(), startAt.Month(), startAt.Day(),
clock.Hour(), clock.Minute(), clock.Second(), clock.Nanosecond(),
startAt.Location(),
)
}
var courseDuration time.Duration

View File

@ -23,7 +23,7 @@ import (
var (
paramsAttr = attribute.Key("params")
webtracer = otel.Tracer("http")
webtracer = otel.Tracer("kuriweb.http")
)
type courseTemplServer struct {
@ -213,6 +213,15 @@ func (c courseTemplServer) List(w http.ResponseWriter, r *http.Request) {
}
})
totalPages := 0
if pathParams.PerPage > 0 {
totalPages = listCoursesResult.Count / pathParams.PerPage
}
currentPage := pathParams.Page
if currentPage > 0 && totalPages > 0 && currentPage > totalPages {
currentPage = totalPages
}
params = bootstrap.ListCoursesParams{
FilterForm: bootstrap.FilterFormParams{
Render: true,
@ -233,8 +242,8 @@ func (c courseTemplServer) List(w http.ResponseWriter, r *http.Request) {
Courses: params.Courses,
Categories: params.Categories,
Pagination: bootstrap.Pagination{
Page: pathParams.Page,
TotalPages: listCoursesResult.Count / pathParams.PerPage,
Page: currentPage,
TotalPages: totalPages,
BaseURL: r.URL.Path,
},
}
@ -285,7 +294,10 @@ func (c courseTemplServer) Index(w http.ResponseWriter, r *http.Request) {
stats := bootstrap.MakeNewStats(1, 2, 3)
coursesResult, err := c.app.Queries.ListCourses.Handle(ctx, query.ListCourse{})
const indexCoursesLimit = 200
coursesResult, err := c.app.Queries.ListCourses.Handle(ctx, query.ListCourse{
Limit: indexCoursesLimit,
})
if handleError(ctx, err, w, c.log, "unable to list courses") {
return
}

View File

@ -87,14 +87,22 @@ func parsePaginationFromQuery(r *http.Request) (out pagination, err error) {
} else {
out.PerPage = 20
}
if out.PerPage < 1 {
out.PerPage = 1
} else if out.PerPage > 100 {
out.PerPage = 100
}
if query.Has("page") {
out.Page, err = strconv.Atoi(query.Get("page"))
if err != nil {
return out, errors.NewValidationError("page", "bad per_page value")
return out, errors.NewValidationError("page", "bad page value")
}
} else if !query.Has("next") {
out.Page = 1
}
if out.Page < 1 && !query.Has("next") {
out.Page = 1
}
return out, nil
}