123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137 |
- package db
- import (
- "context"
- "database/sql"
- "errors"
- "fmt"
- "strconv"
- "github.com/jmoiron/sqlx"
- )
- type TxKey struct{}
- type DB struct {
- db *sqlx.DB
- }
- func NewDB(c Config) (*DB, error) {
- if len(c.username) == 0 {
- return nil, errors.New("empty username")
- }
- if len(c.password) == 0 {
- return nil, errors.New("empty password")
- }
- if len(c.dbname) == 0 {
- return nil, errors.New("empty database name")
- }
- if len(c.driver) == 0 {
- c.driver = defaultDriver
- }
- if len(c.host) == 0 {
- c.host = defaultHost
- }
- if c.port == 0 {
- c.port = defaultPort
- }
- var source string
- switch c.driver {
- case "mysql":
- source = c.username +
- ":" + c.password +
- "@tcp(" + c.host + ":" + strconv.Itoa(int(c.port)) + ")/" +
- c.dbname +
- "?parseTime=true"
- case "postgres":
- source = "user=" + c.username +
- " password=" + c.password +
- " dbname=/" + c.dbname +
- " host=" + c.host +
- " port=" + strconv.Itoa(int(c.port))
- }
- sqlConn, err := sql.Open(c.driver, source)
- if err != nil {
- return nil, fmt.Errorf("open DB connection error: %w", err)
- }
- db := sqlx.NewDb(sqlConn, c.driver)
- if c.maxOpenConns > 0 {
- db.SetMaxOpenConns(int(c.maxOpenConns))
- }
- if c.maxOpenConns > 0 {
- db.SetMaxIdleConns(int(c.maxIdleConns))
- }
- if c.maxOpenConnLifetime != nil {
- db.SetConnMaxLifetime(*c.maxOpenConnLifetime)
- }
- if c.maxIdleConnLifetime != nil {
- db.SetConnMaxIdleTime(*c.maxIdleConnLifetime)
- }
- if err = db.Ping(); err != nil {
- return nil, fmt.Errorf("DB ping error: %w", err)
- }
- return &DB{db: db}, nil
- }
- func (s *DB) Close() error {
- return s.db.Close()
- }
- func (s *DB) Ping() error {
- return s.db.Ping()
- }
- func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sqlx.Tx, error) {
- return s.db.BeginTxx(ctx, opts)
- }
- func (s *DB) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
- tx, ok := ctx.Value(TxKey{}).(*sqlx.Tx)
- if ok {
- return tx.SelectContext(ctx, dest, query, args...)
- }
- return s.db.SelectContext(ctx, dest, query, args...)
- }
- func (s *DB) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
- tx, ok := ctx.Value(TxKey{}).(*sqlx.Tx)
- if ok {
- return tx.GetContext(ctx, dest, query, args...)
- }
- return s.db.GetContext(ctx, dest, query, args...)
- }
- func (s *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
- tx, ok := ctx.Value(TxKey{}).(*sqlx.Tx)
- if ok {
- return tx.ExecContext(ctx, query, args...)
- }
- return s.db.ExecContext(ctx, query, args...)
- }
- func (s *DB) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
- tx, ok := ctx.Value(TxKey{}).(*sqlx.Tx)
- if ok {
- return tx.QueryRowContext(ctx, query, args...)
- }
- return s.db.QueryRowContext(ctx, query, args...)
- }
|