1
0

db.go 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. package db
  2. import (
  3. "context"
  4. "database/sql"
  5. "fmt"
  6. "time"
  7. )
  8. type TxKey struct{}
  9. type Config interface {
  10. Driver() string
  11. User() string
  12. Password() string
  13. Name() string
  14. Host() string
  15. Port() string
  16. SSLMode() string
  17. MaxOpenConns() int
  18. MaxIdleConns() int
  19. MaxConnLifetime() int // in seconds
  20. MaxIdleConnLifetime() int // in seconds
  21. }
  22. type DB struct {
  23. db *sql.DB
  24. }
  25. func NewDB(c Config) (*DB, error) {
  26. source := "user=" + c.User() +
  27. " password=" + c.Password() +
  28. " dbname=" + c.Name() +
  29. " host=" + c.Host() +
  30. " port=" + c.Port() +
  31. " sslmode=" + c.SSLMode()
  32. db, err := sql.Open(c.Driver(), source)
  33. if err != nil {
  34. return nil, fmt.Errorf("open DB connection error: %w", err)
  35. }
  36. if c.MaxOpenConns() > 0 {
  37. db.SetMaxOpenConns(c.MaxOpenConns())
  38. }
  39. if c.MaxIdleConns() > 0 {
  40. db.SetMaxIdleConns(c.MaxIdleConns())
  41. }
  42. if c.MaxConnLifetime() > 0 {
  43. db.SetConnMaxLifetime(time.Second * time.Duration(c.MaxConnLifetime()))
  44. }
  45. if c.MaxIdleConnLifetime() > 0 {
  46. db.SetConnMaxIdleTime(time.Second * time.Duration(c.MaxIdleConnLifetime()))
  47. }
  48. if err = db.Ping(); err != nil {
  49. return nil, fmt.Errorf("DB ping error: %w", err)
  50. }
  51. return &DB{db: db}, nil
  52. }
  53. func (db *DB) Ping() error {
  54. return db.db.Ping()
  55. }
  56. func (db *DB) Close() error {
  57. return db.db.Close()
  58. }
  59. func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
  60. return db.db.BeginTx(ctx, opts)
  61. }
  62. func (db *DB) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
  63. tx, ok := ctx.Value(TxKey{}).(*sql.Tx)
  64. if ok {
  65. return tx.QueryContext(ctx, query, args...)
  66. }
  67. return db.db.QueryContext(ctx, query, args...)
  68. }
  69. func (db *DB) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
  70. tx, ok := ctx.Value(TxKey{}).(*sql.Tx)
  71. if ok {
  72. return tx.QueryRowContext(ctx, query, args...)
  73. }
  74. return db.db.QueryRowContext(ctx, query, args...)
  75. }
  76. func (db *DB) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
  77. tx, ok := ctx.Value(TxKey{}).(*sql.Tx)
  78. if ok {
  79. return tx.ExecContext(ctx, query, args...)
  80. }
  81. return db.db.ExecContext(ctx, query, args...)
  82. }