Dmitriy Gnatenko преди 4 седмици
родител
ревизия
b319cf7ffa
променени са 4 файла, в които са добавени 74 реда и са изтрити 36 реда
  1. 45 18
      db/db.go
  2. 17 18
      db/tx.go
  3. 2 0
      go.mod
  4. 10 0
      go.sum

+ 45 - 18
db/db.go

@@ -6,12 +6,14 @@ import (
 	"errors"
 	"fmt"
 	"strconv"
+
+	"github.com/jmoiron/sqlx"
 )
 
 type TxKey struct{}
 
 type DB struct {
-	db *sql.DB
+	db *sqlx.DB
 }
 
 func NewDB(c Config) (*DB, error) {
@@ -50,11 +52,13 @@ func NewDB(c Config) (*DB, error) {
 		" port=" + strconv.Itoa(int(c.port)) +
 		" sslmode=" + c.sslMode
 
-	db, err := sql.Open(c.driver, source)
+	sqlConn, err := sql.Open(c.driver, source)
 	if err != nil {
 		return nil, fmt.Errorf("open DB connection error: %w", err)
 	}
 
+	db := sqlx.NewDb(sqlConn, c.driver)
+
 	if c.maxOpenConns > 0 {
 		db.SetMaxOpenConns(int(c.maxOpenConns))
 	}
@@ -78,41 +82,64 @@ func NewDB(c Config) (*DB, error) {
 	return &DB{db: db}, nil
 }
 
-func (s *DB) Ping() error {
-	return s.db.Ping()
-}
-
 func (s *DB) Close() error {
 	return s.db.Close()
 }
 
-func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
-	return s.db.BeginTx(ctx, opts)
+func (s *DB) Ping() error {
+	return s.db.Ping()
 }
 
-func (s *DB) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
-	tx, ok := ctx.Value(TxKey{}).(*sql.Tx)
+func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sqlx.Tx, error) {
+	return s.db.BeginTxx(ctx, opts)
+}
+
+func (s *DB) SelectContext(
+	ctx context.Context,
+	dest interface{},
+	query string,
+	args ...interface{},
+) error {
+	tx, ok := ctx.Value(TxKey{}).(*sqlx.Tx)
 	if ok {
-		return tx.QueryContext(ctx, query, args...)
+		return tx.SelectContext(ctx, dest, query, args...)
 	}
 
-	return s.db.QueryContext(ctx, query, args...)
+	return s.db.SelectContext(ctx, dest, query, args...)
 }
 
-func (s *DB) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
-	tx, ok := ctx.Value(TxKey{}).(*sql.Tx)
+func (s *DB) GetContext(
+	ctx context.Context,
+	dest interface{},
+	query string,
+	args ...interface{},
+) error {
+	tx, ok := ctx.Value(TxKey{}).(*sqlx.Tx)
 	if ok {
-		return tx.QueryRowContext(ctx, query, args...)
+		return tx.GetContext(ctx, dest, query, args...)
 	}
 
-	return s.db.QueryRowContext(ctx, query, args...)
+	return s.db.GetContext(ctx, dest, query, args...)
 }
 
-func (s *DB) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
-	tx, ok := ctx.Value(TxKey{}).(*sql.Tx)
+func (s *DB) ExecContext(
+	ctx context.Context,
+	query string,
+	args ...interface{},
+) (sql.Result, error) {
+	tx, ok := ctx.Value(TxKey{}).(*sqlx.Tx)
 	if ok {
 		return tx.ExecContext(ctx, query, args...)
 	}
 
 	return s.db.ExecContext(ctx, query, args...)
 }
+
+func (s *DB) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
+	tx, ok := ctx.Value(TxKey{}).(*sqlx.Tx)
+	if ok {
+		return tx.QueryRowContext(ctx, query, args...)
+	}
+
+	return s.db.QueryRowContext(ctx, query, args...)
+}

+ 17 - 18
db/tx.go

@@ -4,14 +4,12 @@ import (
 	"context"
 	"database/sql"
 	"fmt"
+
+	"github.com/jmoiron/sqlx"
 )
 
 type TxDB interface {
-	BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
-}
-
-type Logger interface {
-	ErrorContext(ctx context.Context, msg string, args ...any)
+	BeginTx(ctx context.Context, opts *sql.TxOptions) (*sqlx.Tx, error)
 }
 
 type Handler func(ctx context.Context) error
@@ -26,15 +24,15 @@ func NewTransactionManager(db TxDB) *TxManager {
 	}
 }
 
-func (tm *TxManager) transaction(ctx context.Context, opts sql.TxOptions, fn Handler) (err error) {
-	tx, ok := ctx.Value(TxKey{}).(*sql.Tx)
+func (s *TxManager) transaction(ctx context.Context, opts sql.TxOptions, fn Handler) (err error) {
+	tx, ok := ctx.Value(TxKey{}).(*sqlx.Tx)
 	if ok {
 		return fn(ctx)
 	}
 
-	tx, err = tm.db.BeginTx(ctx, &opts)
+	tx, err = s.db.BeginTx(ctx, &opts)
 	if err != nil {
-		return fmt.Errorf("begin transaction  error: %w", err)
+		return fmt.Errorf("begin transaction: %w", err)
 	}
 
 	ctx = context.WithValue(ctx, TxKey{}, tx)
@@ -46,16 +44,16 @@ func (tm *TxManager) transaction(ctx context.Context, opts sql.TxOptions, fn Han
 
 		if err != nil {
 			if errRollback := tx.Rollback(); errRollback != nil {
-				err = fmt.Errorf("transaction rollback error: %w", errRollback)
+				err = fmt.Errorf("transaction rollback: %w", errRollback)
 			}
 
 			return
 		}
 
-		if err == nil {
+		if nil == err {
 			err = tx.Commit()
 			if err != nil {
-				err = fmt.Errorf("transaction commit error: %w", err)
+				err = fmt.Errorf("transaction commit: %w", err)
 			}
 		}
 	}()
@@ -67,21 +65,22 @@ func (tm *TxManager) transaction(ctx context.Context, opts sql.TxOptions, fn Han
 	return err
 }
 
-func (tm *TxManager) ReadCommitted(ctx context.Context, f Handler) error {
+func (s *TxManager) ReadCommitted(ctx context.Context, f Handler) error {
 	txOpts := sql.TxOptions{Isolation: sql.LevelReadCommitted}
-	return tm.transaction(ctx, txOpts, f)
+	return s.transaction(ctx, txOpts, f)
 }
 
-func (tm *TxManager) RepeatableRead(ctx context.Context, f Handler) error {
+func (s *TxManager) RepeatableRead(ctx context.Context, f Handler) error {
 	txOpts := sql.TxOptions{Isolation: sql.LevelRepeatableRead}
-	return tm.transaction(ctx, txOpts, f)
+	return s.transaction(ctx, txOpts, f)
 }
 
-func (tm *TxManager) Serializable(ctx context.Context, numAttempts int, f Handler) error {
+func (s *TxManager) Serializable(ctx context.Context, numAttempts int, f Handler) error {
 	txOpts := sql.TxOptions{Isolation: sql.LevelSerializable}
 
 	for i := 0; i < numAttempts; i++ {
-		if err := tm.transaction(ctx, txOpts, f); err != nil {
+		err := s.transaction(ctx, txOpts, f)
+		if err != nil {
 			continue
 		}
 

+ 2 - 0
go.mod

@@ -1,3 +1,5 @@
 module git.dmitriygnatenko.ru/dima/go-common
 
 go 1.21.4
+
+require github.com/jmoiron/sqlx v1.4.0

+ 10 - 0
go.sum

@@ -0,0 +1,10 @@
+filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
+filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
+github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
+github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
+github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o=
+github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY=
+github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
+github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
+github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
+github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=