|
@@ -6,12 +6,14 @@ import (
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
"strconv"
|
|
|
+
|
|
|
+ "github.com/jmoiron/sqlx"
|
|
|
)
|
|
|
|
|
|
type TxKey struct{}
|
|
|
|
|
|
type DB struct {
|
|
|
- db *sql.DB
|
|
|
+ db *sqlx.DB
|
|
|
}
|
|
|
|
|
|
func NewDB(c Config) (*DB, error) {
|
|
@@ -50,11 +52,13 @@ func NewDB(c Config) (*DB, error) {
|
|
|
" port=" + strconv.Itoa(int(c.port)) +
|
|
|
" sslmode=" + c.sslMode
|
|
|
|
|
|
- db, err := sql.Open(c.driver, source)
|
|
|
+ sqlConn, err := sql.Open(c.driver, source)
|
|
|
if err != nil {
|
|
|
return nil, fmt.Errorf("open DB connection error: %w", err)
|
|
|
}
|
|
|
|
|
|
+ db := sqlx.NewDb(sqlConn, c.driver)
|
|
|
+
|
|
|
if c.maxOpenConns > 0 {
|
|
|
db.SetMaxOpenConns(int(c.maxOpenConns))
|
|
|
}
|
|
@@ -78,41 +82,64 @@ func NewDB(c Config) (*DB, error) {
|
|
|
return &DB{db: db}, nil
|
|
|
}
|
|
|
|
|
|
-func (s *DB) Ping() error {
|
|
|
- return s.db.Ping()
|
|
|
-}
|
|
|
-
|
|
|
func (s *DB) Close() error {
|
|
|
return s.db.Close()
|
|
|
}
|
|
|
|
|
|
-func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
|
|
|
- return s.db.BeginTx(ctx, opts)
|
|
|
+func (s *DB) Ping() error {
|
|
|
+ return s.db.Ping()
|
|
|
}
|
|
|
|
|
|
-func (s *DB) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
|
|
|
- tx, ok := ctx.Value(TxKey{}).(*sql.Tx)
|
|
|
+func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sqlx.Tx, error) {
|
|
|
+ return s.db.BeginTxx(ctx, opts)
|
|
|
+}
|
|
|
+
|
|
|
+func (s *DB) SelectContext(
|
|
|
+ ctx context.Context,
|
|
|
+ dest interface{},
|
|
|
+ query string,
|
|
|
+ args ...interface{},
|
|
|
+) error {
|
|
|
+ tx, ok := ctx.Value(TxKey{}).(*sqlx.Tx)
|
|
|
if ok {
|
|
|
- return tx.QueryContext(ctx, query, args...)
|
|
|
+ return tx.SelectContext(ctx, dest, query, args...)
|
|
|
}
|
|
|
|
|
|
- return s.db.QueryContext(ctx, query, args...)
|
|
|
+ return s.db.SelectContext(ctx, dest, query, args...)
|
|
|
}
|
|
|
|
|
|
-func (s *DB) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
|
|
|
- tx, ok := ctx.Value(TxKey{}).(*sql.Tx)
|
|
|
+func (s *DB) GetContext(
|
|
|
+ ctx context.Context,
|
|
|
+ dest interface{},
|
|
|
+ query string,
|
|
|
+ args ...interface{},
|
|
|
+) error {
|
|
|
+ tx, ok := ctx.Value(TxKey{}).(*sqlx.Tx)
|
|
|
if ok {
|
|
|
- return tx.QueryRowContext(ctx, query, args...)
|
|
|
+ return tx.GetContext(ctx, dest, query, args...)
|
|
|
}
|
|
|
|
|
|
- return s.db.QueryRowContext(ctx, query, args...)
|
|
|
+ return s.db.GetContext(ctx, dest, query, args...)
|
|
|
}
|
|
|
|
|
|
-func (s *DB) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
|
|
|
- tx, ok := ctx.Value(TxKey{}).(*sql.Tx)
|
|
|
+func (s *DB) ExecContext(
|
|
|
+ ctx context.Context,
|
|
|
+ query string,
|
|
|
+ args ...interface{},
|
|
|
+) (sql.Result, error) {
|
|
|
+ tx, ok := ctx.Value(TxKey{}).(*sqlx.Tx)
|
|
|
if ok {
|
|
|
return tx.ExecContext(ctx, query, args...)
|
|
|
}
|
|
|
|
|
|
return s.db.ExecContext(ctx, query, args...)
|
|
|
}
|
|
|
+
|
|
|
+func (s *DB) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
|
|
|
+ tx, ok := ctx.Value(TxKey{}).(*sqlx.Tx)
|
|
|
+ if ok {
|
|
|
+ return tx.QueryRowContext(ctx, query, args...)
|
|
|
+ }
|
|
|
+
|
|
|
+ return s.db.QueryRowContext(ctx, query, args...)
|
|
|
+}
|