db.go 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. package db
  2. import (
  3. "context"
  4. "database/sql"
  5. "errors"
  6. "fmt"
  7. "strconv"
  8. "github.com/jmoiron/sqlx"
  9. )
  10. type TxKey struct{}
  11. type DB struct {
  12. db *sqlx.DB
  13. }
  14. func NewDB(c Config) (*DB, error) {
  15. if len(c.username) == 0 {
  16. return nil, errors.New("empty username")
  17. }
  18. if len(c.password) == 0 {
  19. return nil, errors.New("empty password")
  20. }
  21. if len(c.dbname) == 0 {
  22. return nil, errors.New("empty database name")
  23. }
  24. if len(c.driver) == 0 {
  25. c.driver = defaultDriver
  26. }
  27. if len(c.host) == 0 {
  28. c.host = defaultHost
  29. }
  30. if c.port == 0 {
  31. c.port = defaultPort
  32. }
  33. var source string
  34. switch c.driver {
  35. case "mysql":
  36. source = c.username +
  37. ":" + c.password +
  38. "@tcp(" + c.host + ":" + strconv.Itoa(int(c.port)) + ")/" +
  39. c.dbname +
  40. "?parseTime=true"
  41. case "postgres":
  42. source = "user=" + c.username +
  43. " password=" + c.password +
  44. " dbname=/" + c.dbname +
  45. " host=" + c.host +
  46. " port=" + strconv.Itoa(int(c.port))
  47. }
  48. sqlConn, err := sql.Open(c.driver, source)
  49. if err != nil {
  50. return nil, fmt.Errorf("open DB connection error: %w", err)
  51. }
  52. db := sqlx.NewDb(sqlConn, c.driver)
  53. if c.maxOpenConns > 0 {
  54. db.SetMaxOpenConns(int(c.maxOpenConns))
  55. }
  56. if c.maxOpenConns > 0 {
  57. db.SetMaxIdleConns(int(c.maxIdleConns))
  58. }
  59. if c.maxOpenConnLifetime != nil {
  60. db.SetConnMaxLifetime(*c.maxOpenConnLifetime)
  61. }
  62. if c.maxIdleConnLifetime != nil {
  63. db.SetConnMaxIdleTime(*c.maxIdleConnLifetime)
  64. }
  65. if err = db.Ping(); err != nil {
  66. return nil, fmt.Errorf("DB ping error: %w", err)
  67. }
  68. return &DB{db: db}, nil
  69. }
  70. func (s *DB) Close() error {
  71. return s.db.Close()
  72. }
  73. func (s *DB) Ping() error {
  74. return s.db.Ping()
  75. }
  76. func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sqlx.Tx, error) {
  77. return s.db.BeginTxx(ctx, opts)
  78. }
  79. func (s *DB) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
  80. tx, ok := ctx.Value(TxKey{}).(*sqlx.Tx)
  81. if ok {
  82. return tx.SelectContext(ctx, dest, query, args...)
  83. }
  84. return s.db.SelectContext(ctx, dest, query, args...)
  85. }
  86. func (s *DB) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
  87. tx, ok := ctx.Value(TxKey{}).(*sqlx.Tx)
  88. if ok {
  89. return tx.GetContext(ctx, dest, query, args...)
  90. }
  91. return s.db.GetContext(ctx, dest, query, args...)
  92. }
  93. func (s *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
  94. tx, ok := ctx.Value(TxKey{}).(*sqlx.Tx)
  95. if ok {
  96. return tx.ExecContext(ctx, query, args...)
  97. }
  98. return s.db.ExecContext(ctx, query, args...)
  99. }
  100. func (s *DB) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
  101. tx, ok := ctx.Value(TxKey{}).(*sqlx.Tx)
  102. if ok {
  103. return tx.QueryRowContext(ctx, query, args...)
  104. }
  105. return s.db.QueryRowContext(ctx, query, args...)
  106. }