thing.go 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. package repositories
  2. //go:generate mkdir -p mocks
  3. //go:generate rm -rf ./mocks/*_minimock.go
  4. //go:generate minimock -i git.dmitriygnatenko.ru/dima/homethings/internal/interfaces.ThingRepository -o ./mocks/ -s "_minimock.go"
  5. import (
  6. "context"
  7. "database/sql"
  8. "errors"
  9. "fmt"
  10. "git.dmitriygnatenko.ru/dima/homethings/internal/interfaces"
  11. "git.dmitriygnatenko.ru/dima/homethings/internal/models"
  12. sq "github.com/Masterminds/squirrel"
  13. )
  14. const (
  15. thingTableName = "thing"
  16. )
  17. type thingRepository struct {
  18. db *sql.DB
  19. }
  20. func InitThingRepository(db *sql.DB) interfaces.ThingRepository {
  21. return thingRepository{db: db}
  22. }
  23. func (r thingRepository) BeginTx(ctx context.Context, level sql.IsolationLevel) (*sql.Tx, error) {
  24. return r.db.BeginTx(ctx, &sql.TxOptions{Isolation: level})
  25. }
  26. func (r thingRepository) CommitTx(tx *sql.Tx) error {
  27. if tx == nil {
  28. return errors.New("empty transaction")
  29. }
  30. return tx.Commit()
  31. }
  32. func (r thingRepository) Get(ctx context.Context, thingID int) (*models.Thing, error) {
  33. query, args, err := sq.Select("t.id", "t.title", "t.description", "t.created_at", "t.updated_at", "p.place_id").
  34. From(thingTableName + " t").
  35. Join(placeThingTableName + " p ON p.thing_id = t.id").
  36. PlaceholderFormat(sq.Dollar).
  37. Where(sq.Eq{"id": thingID}).
  38. ToSql()
  39. if err != nil {
  40. return nil, err
  41. }
  42. var res models.Thing
  43. err = r.db.QueryRowContext(ctx, query, args...).
  44. Scan(&res.ID, &res.Title, &res.Description, &res.CreatedAt, &res.UpdatedAt, &res.PlaceID)
  45. if err != nil {
  46. return nil, err
  47. }
  48. return &res, nil
  49. }
  50. func (r thingRepository) Search(ctx context.Context, search string) ([]models.Thing, error) {
  51. var res []models.Thing
  52. s := fmt.Sprint("%", search, "%")
  53. query, args, err := sq.Select("t.id", "t.title", "t.description", "t.created_at", "t.updated_at", "p.place_id").
  54. From(thingTableName+" t").
  55. Join(placeThingTableName+" p ON p.thing_id = t.id").
  56. PlaceholderFormat(sq.Dollar).
  57. Where("t.title ILIKE ? OR t.description ILIKE ?", s, s).
  58. OrderBy("t.updated_at DESC").
  59. ToSql()
  60. if err != nil {
  61. return nil, err
  62. }
  63. rows, err := r.db.QueryContext(ctx, query, args...)
  64. if err != nil {
  65. return nil, err
  66. }
  67. defer rows.Close()
  68. for rows.Next() {
  69. resRow := models.Thing{}
  70. err = rows.Scan(
  71. &resRow.ID,
  72. &resRow.Title,
  73. &resRow.Description,
  74. &resRow.CreatedAt,
  75. &resRow.UpdatedAt,
  76. &resRow.PlaceID,
  77. )
  78. if err != nil {
  79. return nil, err
  80. }
  81. res = append(res, resRow)
  82. }
  83. if err = rows.Err(); err != nil {
  84. return nil, err
  85. }
  86. return res, nil
  87. }
  88. func (r thingRepository) GetByPlaceID(ctx context.Context, placeID int) ([]models.Thing, error) {
  89. var res []models.Thing
  90. query, args, err := sq.Select("t.id", "t.title", "t.description", "t.created_at", "t.updated_at", "p.place_id").
  91. From(thingTableName + " t").
  92. Join(placeThingTableName + " p ON p.thing_id = t.id").
  93. PlaceholderFormat(sq.Dollar).
  94. Where(sq.Eq{"p.place_id": placeID}).
  95. OrderBy("t.updated_at DESC").
  96. ToSql()
  97. if err != nil {
  98. return nil, err
  99. }
  100. rows, err := r.db.QueryContext(ctx, query, args...)
  101. if err != nil {
  102. return nil, err
  103. }
  104. defer rows.Close()
  105. for rows.Next() {
  106. resRow := models.Thing{}
  107. err = rows.Scan(
  108. &resRow.ID,
  109. &resRow.Title,
  110. &resRow.Description,
  111. &resRow.CreatedAt,
  112. &resRow.UpdatedAt,
  113. &resRow.PlaceID,
  114. )
  115. if err != nil {
  116. return nil, err
  117. }
  118. res = append(res, resRow)
  119. }
  120. if err = rows.Err(); err != nil {
  121. return nil, err
  122. }
  123. return res, nil
  124. }
  125. // GetAllByPlaceID return things by place ID and all child places
  126. func (r thingRepository) GetAllByPlaceID(ctx context.Context, placeID int) ([]models.Thing, error) {
  127. var res []models.Thing
  128. query := "WITH RECURSIVE cte (id, parent_id) AS (" +
  129. "SELECT id, parent_id " +
  130. "FROM " + placeTableName + " " +
  131. "WHERE id = $1 " +
  132. "UNION ALL " +
  133. "SELECT p.id, p.parent_id " +
  134. "FROM " + placeTableName + " p " +
  135. "INNER JOIN cte ON p.parent_id = cte.id " +
  136. ")" +
  137. "SELECT t.id, t.title, t.description, t.created_at, t.updated_at, pt.place_id " +
  138. "FROM cte, " + placeThingTableName + " pt, " + thingTableName + " t " +
  139. "WHERE pt.place_id = cte.id and t.id = pt.thing_id " +
  140. "ORDER BY t.updated_at DESC"
  141. rows, err := r.db.QueryContext(ctx, query, placeID)
  142. if err != nil {
  143. return nil, err
  144. }
  145. defer rows.Close()
  146. for rows.Next() {
  147. resRow := models.Thing{}
  148. err = rows.Scan(
  149. &resRow.ID,
  150. &resRow.Title,
  151. &resRow.Description,
  152. &resRow.CreatedAt,
  153. &resRow.UpdatedAt,
  154. &resRow.PlaceID,
  155. )
  156. if err != nil {
  157. return nil, err
  158. }
  159. res = append(res, resRow)
  160. }
  161. if err = rows.Err(); err != nil {
  162. return nil, err
  163. }
  164. return res, nil
  165. }
  166. func (r thingRepository) Add(ctx context.Context, req models.AddThingRequest, tx *sql.Tx) (int, error) {
  167. query, args, err := sq.Insert(thingTableName).
  168. PlaceholderFormat(sq.Dollar).
  169. Columns("title", "description").
  170. Values(req.Title, req.Description).
  171. Suffix("RETURNING id").
  172. ToSql()
  173. if err != nil {
  174. return 0, err
  175. }
  176. var id int
  177. if tx == nil {
  178. err = r.db.QueryRowContext(ctx, query, args...).Scan(&id)
  179. } else {
  180. err = tx.QueryRowContext(ctx, query, args...).Scan(&id)
  181. }
  182. if err != nil {
  183. return 0, err
  184. }
  185. return id, nil
  186. }
  187. func (r thingRepository) Update(ctx context.Context, req models.UpdateThingRequest, tx *sql.Tx) error {
  188. query, args, err := sq.Update(thingTableName).
  189. PlaceholderFormat(sq.Dollar).
  190. Set("title", req.Title).
  191. Set("description", req.Description).
  192. Set("updated_at", "NOW()").
  193. Where(sq.Eq{"id": req.ID}).
  194. ToSql()
  195. if err != nil {
  196. return err
  197. }
  198. if tx == nil {
  199. _, err = r.db.ExecContext(ctx, query, args...)
  200. } else {
  201. _, err = tx.ExecContext(ctx, query, args...)
  202. }
  203. return err
  204. }
  205. func (r thingRepository) Delete(ctx context.Context, thingID int, tx *sql.Tx) error {
  206. query, args, err := sq.Delete(thingTableName).
  207. PlaceholderFormat(sq.Dollar).
  208. Where(sq.Eq{"id": thingID}).
  209. ToSql()
  210. if err != nil {
  211. return err
  212. }
  213. if tx == nil {
  214. _, err = r.db.ExecContext(ctx, query, args...)
  215. } else {
  216. _, err = tx.ExecContext(ctx, query, args...)
  217. }
  218. return err
  219. }