123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105 |
- package db
- import (
- "context"
- "database/sql"
- "fmt"
- "time"
- )
- type TxKey struct{}
- type Config interface {
- Driver() string
- User() string
- Password() string
- Name() string
- Host() string
- Port() string
- SSLMode() string
- MaxOpenConns() int
- MaxIdleConns() int
- MaxConnLifetime() int // in seconds
- MaxIdleConnLifetime() int // in seconds
- }
- type DB struct {
- db *sql.DB
- }
- func NewDB(c Config) (*DB, error) {
- source := "user=" + c.User() +
- " password=" + c.Password() +
- " dbname=" + c.Name() +
- " host=" + c.Host() +
- " port=" + c.Port() +
- " sslmode=" + c.SSLMode()
- db, err := sql.Open(c.Driver(), source)
- if err != nil {
- return nil, fmt.Errorf("open DB connection error: %w", err)
- }
- if c.MaxOpenConns() > 0 {
- db.SetMaxOpenConns(c.MaxOpenConns())
- }
- if c.MaxIdleConns() > 0 {
- db.SetMaxIdleConns(c.MaxIdleConns())
- }
- if c.MaxConnLifetime() > 0 {
- db.SetConnMaxLifetime(time.Second * time.Duration(c.MaxConnLifetime()))
- }
- if c.MaxIdleConnLifetime() > 0 {
- db.SetConnMaxIdleTime(time.Second * time.Duration(c.MaxIdleConnLifetime()))
- }
- if err = db.Ping(); err != nil {
- return nil, fmt.Errorf("DB ping error: %w", err)
- }
- return &DB{db: db}, nil
- }
- func (db *DB) Ping() error {
- return db.db.Ping()
- }
- func (db *DB) Close() error {
- return db.db.Close()
- }
- func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
- return db.db.BeginTx(ctx, opts)
- }
- func (db *DB) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
- tx, ok := ctx.Value(TxKey{}).(*sql.Tx)
- if ok {
- return tx.QueryContext(ctx, query, args...)
- }
- return db.db.QueryContext(ctx, query, args...)
- }
- func (db *DB) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
- tx, ok := ctx.Value(TxKey{}).(*sql.Tx)
- if ok {
- return tx.QueryRowContext(ctx, query, args...)
- }
- return db.db.QueryRowContext(ctx, query, args...)
- }
- func (db *DB) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
- tx, ok := ctx.Value(TxKey{}).(*sql.Tx)
- if ok {
- return tx.ExecContext(ctx, query, args...)
- }
- return db.db.ExecContext(ctx, query, args...)
- }
|