tx.go 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. package db
  2. import (
  3. "context"
  4. "database/sql"
  5. "fmt"
  6. "github.com/jmoiron/sqlx"
  7. )
  8. type TxDB interface {
  9. BeginTx(ctx context.Context, opts *sql.TxOptions) (*sqlx.Tx, error)
  10. }
  11. type TxManager struct {
  12. db TxDB
  13. }
  14. func NewTransactionManager(db TxDB) *TxManager {
  15. return &TxManager{
  16. db: db,
  17. }
  18. }
  19. func (s *TxManager) transaction(ctx context.Context, opts sql.TxOptions, fn func(ctx context.Context) error) (err error) {
  20. tx, ok := ctx.Value(TxKey{}).(*sqlx.Tx)
  21. if ok {
  22. return fn(ctx)
  23. }
  24. tx, err = s.db.BeginTx(ctx, &opts)
  25. if err != nil {
  26. return fmt.Errorf("begin transaction: %w", err)
  27. }
  28. ctx = context.WithValue(ctx, TxKey{}, tx)
  29. defer func() {
  30. if r := recover(); r != nil {
  31. err = fmt.Errorf("panic recovered: %v", r)
  32. }
  33. if err != nil {
  34. if errRollback := tx.Rollback(); errRollback != nil {
  35. err = fmt.Errorf("transaction rollback: %w", errRollback)
  36. }
  37. return
  38. }
  39. if nil == err {
  40. err = tx.Commit()
  41. if err != nil {
  42. err = fmt.Errorf("transaction commit: %w", err)
  43. }
  44. }
  45. }()
  46. if err = fn(ctx); err != nil {
  47. err = fmt.Errorf("failed executing code inside transaction: %w", err)
  48. }
  49. return err
  50. }
  51. func (s *TxManager) ReadCommitted(ctx context.Context, f func(ctx context.Context) error) error {
  52. txOpts := sql.TxOptions{Isolation: sql.LevelReadCommitted}
  53. return s.transaction(ctx, txOpts, f)
  54. }
  55. func (s *TxManager) RepeatableRead(ctx context.Context, f func(ctx context.Context) error) error {
  56. txOpts := sql.TxOptions{Isolation: sql.LevelRepeatableRead}
  57. return s.transaction(ctx, txOpts, f)
  58. }
  59. func (s *TxManager) Serializable(ctx context.Context, numAttempts int, f func(ctx context.Context) error) error {
  60. txOpts := sql.TxOptions{Isolation: sql.LevelSerializable}
  61. for i := 0; i < numAttempts; i++ {
  62. err := s.transaction(ctx, txOpts, f)
  63. if err != nil {
  64. continue
  65. }
  66. return nil
  67. }
  68. return fmt.Errorf("serialization error after %d attempts", numAttempts)
  69. }