I am trying to figure out a good solution for managing database transactions in Golang and use the same transaction between different services.
Let's say I am building a forum, and this forum has posts
and comments
.
I have comments_count
column on my posts
table in the database, which tracks the number of comments for the post.
When I create a comment for a given post, I also need to update the posts
table and increase the comments_count
column of the post.
My project structure is made of couple of layers: database / business / web
Currently my code looks like this?
main.go
package main
import (
"context"
"github.com/jackc/pgx/v5/pgxpool"
"net/http"
"vkosev/stack/db/repository"
"vkosev/stack/services"
"vkosev/stack/web"
)
func main() {
dbConString := "postgres://user:password@host:port/database"
dbPool, _ := pgxpool.New(context.Background(), dbConString)
postRepo := repository.NewPostRepository(dbPool)
commentRepo := repository.NewCommentRepository(dbPool)
postService := services.NewPostService(postRepo)
commentService := services.NewCommentService(commentRepo)
handler := web.NewHandler(postService, commentService)
mux := http.NewServeMux()
mux.HandleFunc("POST /comments/{postId}", handler.CreateComment)
_ = http.ListenAndServe(":8080", mux)
}
web.go
package web
type Handler struct {
postService *services.PostService
commentService *services.CommentService
}
func NewHandler(postService *services.PostService, commentService *services.CommentService) *Handler {
return &Handler{
postService: postService,
commentService: commentService,
}
}
func (h *Handler) CreateComment(w http.ResponseWriter, r *http.Request) {
postId := getPostIdeFromRequest(r)
comment := getCommentFromRequest(r)
newComment := h.commentService.Create(comment, postId)
err := h.postService.IncreaseCount(postId)
if err != nil {
// write some error message
}
writeJSON(w, http.StatusOK, newComment)
}
services.go
package services
type PostService struct {
postRepo *repository.PostRepository
}
func NewPostService(postRepo *repository.PostRepository) *PostService {
return &PostService{postRepo: postRepo}
}
func (ps *PostService) IncreaseCount(postId int) error {
return ps.postRepo.IncreaseCount(postId)
}
type CommentService struct {
commentRepo *repository.CommentRepository
}
func NewCommentService(commentRepo *repository.CommentRepository) *CommentService {
return &CommentService{commentRepo: commentRepo}
}
func (cs *CommentService) Create(comment models.Comment, postId int) *models.Comment {
return cs.commentRepo.Save(comment, postId)
}
repository.go
package repository
type PostRepository struct {
pool *pgxpool.Pool
}
func NewPostRepository(pool *pgxpool.Pool) *PostRepository {
return &PostRepository{pool: pool}
}
func (pr *PostRepository) IncreaseCount(postId int) error {
// call pr.pool and increase comments count for post with the given ID
}
type CommentRepository struct {
pool *pgxpool.Pool
}
func NewCommentRepository(pool *pgxpool.Pool) *CommentRepository {
return &CommentRepository{pool: pool}
}
func (cr *CommentRepository) Save(comment models.Comment, postId int) *models.Comment {
// call cr.pool and insert comment into the DB
}
I initialize all the necessary dependencies in main.go
and inject them into where I need them and then use handlers to handle every route.
Now I need a transaction, so that If for some reason I fail to update the comments count of a post to rollback the creation of the comment.
I guess the easiest way is to just pass Tx
into the methods, but it seems ugly.
I was hoping for someway to abstract the database logic, so that the repositories does not care if they are using transaction or not.
And also to manage the transaction in the handler
methods.
So that I could have something like this:
func (h *Handler) CreateComment(w http.ResponseWriter, r *http.Request) {
postId := getPostIdeFromRequest(r)
comment := getCommentFromRequest(r)
// Begin transaction
newComment := h.commentService.Create(comment, postId)
err := h.postService.IncreaseCount(postId)
if err != nil {
// rollback the transaction
// write some error message
}
// commit the transaction
writeJSON(w, http.StatusOK, newComment)
}
Your approach splitting services and repositories is a very good start. The following worked great for me:
context
API. Make all your methods in your services and repositories accept a context
as first parameter.Session
that would look like this:type Session interface {
Begin(ctx context.Context) (Session, error)
Transaction(ctx context.Context, f func(context.Context) error) error
Rollback() error
Commit() error
Context() context.Context
}
Session
that wraps pgxpool
.Begin
and Transaction
should inject a DB instance into the context.Context
.pgxpool
from a context like so:// Use dbKey as the context value key
type dbKey struct{}
func DB(ctx context.Context, fallback *pgxpool.Pool) (pgx.Tx, error) {
db := ctx.Value(dbKey{})
if db == nil {
return fallback.Begin()
}
return db.(pgx.Tx), nil
}
repo.pool.Begin()
as fallback). This way, the repository doesn't know (and doesn't have to know) if the operation is inside a transaction or not. The services can call multiple different repositories and multiple methods without worrying at all about the underlying mechanism. It also helps writing tests for your services without depending on your repositories or DB. This should handle nested transactions fine too.Session
parameter to the constructor of your services that need them.Session
when you need to execute multiple repository operations (that is to say: having a business transaction). For single operations, you can simply call the repository right away and it will use the database without a Tx thanks to the fallback.I used this approach with Gorm and not with pgx directly so this may not work as well because you need to work with pgx.Tx
and not a single type *gorm.DB
. I think this should work fine though and I hope this helped you going forward. Good luck!
Here is a more complete example based on the implementation I am using in my projects (using Gorm). In this example we have a user and a user action history. We want to create a "register" action record alongside the user record when the user registers.
I know you are not using Gorm but the logic stays the same, you will only have to implement your Session and repositories differently.
package session
import (
"context"
"database/sql"
"gorm.io/gorm"
)
// Session aims at facilitating business transactions while abstracting the underlying mechanism,
// be it a database transaction or another transaction mechanism. This allows services to execute
// multiple business use-cases and easily rollback changes in case of error, without creating a
// dependency to the database layer.
//
// Sessions should be constituted of a root session created with a "New"-type constructor and allow
// the creation of child sessions with `Begin()` and `Transaction()`. Nested transactions should be supported
// as well.
type Session interface {
// Begin returns a new session with the given context and a started transaction.
// Using the returned session should have no side-effect on the parent session.
// The underlying transaction mechanism is injected as a value into the new session's context.
Begin(ctx context.Context) (Session, error)
// Transaction executes a transaction. If the given function returns an error, the transaction
// is rolled back. Otherwise it is automatically committed before `Transaction()` returns.
// The underlying transaction mechanism is injected into the context as a value.
Transaction(ctx context.Context, f func(context.Context) error) error
// Rollback the changes in the transaction. This action is final.
Rollback() error
// Commit the changes in the transaction. This action is final.
Commit() error
// Context returns the session's context. If it's the root session, `context.Background()` is returned.
// If it's a child session started with `Begin()`, then the context will contain the associated
// transaction mechanism as a value.
Context() context.Context
}
// Gorm session implementation.
type Gorm struct {
db *gorm.DB
TxOptions *sql.TxOptions
ctx context.Context
}
// GORM create a new root session for Gorm.
// The transaction options are optional.
func GORM(db *gorm.DB, opt *sql.TxOptions) Gorm {
return Gorm{
db: db,
TxOptions: opt,
ctx: context.Background(),
}
}
// Begin returns a new session with the given context and a started DB transaction.
// The returned session has manual controls. Make sure a call to `Rollback()` or `Commit()`
// is executed before the session is expired (eligible for garbage collection).
// The Gorm DB associated with this session is injected as a value into the new session's context.
// If a Gorm DB is found in the given context, it will be used instead of this Session's DB, allowing for
// nested transactions.
func (s Gorm) Begin(ctx context.Context) (Session, error) {
tx := DB(ctx, s.db).WithContext(ctx).Begin(s.TxOptions)
if tx.Error != nil {
return nil, tx.Error
}
return Gorm{
ctx: context.WithValue(ctx, dbKey{}, tx),
TxOptions: s.TxOptions,
db: tx,
}, nil
}
// Rollback the changes in the transaction. This action is final.
func (s Gorm) Rollback() error {
return s.db.Rollback().Error
}
// Commit the changes in the transaction. This action is final.
func (s Gorm) Commit() error {
return s.db.Commit().Error
}
// Context returns the session's context. If it's the root session, `context.Background()`
// is returned. If it's a child session started with `Begin()`, then the context will contain
// the associated Gorm DB and can be used in combination with `session.DB()`.
func (s Gorm) Context() context.Context {
return s.ctx
}
// dbKey the key used to store the database in the context.
type dbKey struct{}
// Transaction executes a transaction. If the given function returns an error, the transaction
// is rolled back. Otherwise it is automatically committed before `Transaction()` returns.
//
// The Gorm DB associated with this session is injected into the context as a value so `session.DB()`
// can be used to retrieve it.
func (s Gorm) Transaction(ctx context.Context, f func(context.Context) error) error {
tx := DB(ctx, s.db).WithContext(ctx).Begin(s.TxOptions)
if tx.Error != nil {
return tx.Error
}
c := context.WithValue(ctx, dbKey{}, tx)
err := f(c)
if err != nil {
tx.Rollback()
return err
}
return tx.Commit().Error
}
// DB returns the Gorm instance stored in the given context. Returns the given fallback
// if no Gorm DB could be found in the context.
func DB(ctx context.Context, fallback *gorm.DB) *gorm.DB {
db := ctx.Value(dbKey{})
if db == nil {
return fallback
}
return db.(*gorm.DB)
}
package user
import (
"context"
"example-module/database/model"
"example-module/dto"
"example-module/session"
"example-module/typeutil"
)
type Repository interface {
Create(ctx context.Context, user *model.User) (*model.User, error)
CreateHistory(ctx context.Context, history *model.History) (*model.History, error)
}
type Service struct {
session session.Session
repository Repository
}
func NewService(session session.Session, repository Repository) *Service {
return &Service{
session: session,
repository: repository,
}
}
// Register create a new user with an associated "register" history.
func (s *Service) Register(ctx context.Context, user *dto.RegisterUser) (*dto.User, error) {
// Model mapping from DTO to model (using the `copier` library)
u := typeutil.Copy(&model.User{}, user)
err := s.session.Transaction(ctx, func(ctx context.Context) error {
// You can also call another service from here, not necessarily a repository.
var err error
u, err = s.repository.Create(ctx, u)
if err != nil {
return err
}
history := &model.History{
UserID: u.ID,
Action: "register",
}
_, err = s.repository.CreateHistory(ctx, history)
return err
})
// Convert back to a DTO user using json marshal/unmarshal
return typeutil.MustConvert[*dto.User](u), err
}
package repository
import (
"context"
"example-module/database/model"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"example-module/session"
)
type User struct {
DB *gorm.DB
}
func NewUser(db *gorm.DB) *User {
return &User{
DB: db,
}
}
func (r *User) Create(ctx context.Context, user *model.User) (*model.User, error) {
db := session.DB(ctx, r.DB).Omit(clause.Associations).Create(&user)
return user, db.Error
}
func (r *User) CreateHistory(ctx context.Context, history *model.History) (*model.History, error) {
db := session.DB(ctx, r.DB).Omit(clause.Associations).Create(&history)
return history, db.Error
}
session := session.GORM(myDB, nil)
userRepository := repository.NewUser(myDB)
userService := user.NewService(session, userRepository)