-
Notifications
You must be signed in to change notification settings - Fork 413
/
repository.go
136 lines (113 loc) · 2.77 KB
/
repository.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
package main
import (
"context"
"database/sql"
"encoding/json"
"time"
)
const migration = `
CREATE TABLE IF NOT EXISTS posts (
id serial PRIMARY KEY,
author VARCHAR NOT NULL,
content TEXT NOT NULL,
views INT NOT NULL DEFAULT 0,
reactions JSONB NOT NULL DEFAULT '{}',
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
INSERT INTO posts (id, author, content) VALUES
(1, 'Miłosz', 'Oh, I remember the days when we used to write code in PHP!'),
(2, 'Robert', 'Back in my days, we used to write code in assembly!')
ON CONFLICT (id) DO NOTHING;
`
func MigrateDB(db *sql.DB) error {
_, err := db.Exec(migration)
return err
}
type Repository struct {
db *sql.DB
}
func NewRepository(db *sql.DB) *Repository {
return &Repository{
db: db,
}
}
func (s *Repository) PostByID(ctx context.Context, id int) (Post, error) {
row := s.db.QueryRowContext(ctx, `SELECT id, author, content, views, reactions, created_at FROM posts WHERE id = $1`, id)
post, err := scanPost(row)
if err != nil {
return Post{}, err
}
return post, nil
}
func (s *Repository) AllPosts(ctx context.Context) ([]Post, error) {
rows, err := s.db.QueryContext(ctx, `SELECT id, author, content, views, reactions, created_at FROM posts ORDER BY id ASC`)
if err != nil {
return nil, err
}
var posts []Post
for rows.Next() {
post, err := scanPost(rows)
if err != nil {
return nil, err
}
posts = append(posts, post)
}
return posts, nil
}
func (s *Repository) UpdatePost(ctx context.Context, id int, updateFn func(post *Post)) (err error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer func() {
if err == nil {
err = tx.Commit()
} else {
txErr := tx.Rollback()
if txErr != nil {
err = txErr
}
}
}()
row := s.db.QueryRowContext(ctx, `SELECT id, author, content, views, reactions, created_at FROM posts WHERE id = $1 FOR UPDATE`, id)
post, err := scanPost(row)
if err != nil {
return err
}
updateFn(&post)
reactionsJSON, err := json.Marshal(post.Reactions)
if err != nil {
return err
}
_, err = tx.ExecContext(ctx, `UPDATE posts SET views = $1, reactions = $2 WHERE id = $3`, post.Views, reactionsJSON, post.ID)
if err != nil {
return err
}
return nil
}
type scanner interface {
Scan(dest ...any) error
}
func scanPost(s scanner) (Post, error) {
var id, postViews int
var author, content string
var reactions []byte
var createdAt time.Time
err := s.Scan(&id, &author, &content, &postViews, &reactions, &createdAt)
if err != nil {
return Post{}, err
}
var reactionsMap map[string]int
err = json.Unmarshal(reactions, &reactionsMap)
if err != nil {
return Post{}, err
}
return Post{
ID: id,
Author: author,
Content: content,
CreatedAt: createdAt,
Views: postViews,
Reactions: reactionsMap,
}, nil
}