Search code examples
sqldatabasegopgx

Clean way to manage database transactions in golang


I am trying to figure out a good solution for managing database transactions in Golang and use the same transaction between different services.

Let's say I am building a forum, and this forum has posts and comments.

I have comments_count column on my posts table in the database, which tracks the number of comments for the post.

When I create a comment for a given post, I also need to update the posts table and increase the comments_count column of the post.

My project structure is made of couple of layers: database / business / web

Currently my code looks like this?

main.go

package main

import (
    "context"
    "github.com/jackc/pgx/v5/pgxpool"
    "net/http"
    "vkosev/stack/db/repository"
    "vkosev/stack/services"
    "vkosev/stack/web"
)

func main() {
    dbConString := "postgres://user:password@host:port/database"

    dbPool, _ := pgxpool.New(context.Background(), dbConString)

    postRepo := repository.NewPostRepository(dbPool)
    commentRepo := repository.NewCommentRepository(dbPool)

    postService := services.NewPostService(postRepo)
    commentService := services.NewCommentService(commentRepo)

    handler := web.NewHandler(postService, commentService)

    mux := http.NewServeMux()

    mux.HandleFunc("POST /comments/{postId}", handler.CreateComment)

    _ = http.ListenAndServe(":8080", mux)
}

web.go

package web

type Handler struct {
    postService    *services.PostService
    commentService *services.CommentService
}

func NewHandler(postService *services.PostService, commentService *services.CommentService) *Handler {
    return &Handler{
        postService:    postService,
        commentService: commentService,
    }
}

func (h *Handler) CreateComment(w http.ResponseWriter, r *http.Request) {
    postId := getPostIdeFromRequest(r)
    comment := getCommentFromRequest(r)

    newComment := h.commentService.Create(comment, postId)
    err := h.postService.IncreaseCount(postId)
    
    if err != nil {
        // write some error message
    }

    writeJSON(w, http.StatusOK, newComment)
}

services.go

package services

type PostService struct {
    postRepo *repository.PostRepository
}

func NewPostService(postRepo *repository.PostRepository) *PostService {
    return &PostService{postRepo: postRepo}
}

func (ps *PostService) IncreaseCount(postId int) error {
    return ps.postRepo.IncreaseCount(postId)
}

type CommentService struct {
    commentRepo *repository.CommentRepository
}

func NewCommentService(commentRepo *repository.CommentRepository) *CommentService {
    return &CommentService{commentRepo: commentRepo}
}

func (cs *CommentService) Create(comment models.Comment, postId int) *models.Comment {
    return cs.commentRepo.Save(comment, postId)
}

repository.go

package repository

type PostRepository struct {
    pool *pgxpool.Pool
}

func NewPostRepository(pool *pgxpool.Pool) *PostRepository {
    return &PostRepository{pool: pool}
}

func (pr *PostRepository) IncreaseCount(postId int) error {
    // call pr.pool and increase comments count for post with the given ID
}
type CommentRepository struct {
    pool *pgxpool.Pool
}

func NewCommentRepository(pool *pgxpool.Pool) *CommentRepository {
    return &CommentRepository{pool: pool}
}

func (cr *CommentRepository) Save(comment models.Comment, postId int) *models.Comment {
    // call cr.pool and insert comment into the DB
}

I initialize all the necessary dependencies in main.go and inject them into where I need them and then use handlers to handle every route.

Now I need a transaction, so that If for some reason I fail to update the comments count of a post to rollback the creation of the comment.

I guess the easiest way is to just pass Tx into the methods, but it seems ugly.

I was hoping for someway to abstract the database logic, so that the repositories does not care if they are using transaction or not.

And also to manage the transaction in the handler methods. So that I could have something like this:

func (h *Handler) CreateComment(w http.ResponseWriter, r *http.Request) {
    postId := getPostIdeFromRequest(r)
    comment := getCommentFromRequest(r)

    // Begin transaction
    newComment := h.commentService.Create(comment, postId)

    err := h.postService.IncreaseCount(postId)

    if err != nil {
        // rollback the transaction

        // write some error message
    }

    // commit the transaction
    writeJSON(w, http.StatusOK, newComment)
}

Solution

  • Your approach splitting services and repositories is a very good start. The following worked great for me:

    • Make use of the context API. Make all your methods in your services and repositories accept a context as first parameter.
    • Create a new interface Session that would look like this:
    type Session interface {
        Begin(ctx context.Context) (Session, error)
        Transaction(ctx context.Context, f func(context.Context) error) error
        Rollback() error
        Commit() error
        Context() context.Context
    }
    
    • Using such interface allows you to use any type of transaction system in your services (be it a DB or not, or any kind of DB).
    • Create an implementation of Session that wraps pgxpool.
    • Begin and Transaction should inject a DB instance into the context.Context.
    • Create a function that returns a pgxpool from a context like so:
    // Use dbKey as the context value key
    type dbKey struct{}
    
    func DB(ctx context.Context, fallback *pgxpool.Pool) (pgx.Tx, error) {
        db := ctx.Value(dbKey{})
        if db == nil {
            return fallback.Begin()
        }
        return db.(pgx.Tx), nil
    }
    
    • In all your repository, always use this function to retrieve a DB (and use the repo.pool.Begin() as fallback). This way, the repository doesn't know (and doesn't have to know) if the operation is inside a transaction or not. The services can call multiple different repositories and multiple methods without worrying at all about the underlying mechanism. It also helps writing tests for your services without depending on your repositories or DB. This should handle nested transactions fine too.
    • Add a Session parameter to the constructor of your services that need them.
    • From your services, only use the Session when you need to execute multiple repository operations (that is to say: having a business transaction). For single operations, you can simply call the repository right away and it will use the database without a Tx thanks to the fallback.

    I used this approach with Gorm and not with pgx directly so this may not work as well because you need to work with pgx.Tx and not a single type *gorm.DB. I think this should work fine though and I hope this helped you going forward. Good luck!


    Complete example

    Here is a more complete example based on the implementation I am using in my projects (using Gorm). In this example we have a user and a user action history. We want to create a "register" action record alongside the user record when the user registers.

    I know you are not using Gorm but the logic stays the same, you will only have to implement your Session and repositories differently.

    Session implementation

    package session
    
    import (
        "context"
        "database/sql"
    
        "gorm.io/gorm"
    )
    
    // Session aims at facilitating business transactions while abstracting the underlying mechanism,
    // be it a database transaction or another transaction mechanism. This allows services to execute
    // multiple business use-cases and easily rollback changes in case of error, without creating a
    // dependency to the database layer.
    //
    // Sessions should be constituted of a root session created with a "New"-type constructor and allow
    // the creation of child sessions with `Begin()` and `Transaction()`. Nested transactions should be supported
    // as well.
    type Session interface {
        // Begin returns a new session with the given context and a started transaction.
        // Using the returned session should have no side-effect on the parent session.
        // The underlying transaction mechanism is injected as a value into the new session's context.
        Begin(ctx context.Context) (Session, error)
    
        // Transaction executes a transaction. If the given function returns an error, the transaction
        // is rolled back. Otherwise it is automatically committed before `Transaction()` returns.
        // The underlying transaction mechanism is injected into the context as a value.
        Transaction(ctx context.Context, f func(context.Context) error) error
    
        // Rollback the changes in the transaction. This action is final.
        Rollback() error
    
        // Commit the changes in the transaction. This action is final.
        Commit() error
    
        // Context returns the session's context. If it's the root session, `context.Background()` is returned.
        // If it's a child session started with `Begin()`, then the context will contain the associated
        // transaction mechanism as a value.
        Context() context.Context
    }
    
    // Gorm session implementation.
    type Gorm struct {
        db        *gorm.DB
        TxOptions *sql.TxOptions
        ctx       context.Context
    }
    
    // GORM create a new root session for Gorm.
    // The transaction options are optional.
    func GORM(db *gorm.DB, opt *sql.TxOptions) Gorm {
        return Gorm{
            db:        db,
            TxOptions: opt,
            ctx:       context.Background(),
        }
    }
    
    // Begin returns a new session with the given context and a started DB transaction.
    // The returned session has manual controls. Make sure a call to `Rollback()` or `Commit()`
    // is executed before the session is expired (eligible for garbage collection).
    // The Gorm DB associated with this session is injected as a value into the new session's context.
    // If a Gorm DB is found in the given context, it will be used instead of this Session's DB, allowing for
    // nested transactions.
    func (s Gorm) Begin(ctx context.Context) (Session, error) {
        tx := DB(ctx, s.db).WithContext(ctx).Begin(s.TxOptions)
        if tx.Error != nil {
            return nil, tx.Error
        }
        return Gorm{
            ctx:       context.WithValue(ctx, dbKey{}, tx),
            TxOptions: s.TxOptions,
            db:        tx,
        }, nil
    }
    
    // Rollback the changes in the transaction. This action is final.
    func (s Gorm) Rollback() error {
        return s.db.Rollback().Error
    }
    
    // Commit the changes in the transaction. This action is final.
    func (s Gorm) Commit() error {
        return s.db.Commit().Error
    }
    
    // Context returns the session's context. If it's the root session, `context.Background()`
    // is returned. If it's a child session started with `Begin()`, then the context will contain
    // the associated Gorm DB and can be used in combination with `session.DB()`.
    func (s Gorm) Context() context.Context {
        return s.ctx
    }
    
    // dbKey the key used to store the database in the context.
    type dbKey struct{}
    
    // Transaction executes a transaction. If the given function returns an error, the transaction
    // is rolled back. Otherwise it is automatically committed before `Transaction()` returns.
    //
    // The Gorm DB associated with this session is injected into the context as a value so `session.DB()`
    // can be used to retrieve it.
    func (s Gorm) Transaction(ctx context.Context, f func(context.Context) error) error {
        tx := DB(ctx, s.db).WithContext(ctx).Begin(s.TxOptions)
        if tx.Error != nil {
            return tx.Error
        }
        c := context.WithValue(ctx, dbKey{}, tx)
        err := f(c)
        if err != nil {
            tx.Rollback()
            return err
        }
        return tx.Commit().Error
    }
    
    // DB returns the Gorm instance stored in the given context. Returns the given fallback
    // if no Gorm DB could be found in the context.
    func DB(ctx context.Context, fallback *gorm.DB) *gorm.DB {
        db := ctx.Value(dbKey{})
        if db == nil {
            return fallback
        }
        return db.(*gorm.DB)
    }
    

    Service implementation

    package user
    
    import (
        "context"
        "example-module/database/model"
        "example-module/dto"
    
        "example-module/session"
        "example-module/typeutil"
    )
    
    type Repository interface {
        Create(ctx context.Context, user *model.User) (*model.User, error)
        CreateHistory(ctx context.Context, history *model.History) (*model.History, error)
    }
    
    type Service struct {
        session    session.Session
        repository Repository
    }
    
    func NewService(session session.Session, repository Repository) *Service {
        return &Service{
            session:    session,
            repository: repository,
        }
    }
    
    // Register create a new user with an associated "register" history.
    func (s *Service) Register(ctx context.Context, user *dto.RegisterUser) (*dto.User, error) {
    
        // Model mapping from DTO to model (using the `copier` library)
        u := typeutil.Copy(&model.User{}, user)
    
        err := s.session.Transaction(ctx, func(ctx context.Context) error {
            // You can also call another service from here, not necessarily a repository.
            var err error
            u, err = s.repository.Create(ctx, u)
            if err != nil {
                return err
            }
    
            history := &model.History{
                UserID: u.ID,
                Action: "register",
            }
            _, err = s.repository.CreateHistory(ctx, history)
            return err
        })
    
        // Convert back to a DTO user using json marshal/unmarshal
        return typeutil.MustConvert[*dto.User](u), err
    }
    

    Repository implementation

    package repository
    
    import (
        "context"
        "example-module/database/model"
    
        "gorm.io/gorm"
        "gorm.io/gorm/clause"
        "example-module/session"
    )
    
    type User struct {
        DB *gorm.DB
    }
    
    func NewUser(db *gorm.DB) *User {
        return &User{
            DB: db,
        }
    }
    
    func (r *User) Create(ctx context.Context, user *model.User) (*model.User, error) {
        db := session.DB(ctx, r.DB).Omit(clause.Associations).Create(&user)
        return user, db.Error
    }
    
    func (r *User) CreateHistory(ctx context.Context, history *model.History) (*model.History, error) {
        db := session.DB(ctx, r.DB).Omit(clause.Associations).Create(&history)
        return history, db.Error
    }
    

    Main (init)

    session := session.GORM(myDB, nil)
    userRepository := repository.NewUser(myDB)
    userService := user.NewService(session, userRepository)