db.go 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. package db
  2. import (
  3. "context"
  4. "database/sql"
  5. "fmt"
  6. "strconv"
  7. )
  8. type TxKey struct{}
  9. type DB struct {
  10. db *sql.DB
  11. }
  12. func NewDB(c Config) (*DB, error) {
  13. if len(c.driver) == 0 {
  14. c.driver = defaultDriver
  15. }
  16. if len(c.host) == 0 {
  17. c.host = defaultHost
  18. }
  19. if c.port == 0 {
  20. c.port = defaultPort
  21. }
  22. if len(c.sslMode) == 0 {
  23. c.sslMode = defaultSslMode
  24. }
  25. source := "user=" + c.username +
  26. " password=" + c.password +
  27. " dbname=" + c.dbname +
  28. " host=" + c.host +
  29. " port=" + strconv.Itoa(int(c.port)) +
  30. " sslmode=" + c.sslMode
  31. db, err := sql.Open(c.driver, source)
  32. if err != nil {
  33. return nil, fmt.Errorf("open DB connection error: %w", err)
  34. }
  35. if c.maxOpenConns > 0 {
  36. db.SetMaxOpenConns(int(c.maxOpenConns))
  37. }
  38. if c.maxOpenConns > 0 {
  39. db.SetMaxIdleConns(int(c.maxIdleConns))
  40. }
  41. if c.maxConnLifetime != nil {
  42. db.SetConnMaxLifetime(*c.maxConnLifetime)
  43. }
  44. if c.maxIdleConnLifetime != nil {
  45. db.SetConnMaxIdleTime(*c.maxIdleConnLifetime)
  46. }
  47. if err = db.Ping(); err != nil {
  48. return nil, fmt.Errorf("DB ping error: %w", err)
  49. }
  50. return &DB{db: db}, nil
  51. }
  52. func (s *DB) Ping() error {
  53. return s.db.Ping()
  54. }
  55. func (s *DB) Close() error {
  56. return s.db.Close()
  57. }
  58. func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
  59. return s.db.BeginTx(ctx, opts)
  60. }
  61. func (s *DB) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
  62. tx, ok := ctx.Value(TxKey{}).(*sql.Tx)
  63. if ok {
  64. return tx.QueryContext(ctx, query, args...)
  65. }
  66. return s.db.QueryContext(ctx, query, args...)
  67. }
  68. func (s *DB) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
  69. tx, ok := ctx.Value(TxKey{}).(*sql.Tx)
  70. if ok {
  71. return tx.QueryRowContext(ctx, query, args...)
  72. }
  73. return s.db.QueryRowContext(ctx, query, args...)
  74. }
  75. func (s *DB) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
  76. tx, ok := ctx.Value(TxKey{}).(*sql.Tx)
  77. if ok {
  78. return tx.ExecContext(ctx, query, args...)
  79. }
  80. return s.db.ExecContext(ctx, query, args...)
  81. }