1
0

tx.go 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. package db
  2. import (
  3. "context"
  4. "database/sql"
  5. "fmt"
  6. )
  7. type TxDB interface {
  8. BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
  9. }
  10. type Logger interface {
  11. ErrorContext(ctx context.Context, msg string, args ...any)
  12. }
  13. type Handler func(ctx context.Context) error
  14. type TxManager struct {
  15. db TxDB
  16. }
  17. func NewTransactionManager(db TxDB) *TxManager {
  18. return &TxManager{
  19. db: db,
  20. }
  21. }
  22. func (tm *TxManager) transaction(ctx context.Context, opts sql.TxOptions, fn Handler) error {
  23. tx, ok := ctx.Value(TxKey{}).(*sql.Tx)
  24. if ok {
  25. return fn(ctx)
  26. }
  27. tx, err := tm.db.BeginTx(ctx, &opts)
  28. if err != nil {
  29. return fmt.Errorf("begin transaction error: %w", err)
  30. }
  31. ctx = context.WithValue(ctx, TxKey{}, tx)
  32. defer func() {
  33. if r := recover(); r != nil {
  34. err = fmt.Errorf("panic recovered: %v", r)
  35. }
  36. if err != nil {
  37. if errRollback := tx.Rollback(); errRollback != nil {
  38. err = fmt.Errorf("transaction rollback error: %w", errRollback)
  39. }
  40. return
  41. }
  42. if err == nil {
  43. err = tx.Commit()
  44. if err != nil {
  45. err = fmt.Errorf("transaction commit error: %w", err)
  46. }
  47. }
  48. }()
  49. if err = fn(ctx); err != nil {
  50. err = fmt.Errorf("failed executing code inside transaction: %w", err)
  51. }
  52. return err
  53. }
  54. func (tm *TxManager) ReadCommitted(ctx context.Context, f Handler) error {
  55. txOpts := sql.TxOptions{Isolation: sql.LevelReadCommitted}
  56. return tm.transaction(ctx, txOpts, f)
  57. }
  58. func (tm *TxManager) RepeatableRead(ctx context.Context, f Handler) error {
  59. txOpts := sql.TxOptions{Isolation: sql.LevelRepeatableRead}
  60. return tm.transaction(ctx, txOpts, f)
  61. }
  62. func (tm *TxManager) Serializable(ctx context.Context, numAttempts int, f Handler) error {
  63. txOpts := sql.TxOptions{Isolation: sql.LevelSerializable}
  64. for i := 0; i < numAttempts; i++ {
  65. err := tm.transaction(ctx, txOpts, f)
  66. if err != nil {
  67. continue
  68. }
  69. return nil
  70. }
  71. return fmt.Errorf("serialization error after %d attempts", numAttempts)
  72. }