db.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. package db
  2. import (
  3. "context"
  4. "database/sql"
  5. "errors"
  6. "fmt"
  7. "strconv"
  8. "github.com/jmoiron/sqlx"
  9. )
  10. type TxKey struct{}
  11. type DB struct {
  12. db *sqlx.DB
  13. }
  14. func NewDB(c Config) (*DB, error) {
  15. if len(c.username) == 0 {
  16. return nil, errors.New("empty username")
  17. }
  18. if len(c.password) == 0 {
  19. return nil, errors.New("empty password")
  20. }
  21. if len(c.dbname) == 0 {
  22. return nil, errors.New("empty database name")
  23. }
  24. if len(c.driver) == 0 {
  25. c.driver = defaultDriver
  26. }
  27. if len(c.host) == 0 {
  28. c.host = defaultHost
  29. }
  30. if c.port == 0 {
  31. c.port = defaultPort
  32. }
  33. if len(c.sslMode) == 0 {
  34. c.sslMode = defaultSslMode
  35. }
  36. source := "user=" + c.username +
  37. " password=" + c.password +
  38. " dbname=" + c.dbname +
  39. " host=" + c.host +
  40. " port=" + strconv.Itoa(int(c.port)) +
  41. " sslmode=" + c.sslMode
  42. sqlConn, err := sql.Open(c.driver, source)
  43. if err != nil {
  44. return nil, fmt.Errorf("open DB connection error: %w", err)
  45. }
  46. db := sqlx.NewDb(sqlConn, c.driver)
  47. if c.maxOpenConns > 0 {
  48. db.SetMaxOpenConns(int(c.maxOpenConns))
  49. }
  50. if c.maxOpenConns > 0 {
  51. db.SetMaxIdleConns(int(c.maxIdleConns))
  52. }
  53. if c.maxOpenConnLifetime != nil {
  54. db.SetConnMaxLifetime(*c.maxOpenConnLifetime)
  55. }
  56. if c.maxIdleConnLifetime != nil {
  57. db.SetConnMaxIdleTime(*c.maxIdleConnLifetime)
  58. }
  59. if err = db.Ping(); err != nil {
  60. return nil, fmt.Errorf("DB ping error: %w", err)
  61. }
  62. return &DB{db: db}, nil
  63. }
  64. func (s *DB) Close() error {
  65. return s.db.Close()
  66. }
  67. func (s *DB) Ping() error {
  68. return s.db.Ping()
  69. }
  70. func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sqlx.Tx, error) {
  71. return s.db.BeginTxx(ctx, opts)
  72. }
  73. func (s *DB) SelectContext(
  74. ctx context.Context,
  75. dest interface{},
  76. query string,
  77. args ...interface{},
  78. ) error {
  79. tx, ok := ctx.Value(TxKey{}).(*sqlx.Tx)
  80. if ok {
  81. return tx.SelectContext(ctx, dest, query, args...)
  82. }
  83. return s.db.SelectContext(ctx, dest, query, args...)
  84. }
  85. func (s *DB) GetContext(
  86. ctx context.Context,
  87. dest interface{},
  88. query string,
  89. args ...interface{},
  90. ) error {
  91. tx, ok := ctx.Value(TxKey{}).(*sqlx.Tx)
  92. if ok {
  93. return tx.GetContext(ctx, dest, query, args...)
  94. }
  95. return s.db.GetContext(ctx, dest, query, args...)
  96. }
  97. func (s *DB) ExecContext(
  98. ctx context.Context,
  99. query string,
  100. args ...interface{},
  101. ) (sql.Result, error) {
  102. tx, ok := ctx.Value(TxKey{}).(*sqlx.Tx)
  103. if ok {
  104. return tx.ExecContext(ctx, query, args...)
  105. }
  106. return s.db.ExecContext(ctx, query, args...)
  107. }
  108. func (s *DB) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
  109. tx, ok := ctx.Value(TxKey{}).(*sqlx.Tx)
  110. if ok {
  111. return tx.QueryRowContext(ctx, query, args...)
  112. }
  113. return s.db.QueryRowContext(ctx, query, args...)
  114. }