1
0

tx.go 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. package db
  2. import (
  3. "context"
  4. "database/sql"
  5. "fmt"
  6. "github.com/jmoiron/sqlx"
  7. )
  8. type TxDB interface {
  9. BeginTx(ctx context.Context, opts *sql.TxOptions) (*sqlx.Tx, error)
  10. }
  11. type Handler func(ctx context.Context) error
  12. type TxManager struct {
  13. db TxDB
  14. }
  15. func NewTransactionManager(db TxDB) *TxManager {
  16. return &TxManager{
  17. db: db,
  18. }
  19. }
  20. func (s *TxManager) transaction(ctx context.Context, opts sql.TxOptions, fn Handler) (err error) {
  21. tx, ok := ctx.Value(TxKey{}).(*sqlx.Tx)
  22. if ok {
  23. return fn(ctx)
  24. }
  25. tx, err = s.db.BeginTx(ctx, &opts)
  26. if err != nil {
  27. return fmt.Errorf("begin transaction: %w", err)
  28. }
  29. ctx = context.WithValue(ctx, TxKey{}, tx)
  30. defer func() {
  31. if r := recover(); r != nil {
  32. err = fmt.Errorf("panic recovered: %v", r)
  33. }
  34. if err != nil {
  35. if errRollback := tx.Rollback(); errRollback != nil {
  36. err = fmt.Errorf("transaction rollback: %w", errRollback)
  37. }
  38. return
  39. }
  40. if nil == err {
  41. err = tx.Commit()
  42. if err != nil {
  43. err = fmt.Errorf("transaction commit: %w", err)
  44. }
  45. }
  46. }()
  47. if err = fn(ctx); err != nil {
  48. err = fmt.Errorf("failed executing code inside transaction: %w", err)
  49. }
  50. return err
  51. }
  52. func (s *TxManager) ReadCommitted(ctx context.Context, f Handler) error {
  53. txOpts := sql.TxOptions{Isolation: sql.LevelReadCommitted}
  54. return s.transaction(ctx, txOpts, f)
  55. }
  56. func (s *TxManager) RepeatableRead(ctx context.Context, f Handler) error {
  57. txOpts := sql.TxOptions{Isolation: sql.LevelRepeatableRead}
  58. return s.transaction(ctx, txOpts, f)
  59. }
  60. func (s *TxManager) Serializable(ctx context.Context, numAttempts int, f Handler) error {
  61. txOpts := sql.TxOptions{Isolation: sql.LevelSerializable}
  62. for i := 0; i < numAttempts; i++ {
  63. err := s.transaction(ctx, txOpts, f)
  64. if err != nil {
  65. continue
  66. }
  67. return nil
  68. }
  69. return fmt.Errorf("serialization error after %d attempts", numAttempts)
  70. }