ソースを参照

Update DB library

Dmitriy Gnatenko 7 ヶ月 前
コミット
85f437340c
2 ファイル変更136 行追加45 行削除
  1. 45 45
      db/db.go
  2. 91 0
      db/db_config.go

+ 45 - 45
db/db.go

@@ -4,58 +4,58 @@ import (
 	"context"
 	"database/sql"
 	"fmt"
-	"time"
+	"strconv"
 )
 
 type TxKey struct{}
 
-type Config interface {
-	Driver() string
-
-	User() string
-	Password() string
-	Name() string
-	Host() string
-	Port() string
-	SSLMode() string
-
-	MaxOpenConns() int
-	MaxIdleConns() int
-	MaxConnLifetime() int     // in seconds
-	MaxIdleConnLifetime() int // in seconds
-}
-
 type DB struct {
 	db *sql.DB
 }
 
 func NewDB(c Config) (*DB, error) {
-	source := "user=" + c.User() +
-		" password=" + c.Password() +
-		" dbname=" + c.Name() +
-		" host=" + c.Host() +
-		" port=" + c.Port() +
-		" sslmode=" + c.SSLMode()
-
-	db, err := sql.Open(c.Driver(), source)
+	if len(c.driver) == 0 {
+		c.driver = defaultDriver
+	}
+
+	if len(c.host) == 0 {
+		c.host = defaultHost
+	}
+
+	if c.port == 0 {
+		c.port = defaultPort
+	}
+
+	if len(c.sslMode) == 0 {
+		c.sslMode = defaultSslMode
+	}
+
+	source := "user=" + c.username +
+		" password=" + c.password +
+		" dbname=" + c.dbname +
+		" host=" + c.host +
+		" port=" + strconv.Itoa(int(c.port)) +
+		" sslmode=" + c.sslMode
+
+	db, err := sql.Open(c.driver, source)
 	if err != nil {
 		return nil, fmt.Errorf("open DB connection error: %w", err)
 	}
 
-	if c.MaxOpenConns() > 0 {
-		db.SetMaxOpenConns(c.MaxOpenConns())
+	if c.maxOpenConns > 0 {
+		db.SetMaxOpenConns(int(c.maxOpenConns))
 	}
 
-	if c.MaxIdleConns() > 0 {
-		db.SetMaxIdleConns(c.MaxIdleConns())
+	if c.maxOpenConns > 0 {
+		db.SetMaxIdleConns(int(c.maxIdleConns))
 	}
 
-	if c.MaxConnLifetime() > 0 {
-		db.SetConnMaxLifetime(time.Second * time.Duration(c.MaxConnLifetime()))
+	if c.maxConnLifetime != nil {
+		db.SetConnMaxLifetime(*c.maxConnLifetime)
 	}
 
-	if c.MaxIdleConnLifetime() > 0 {
-		db.SetConnMaxIdleTime(time.Second * time.Duration(c.MaxIdleConnLifetime()))
+	if c.maxIdleConnLifetime != nil {
+		db.SetConnMaxIdleTime(*c.maxIdleConnLifetime)
 	}
 
 	if err = db.Ping(); err != nil {
@@ -65,41 +65,41 @@ func NewDB(c Config) (*DB, error) {
 	return &DB{db: db}, nil
 }
 
-func (db *DB) Ping() error {
-	return db.db.Ping()
+func (s *DB) Ping() error {
+	return s.db.Ping()
 }
 
-func (db *DB) Close() error {
-	return db.db.Close()
+func (s *DB) Close() error {
+	return s.db.Close()
 }
 
-func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
-	return db.db.BeginTx(ctx, opts)
+func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
+	return s.db.BeginTx(ctx, opts)
 }
 
-func (db *DB) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
+func (s *DB) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
 	tx, ok := ctx.Value(TxKey{}).(*sql.Tx)
 	if ok {
 		return tx.QueryContext(ctx, query, args...)
 	}
 
-	return db.db.QueryContext(ctx, query, args...)
+	return s.db.QueryContext(ctx, query, args...)
 }
 
-func (db *DB) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
+func (s *DB) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
 	tx, ok := ctx.Value(TxKey{}).(*sql.Tx)
 	if ok {
 		return tx.QueryRowContext(ctx, query, args...)
 	}
 
-	return db.db.QueryRowContext(ctx, query, args...)
+	return s.db.QueryRowContext(ctx, query, args...)
 }
 
-func (db *DB) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
+func (s *DB) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
 	tx, ok := ctx.Value(TxKey{}).(*sql.Tx)
 	if ok {
 		return tx.ExecContext(ctx, query, args...)
 	}
 
-	return db.db.ExecContext(ctx, query, args...)
+	return s.db.ExecContext(ctx, query, args...)
 }

+ 91 - 0
db/db_config.go

@@ -0,0 +1,91 @@
+package db
+
+import (
+	"time"
+)
+
+const (
+	defaultDriver  = "mysql"
+	defaultHost    = "localhost"
+	defaultPort    = 3306
+	defaultSslMode = "disabled"
+)
+
+type Config struct {
+	driver string
+
+	username string
+	password string
+	dbname   string
+	host     string
+	port     uint16
+	sslMode  string
+
+	maxOpenConns uint16
+	maxIdleConns uint16
+
+	maxConnLifetime     *time.Duration
+	maxIdleConnLifetime *time.Duration
+}
+
+type ConfigOption func(*Config)
+
+type ConfigOptions []ConfigOption
+
+func (s *ConfigOptions) Add(option ConfigOption) {
+	*s = append(*s, option)
+}
+
+func WithDriver(driver string) ConfigOption {
+	return func(s *Config) {
+		s.driver = driver
+	}
+}
+
+func WithUsername(username string) ConfigOption {
+	return func(s *Config) {
+		s.username = username
+	}
+}
+
+func WithDatabase(dbname string) ConfigOption {
+	return func(s *Config) {
+		s.dbname = dbname
+	}
+}
+
+func WithPassword(password string) ConfigOption {
+	return func(s *Config) {
+		s.password = password
+	}
+}
+
+func WithHost(host string) ConfigOption {
+	return func(s *Config) {
+		s.host = host
+	}
+}
+
+func WithPort(port uint16) ConfigOption {
+	return func(s *Config) {
+		s.port = port
+	}
+}
+
+func WithSSLMode(sslMode string) ConfigOption {
+	return func(s *Config) {
+		s.sslMode = sslMode
+	}
+}
+
+func WithMaxOpenConns(maxOpenConns uint16) ConfigOption {
+	return func(s *Config) {
+		s.maxOpenConns = maxOpenConns
+	}
+}
+
+func WithMaxIdleConns(maxIdleConns uint16) ConfigOption {
+	return func(s *Config) {
+		s.maxIdleConns = maxIdleConns
+	}
+}