forked from advania/pass
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdb.go
151 lines (120 loc) · 3.14 KB
/
db.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
package pass
import (
"database/sql"
"net/http"
"sync"
)
type database interface {
ping() (err error)
get(uuid string, sitePassword string) (password string, err error)
create(password string, sitePassword string) (uuid string, err error)
close()
}
type mustDatabaseConnect struct {
dbHandle database
mutex sync.Mutex
connect func() (db database, err error)
}
func (nfp *mustDatabaseConnect) doConnect() (err error) {
nfp.mutex.Lock()
defer nfp.mutex.Unlock()
if nfp.dbHandle == nil {
var err error
if nfp.dbHandle, err = nfp.connect(); err != nil {
return err
}
}
return nil
}
func (nfp *mustDatabaseConnect) ping() (err error) {
if err := nfp.doConnect(); err != nil {
return err
}
return nfp.dbHandle.ping()
}
func (nfp *mustDatabaseConnect) get(uuid string, sitePassword string) (password string, err error) {
if err := nfp.doConnect(); err != nil {
return "", err
}
return nfp.dbHandle.get(uuid, sitePassword)
}
func (nfp *mustDatabaseConnect) create(password string, sitePassword string) (uuid string, err error) {
if err = nfp.doConnect(); err != nil {
return "", err
}
return nfp.dbHandle.create(password, sitePassword)
}
func (nfp *mustDatabaseConnect) close() {
nfp.mutex.Lock()
defer nfp.mutex.Unlock()
if nfp.dbHandle != nil {
nfp.dbHandle.close()
nfp.dbHandle = nil
}
}
func newMustDatabaseConnect(connect func() (db database, err error)) (dbc *mustDatabaseConnect) {
return &mustDatabaseConnect{
dbHandle: nil,
connect: connect,
}
}
type databaseConnection struct {
db *sql.DB
createStmt *sql.Stmt
getStmt *sql.Stmt
}
func (pp *databaseConnection) ping() (err error) {
return pp.db.Ping()
}
func (pp *databaseConnection) get(uuid string, sitePassword string) (password string, err error) {
var result *sql.Rows
if result, err = pp.getStmt.Query(uuid, sitePassword); err != nil {
return "", NewHTTPError(http.StatusInternalServerError, err)
}
defer result.Close()
result.Next()
result.Scan(&password)
return password, nil
}
func (pp *databaseConnection) create(password string, sitePassword string) (uuid string, err error) {
var result *sql.Rows
if result, err = pp.createStmt.Query(password, sitePassword); err != nil {
return "", NewHTTPError(http.StatusInternalServerError, err)
}
defer result.Close()
result.Next()
result.Scan(&uuid)
return uuid, nil
}
func (pp *databaseConnection) close() {
if pp != nil {
pp.getStmt.Close()
pp.createStmt.Close()
pp.db.Close()
}
}
func newDatabaseConnectionVariables(sv *serverVariables) (dbConn database, err error) {
var db *sql.DB
if db, err = sql.Open("postgres", sv.cfg.PDOString); err != nil {
return
}
if err = db.Ping(); err != nil {
return
}
if _, err = db.Exec("set session characteristics as transaction isolation level serializable"); err != nil {
return
}
var createStmt *sql.Stmt
if createStmt, err = db.Prepare("select * from create_password($1, $2)"); err != nil {
return
}
var getStmt *sql.Stmt
if getStmt, err = db.Prepare("select * from get_password($1, $2)"); err != nil {
return
}
return &databaseConnection{
db: db,
createStmt: createStmt,
getStmt: getStmt,
}, err
}