|
@@ -4,58 +4,58 @@ import (
|
|
"context"
|
|
"context"
|
|
"database/sql"
|
|
"database/sql"
|
|
"fmt"
|
|
"fmt"
|
|
- "time"
|
|
|
|
|
|
+ "strconv"
|
|
)
|
|
)
|
|
|
|
|
|
type TxKey struct{}
|
|
type TxKey struct{}
|
|
|
|
|
|
-type Config interface {
|
|
|
|
- Driver() string
|
|
|
|
-
|
|
|
|
- User() string
|
|
|
|
- Password() string
|
|
|
|
- Name() string
|
|
|
|
- Host() string
|
|
|
|
- Port() string
|
|
|
|
- SSLMode() string
|
|
|
|
-
|
|
|
|
- MaxOpenConns() int
|
|
|
|
- MaxIdleConns() int
|
|
|
|
- MaxConnLifetime() int // in seconds
|
|
|
|
- MaxIdleConnLifetime() int // in seconds
|
|
|
|
-}
|
|
|
|
-
|
|
|
|
type DB struct {
|
|
type DB struct {
|
|
db *sql.DB
|
|
db *sql.DB
|
|
}
|
|
}
|
|
|
|
|
|
func NewDB(c Config) (*DB, error) {
|
|
func NewDB(c Config) (*DB, error) {
|
|
- source := "user=" + c.User() +
|
|
|
|
- " password=" + c.Password() +
|
|
|
|
- " dbname=" + c.Name() +
|
|
|
|
- " host=" + c.Host() +
|
|
|
|
- " port=" + c.Port() +
|
|
|
|
- " sslmode=" + c.SSLMode()
|
|
|
|
-
|
|
|
|
- db, err := sql.Open(c.Driver(), source)
|
|
|
|
|
|
+ if len(c.driver) == 0 {
|
|
|
|
+ c.driver = defaultDriver
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if len(c.host) == 0 {
|
|
|
|
+ c.host = defaultHost
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if c.port == 0 {
|
|
|
|
+ c.port = defaultPort
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if len(c.sslMode) == 0 {
|
|
|
|
+ c.sslMode = defaultSslMode
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ source := "user=" + c.username +
|
|
|
|
+ " password=" + c.password +
|
|
|
|
+ " dbname=" + c.dbname +
|
|
|
|
+ " host=" + c.host +
|
|
|
|
+ " port=" + strconv.Itoa(int(c.port)) +
|
|
|
|
+ " sslmode=" + c.sslMode
|
|
|
|
+
|
|
|
|
+ db, err := sql.Open(c.driver, source)
|
|
if err != nil {
|
|
if err != nil {
|
|
return nil, fmt.Errorf("open DB connection error: %w", err)
|
|
return nil, fmt.Errorf("open DB connection error: %w", err)
|
|
}
|
|
}
|
|
|
|
|
|
- if c.MaxOpenConns() > 0 {
|
|
|
|
- db.SetMaxOpenConns(c.MaxOpenConns())
|
|
|
|
|
|
+ if c.maxOpenConns > 0 {
|
|
|
|
+ db.SetMaxOpenConns(int(c.maxOpenConns))
|
|
}
|
|
}
|
|
|
|
|
|
- if c.MaxIdleConns() > 0 {
|
|
|
|
- db.SetMaxIdleConns(c.MaxIdleConns())
|
|
|
|
|
|
+ if c.maxOpenConns > 0 {
|
|
|
|
+ db.SetMaxIdleConns(int(c.maxIdleConns))
|
|
}
|
|
}
|
|
|
|
|
|
- if c.MaxConnLifetime() > 0 {
|
|
|
|
- db.SetConnMaxLifetime(time.Second * time.Duration(c.MaxConnLifetime()))
|
|
|
|
|
|
+ if c.maxConnLifetime != nil {
|
|
|
|
+ db.SetConnMaxLifetime(*c.maxConnLifetime)
|
|
}
|
|
}
|
|
|
|
|
|
- if c.MaxIdleConnLifetime() > 0 {
|
|
|
|
- db.SetConnMaxIdleTime(time.Second * time.Duration(c.MaxIdleConnLifetime()))
|
|
|
|
|
|
+ if c.maxIdleConnLifetime != nil {
|
|
|
|
+ db.SetConnMaxIdleTime(*c.maxIdleConnLifetime)
|
|
}
|
|
}
|
|
|
|
|
|
if err = db.Ping(); err != nil {
|
|
if err = db.Ping(); err != nil {
|
|
@@ -65,41 +65,41 @@ func NewDB(c Config) (*DB, error) {
|
|
return &DB{db: db}, nil
|
|
return &DB{db: db}, nil
|
|
}
|
|
}
|
|
|
|
|
|
-func (db *DB) Ping() error {
|
|
|
|
- return db.db.Ping()
|
|
|
|
|
|
+func (s *DB) Ping() error {
|
|
|
|
+ return s.db.Ping()
|
|
}
|
|
}
|
|
|
|
|
|
-func (db *DB) Close() error {
|
|
|
|
- return db.db.Close()
|
|
|
|
|
|
+func (s *DB) Close() error {
|
|
|
|
+ return s.db.Close()
|
|
}
|
|
}
|
|
|
|
|
|
-func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
|
|
|
|
- return db.db.BeginTx(ctx, opts)
|
|
|
|
|
|
+func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
|
|
|
|
+ return s.db.BeginTx(ctx, opts)
|
|
}
|
|
}
|
|
|
|
|
|
-func (db *DB) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
|
|
|
|
|
|
+func (s *DB) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
|
|
tx, ok := ctx.Value(TxKey{}).(*sql.Tx)
|
|
tx, ok := ctx.Value(TxKey{}).(*sql.Tx)
|
|
if ok {
|
|
if ok {
|
|
return tx.QueryContext(ctx, query, args...)
|
|
return tx.QueryContext(ctx, query, args...)
|
|
}
|
|
}
|
|
|
|
|
|
- return db.db.QueryContext(ctx, query, args...)
|
|
|
|
|
|
+ return s.db.QueryContext(ctx, query, args...)
|
|
}
|
|
}
|
|
|
|
|
|
-func (db *DB) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
|
|
|
|
|
|
+func (s *DB) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
|
|
tx, ok := ctx.Value(TxKey{}).(*sql.Tx)
|
|
tx, ok := ctx.Value(TxKey{}).(*sql.Tx)
|
|
if ok {
|
|
if ok {
|
|
return tx.QueryRowContext(ctx, query, args...)
|
|
return tx.QueryRowContext(ctx, query, args...)
|
|
}
|
|
}
|
|
|
|
|
|
- return db.db.QueryRowContext(ctx, query, args...)
|
|
|
|
|
|
+ return s.db.QueryRowContext(ctx, query, args...)
|
|
}
|
|
}
|
|
|
|
|
|
-func (db *DB) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
|
|
|
|
|
|
+func (s *DB) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
|
|
tx, ok := ctx.Value(TxKey{}).(*sql.Tx)
|
|
tx, ok := ctx.Value(TxKey{}).(*sql.Tx)
|
|
if ok {
|
|
if ok {
|
|
return tx.ExecContext(ctx, query, args...)
|
|
return tx.ExecContext(ctx, query, args...)
|
|
}
|
|
}
|
|
|
|
|
|
- return db.db.ExecContext(ctx, query, args...)
|
|
|
|
|
|
+ return s.db.ExecContext(ctx, query, args...)
|
|
}
|
|
}
|