1
0

db.go 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. package db
  2. import (
  3. "context"
  4. "database/sql"
  5. "errors"
  6. "fmt"
  7. "strconv"
  8. )
  9. type TxKey struct{}
  10. type DB struct {
  11. db *sql.DB
  12. }
  13. func NewDB(c Config) (*DB, error) {
  14. if len(c.username) == 0 {
  15. return nil, errors.New("empty username")
  16. }
  17. if len(c.password) == 0 {
  18. return nil, errors.New("empty password")
  19. }
  20. if len(c.dbname) == 0 {
  21. return nil, errors.New("empty database name")
  22. }
  23. if len(c.driver) == 0 {
  24. c.driver = defaultDriver
  25. }
  26. if len(c.host) == 0 {
  27. c.host = defaultHost
  28. }
  29. if c.port == 0 {
  30. c.port = defaultPort
  31. }
  32. if len(c.sslMode) == 0 {
  33. c.sslMode = defaultSslMode
  34. }
  35. source := "user=" + c.username +
  36. " password=" + c.password +
  37. " dbname=" + c.dbname +
  38. " host=" + c.host +
  39. " port=" + strconv.Itoa(int(c.port)) +
  40. " sslmode=" + c.sslMode
  41. db, err := sql.Open(c.driver, source)
  42. if err != nil {
  43. return nil, fmt.Errorf("open DB connection error: %w", err)
  44. }
  45. if c.maxOpenConns > 0 {
  46. db.SetMaxOpenConns(int(c.maxOpenConns))
  47. }
  48. if c.maxOpenConns > 0 {
  49. db.SetMaxIdleConns(int(c.maxIdleConns))
  50. }
  51. if c.maxConnLifetime != nil {
  52. db.SetConnMaxLifetime(*c.maxConnLifetime)
  53. }
  54. if c.maxIdleConnLifetime != nil {
  55. db.SetConnMaxIdleTime(*c.maxIdleConnLifetime)
  56. }
  57. if err = db.Ping(); err != nil {
  58. return nil, fmt.Errorf("DB ping error: %w", err)
  59. }
  60. return &DB{db: db}, nil
  61. }
  62. func (s *DB) Ping() error {
  63. return s.db.Ping()
  64. }
  65. func (s *DB) Close() error {
  66. return s.db.Close()
  67. }
  68. func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
  69. return s.db.BeginTx(ctx, opts)
  70. }
  71. func (s *DB) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
  72. tx, ok := ctx.Value(TxKey{}).(*sql.Tx)
  73. if ok {
  74. return tx.QueryContext(ctx, query, args...)
  75. }
  76. return s.db.QueryContext(ctx, query, args...)
  77. }
  78. func (s *DB) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
  79. tx, ok := ctx.Value(TxKey{}).(*sql.Tx)
  80. if ok {
  81. return tx.QueryRowContext(ctx, query, args...)
  82. }
  83. return s.db.QueryRowContext(ctx, query, args...)
  84. }
  85. func (s *DB) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
  86. tx, ok := ctx.Value(TxKey{}).(*sql.Tx)
  87. if ok {
  88. return tx.ExecContext(ctx, query, args...)
  89. }
  90. return s.db.ExecContext(ctx, query, args...)
  91. }