Dima 7 months ago
parent
commit
7612794c8f
4 changed files with 209 additions and 0 deletions
  1. 1 0
      .gitignore
  2. 105 0
      db/db.go
  3. 10 0
      db/error.go
  4. 93 0
      db/tx.go

+ 1 - 0
.gitignore

@@ -0,0 +1 @@
+/.idea/

+ 105 - 0
db/db.go

@@ -0,0 +1,105 @@
+package db
+
+import (
+	"context"
+	"database/sql"
+	"fmt"
+	"time"
+)
+
+type TxKey struct{}
+
+type Config interface {
+	Driver() string
+
+	User() string
+	Password() string
+	Name() string
+	Host() string
+	Port() string
+	SSLMode() string
+
+	MaxOpenConns() int
+	MaxIdleConns() int
+	MaxConnLifetime() int     // in seconds
+	MaxIdleConnLifetime() int // in seconds
+}
+
+type DB struct {
+	db *sql.DB
+}
+
+func NewDB(c Config) (*DB, error) {
+	source := "user=" + c.User() +
+		" password=" + c.Password() +
+		" dbname=" + c.Name() +
+		" host=" + c.Host() +
+		" port=" + c.Port() +
+		" sslmode=" + c.SSLMode()
+
+	db, err := sql.Open(c.Driver(), source)
+	if err != nil {
+		return nil, fmt.Errorf("open DB connection error: %w", err)
+	}
+
+	if c.MaxOpenConns() > 0 {
+		db.SetMaxOpenConns(c.MaxOpenConns())
+	}
+
+	if c.MaxIdleConns() > 0 {
+		db.SetMaxIdleConns(c.MaxIdleConns())
+	}
+
+	if c.MaxConnLifetime() > 0 {
+		db.SetConnMaxLifetime(time.Second * time.Duration(c.MaxConnLifetime()))
+	}
+
+	if c.MaxIdleConnLifetime() > 0 {
+		db.SetConnMaxIdleTime(time.Second * time.Duration(c.MaxIdleConnLifetime()))
+	}
+
+	if err = db.Ping(); err != nil {
+		return nil, fmt.Errorf("DB ping error: %w", err)
+	}
+
+	return &DB{db: db}, nil
+}
+
+func (db *DB) Ping() error {
+	return db.db.Ping()
+}
+
+func (db *DB) Close() error {
+	return db.db.Close()
+}
+
+func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
+	return db.db.BeginTx(ctx, opts)
+}
+
+func (db *DB) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
+	tx, ok := ctx.Value(TxKey{}).(*sql.Tx)
+	if ok {
+		return tx.QueryContext(ctx, query, args...)
+	}
+
+	return db.db.QueryContext(ctx, query, args...)
+}
+
+func (db *DB) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
+	tx, ok := ctx.Value(TxKey{}).(*sql.Tx)
+	if ok {
+		return tx.QueryRowContext(ctx, query, args...)
+	}
+
+	return db.db.QueryRowContext(ctx, query, args...)
+}
+
+func (db *DB) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
+	tx, ok := ctx.Value(TxKey{}).(*sql.Tx)
+	if ok {
+		return tx.ExecContext(ctx, query, args...)
+	}
+
+	return db.db.ExecContext(ctx, query, args...)
+}

+ 10 - 0
db/error.go

@@ -0,0 +1,10 @@
+package db
+
+import (
+	"database/sql"
+	"errors"
+)
+
+func IsNotFoundError(err error) bool {
+	return errors.Is(err, sql.ErrNoRows)
+}

+ 93 - 0
db/tx.go

@@ -0,0 +1,93 @@
+package db
+
+import (
+	"context"
+	"database/sql"
+	"fmt"
+)
+
+type TxDB interface {
+	BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
+}
+
+type Logger interface {
+	ErrorContext(ctx context.Context, msg string, args ...any)
+}
+
+type Handler func(ctx context.Context) error
+
+type TxManager struct {
+	db TxDB
+}
+
+func NewTransactionManager(db TxDB) *TxManager {
+	return &TxManager{
+		db: db,
+	}
+}
+
+func (tm *TxManager) transaction(ctx context.Context, opts sql.TxOptions, fn Handler) error {
+	tx, ok := ctx.Value(TxKey{}).(*sql.Tx)
+	if ok {
+		return fn(ctx)
+	}
+
+	tx, err := tm.db.BeginTx(ctx, &opts)
+	if err != nil {
+		return fmt.Errorf("begin transaction  error: %w", err)
+	}
+
+	ctx = context.WithValue(ctx, TxKey{}, tx)
+
+	defer func() {
+		if r := recover(); r != nil {
+			err = fmt.Errorf("panic recovered: %v", r)
+		}
+
+		if err != nil {
+			if errRollback := tx.Rollback(); errRollback != nil {
+				err = fmt.Errorf("transaction rollback error: %w", errRollback)
+			}
+
+			return
+		}
+
+		if err == nil {
+			err = tx.Commit()
+			if err != nil {
+				err = fmt.Errorf("transaction commit error: %w", err)
+			}
+		}
+	}()
+
+	if err = fn(ctx); err != nil {
+		err = fmt.Errorf("failed executing code inside transaction: %w", err)
+	}
+
+	return err
+}
+
+func (tm *TxManager) ReadCommitted(ctx context.Context, f Handler) error {
+	txOpts := sql.TxOptions{Isolation: sql.LevelReadCommitted}
+	return tm.transaction(ctx, txOpts, f)
+}
+
+func (tm *TxManager) RepeatableRead(ctx context.Context, f Handler) error {
+	txOpts := sql.TxOptions{Isolation: sql.LevelRepeatableRead}
+	return tm.transaction(ctx, txOpts, f)
+}
+
+func (tm *TxManager) Serializable(ctx context.Context, numAttempts int, f Handler) error {
+	txOpts := sql.TxOptions{Isolation: sql.LevelSerializable}
+
+	for i := 0; i < numAttempts; i++ {
+		err := tm.transaction(ctx, txOpts, f)
+		if err != nil {
+			continue
+		}
+
+		return nil
+	}
+
+	return fmt.Errorf("serialization error after %d attempts", numAttempts)
+}