db.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. package db
  2. import (
  3. "context"
  4. "database/sql"
  5. "errors"
  6. "fmt"
  7. "strconv"
  8. "github.com/jmoiron/sqlx"
  9. )
  10. type TxKey struct{}
  11. type DB struct {
  12. db *sqlx.DB
  13. }
  14. func NewDB(c Config) (*DB, error) {
  15. if len(c.username) == 0 {
  16. return nil, errors.New("empty username")
  17. }
  18. if len(c.password) == 0 {
  19. return nil, errors.New("empty password")
  20. }
  21. if len(c.dbname) == 0 {
  22. return nil, errors.New("empty database name")
  23. }
  24. if len(c.driver) == 0 {
  25. c.driver = defaultDriver
  26. }
  27. if len(c.host) == 0 {
  28. c.host = defaultHost
  29. }
  30. if c.port == 0 {
  31. c.port = defaultPort
  32. }
  33. if len(c.sslMode) == 0 {
  34. c.sslMode = defaultSslMode
  35. }
  36. source := "user=" + c.username +
  37. " password=" + c.password +
  38. " dbname=" + c.dbname +
  39. " host=" + c.host +
  40. " port=" + strconv.Itoa(int(c.port)) +
  41. " sslmode=" + c.sslMode
  42. sqlConn, err := sql.Open(c.driver, source)
  43. if err != nil {
  44. return nil, fmt.Errorf("open DB connection error: %w", err)
  45. }
  46. db := sqlx.NewDb(sqlConn, c.driver)
  47. if c.maxOpenConns > 0 {
  48. db.SetMaxOpenConns(int(c.maxOpenConns))
  49. }
  50. if c.maxOpenConns > 0 {
  51. db.SetMaxIdleConns(int(c.maxIdleConns))
  52. }
  53. if c.maxOpenConnLifetime != nil {
  54. db.SetConnMaxLifetime(*c.maxOpenConnLifetime)
  55. }
  56. if c.maxIdleConnLifetime != nil {
  57. db.SetConnMaxIdleTime(*c.maxIdleConnLifetime)
  58. }
  59. if err = db.Ping(); err != nil {
  60. return nil, fmt.Errorf("DB ping error: %w", err)
  61. }
  62. return &DB{db: db}, nil
  63. }
  64. func (s *DB) Close() error {
  65. return s.db.Close()
  66. }
  67. func (s *DB) Ping() error {
  68. return s.db.Ping()
  69. }
  70. func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sqlx.Tx, error) {
  71. return s.db.BeginTxx(ctx, opts)
  72. }
  73. func (s *DB) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
  74. tx, ok := ctx.Value(TxKey{}).(*sqlx.Tx)
  75. if ok {
  76. return tx.SelectContext(ctx, dest, query, args...)
  77. }
  78. return s.db.SelectContext(ctx, dest, query, args...)
  79. }
  80. func (s *DB) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
  81. tx, ok := ctx.Value(TxKey{}).(*sqlx.Tx)
  82. if ok {
  83. return tx.GetContext(ctx, dest, query, args...)
  84. }
  85. return s.db.GetContext(ctx, dest, query, args...)
  86. }
  87. func (s *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
  88. tx, ok := ctx.Value(TxKey{}).(*sqlx.Tx)
  89. if ok {
  90. return tx.ExecContext(ctx, query, args...)
  91. }
  92. return s.db.ExecContext(ctx, query, args...)
  93. }
  94. func (s *DB) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
  95. tx, ok := ctx.Value(TxKey{}).(*sqlx.Tx)
  96. if ok {
  97. return tx.QueryRowContext(ctx, query, args...)
  98. }
  99. return s.db.QueryRowContext(ctx, query, args...)
  100. }