Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing #5884 #5885

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 88 additions & 27 deletions lib/impure/db_mysql.nim
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,23 @@ type
## column text on demand
row: cstringArray
len: int

SQLInputKind = enum
SQLSIntKind, SQLUIntKind, SQLFloatKind, SQLStringKind, SQLNullKind

SQLInput = object
case kind*: SQLInputKind
of SQLSIntKind:
num*: BiggestInt
of SQLUIntKind:
unum*: BiggestUInt
of SQLFloatKind:
fnum*: BiggestFloat
of SQLStringKind:
str*: string
of SQLNullKind:
nil

{.deprecated: [TRow: Row, TDbConn: DbConn].}

proc dbError*(db: DbConn) {.noreturn.} =
Expand All @@ -106,7 +123,7 @@ proc dbError*(db: DbConn) {.noreturn.} =
raise e

when false:
proc dbQueryOpt*(db: DbConn, query: string, args: varargs[string, `$`]) =
proc dbQueryOpt*(db: DbConn, query: string, args: varargs[SQLInput, toSQLInput]) =
var stmt = mysql_stmt_init(db)
if stmt == nil: dbError(db)
if mysql_stmt_prepare(stmt, query, len(query)) != 0:
Expand All @@ -115,38 +132,82 @@ when false:
binding: seq[MYSQL_BIND]
discard mysql_stmt_close(stmt)

proc dbQuote*(s: string): string =
## DB quotes the string.
result = "'"
for c in items(s):
if c == '\'': add(result, "''")
else: add(result, c)
add(result, '\'')
proc dbQuote*(s: SQLInput): string =
## Database sanitizes the SQLInput
case s.kind:
of SQLStringKind:
result = newStringOfCap(s.str.len + 2)
result.add "'"
for c in items(s.str):
# Substitution rules from here "https://www.owasp.org/index.php/SQL_Injection_Prevention_Cheat_Sheet"
case c:
of '\0': result.add "\\0"
of '\b': result.add "\\b"
of '\t': result.add "\\t"
of '\l': result.add "\\n"
of '\r': result.add "\\r"
of '\x1a': result.add "\\Z"
of '"': result.add "\\\""
of '%': result.add "\\%"
of '\'': result.add "\\'"
of '\\': result.add "\\\\"
of '_': result.add "\\_"
of Letters+Digits: result.add c
else: result.add "\\" & c
add(result, '\'')
of SQLSIntKind:
result = $s.num
of SQLUIntKind:
result = $s.unum
of SQLFloatKind:
result = $s.fnum
of SQLNullKind:
result = "NULL"

proc toSQLInput*(x: SomeInteger): SQLInput =
when x is SomeSignedInt:
result.kind = SQLSIntKind
result.num = x
when x is SomeUnsignedInt:
result.kind = SQLUIntKind
result.unum = x

proc toSQLInput*(x: bool): SQLInput =
result.kind = SQLUIntKind
result.unum = if x: 1 else: 0

proc toSQLInput*(x: SomeReal): SQLInput =
result.kind = SQLFloatKind
result.fnum = x

proc toSQLInput*(x: string): SQLInput =
if x==nil:
result.kind = SQLNullKind
else:
result.kind = SQLStringKind
result.str = x

proc dbFormat(formatstr: SqlQuery, args: varargs[string]): string =
proc dbFormat(formatstr: SqlQuery, args: varargs[SQLInput, toSQLInput]): string =
result = ""
var a = 0
for c in items(string(formatstr)):
if c == '?':
if args[a] == nil:
add(result, "NULL")
else:
add(result, dbQuote(args[a]))
add(result, dbQuote(args[a]))
inc(a)
else:
add(result, c)

proc tryExec*(db: DbConn, query: SqlQuery, args: varargs[string, `$`]): bool {.
proc tryExec*(db: DbConn, query: SqlQuery, args: varargs[SQLInput, toSQLInput]): bool {.
tags: [ReadDbEffect, WriteDbEffect].} =
## tries to execute the query and returns true if successful, false otherwise.
var q = dbFormat(query, args)
return mysql.realQuery(db, q, q.len) == 0'i32

proc rawExec(db: DbConn, query: SqlQuery, args: varargs[string, `$`]) =
proc rawExec(db: DbConn, query: SqlQuery, args: varargs[SQLInput, toSQLInput]) =
var q = dbFormat(query, args)
if mysql.realQuery(db, q, q.len) != 0'i32: dbError(db)

proc exec*(db: DbConn, query: SqlQuery, args: varargs[string, `$`]) {.
proc exec*(db: DbConn, query: SqlQuery, args: varargs[SQLInput, toSQLInput]) {.
tags: [ReadDbEffect, WriteDbEffect].} =
## executes the query and raises EDB if not successful.
var q = dbFormat(query, args)
Expand All @@ -162,7 +223,7 @@ proc properFreeResult(sqlres: mysql.PRES, row: cstringArray) =
mysql.freeResult(sqlres)

iterator fastRows*(db: DbConn, query: SqlQuery,
args: varargs[string, `$`]): Row {.tags: [ReadDbEffect].} =
args: varargs[SQLInput, toSQLInput]): Row {.tags: [ReadDbEffect].} =
## executes the query and iterates over the result dataset.
##
## This is very fast, but potentially dangerous. Use this iterator only
Expand Down Expand Up @@ -205,7 +266,7 @@ iterator fastRows*(db: DbConn, query: SqlQuery,
properFreeResult(sqlres, row)

iterator instantRows*(db: DbConn, query: SqlQuery,
args: varargs[string, `$`]): InstantRow
args: varargs[SQLInput, toSQLInput]): InstantRow
{.tags: [ReadDbEffect].} =
## Same as fastRows but returns a handle that can be used to get column text
## on demand using []. Returned handle is valid only within the iterator body.
Expand Down Expand Up @@ -286,7 +347,7 @@ proc setColumnInfo(columns: var DbColumns; res: PRES; L: int) =
#columns[i].foreignKey = there is no such thing in mysql

iterator instantRows*(db: DbConn; columns: var DbColumns; query: SqlQuery;
args: varargs[string, `$`]): InstantRow =
args: varargs[SQLInput, toSQLInput]): InstantRow =
## Same as fastRows but returns a handle that can be used to get column text
## on demand using []. Returned handle is valid only within the iterator body.
rawExec(db, query, args)
Expand All @@ -311,7 +372,7 @@ proc len*(row: InstantRow): int {.inline.} =
row.len

proc getRow*(db: DbConn, query: SqlQuery,
args: varargs[string, `$`]): Row {.tags: [ReadDbEffect].} =
args: varargs[SQLInput, toSQLInput]): Row {.tags: [ReadDbEffect].} =
## Retrieves a single row. If the query doesn't return any rows, this proc
## will return a Row with empty strings for each column.
rawExec(db, query, args)
Expand All @@ -330,7 +391,7 @@ proc getRow*(db: DbConn, query: SqlQuery,
properFreeResult(sqlres, row)

proc getAllRows*(db: DbConn, query: SqlQuery,
args: varargs[string, `$`]): seq[Row] {.tags: [ReadDbEffect].} =
args: varargs[SQLInput, toSQLInput]): seq[Row] {.tags: [ReadDbEffect].} =
## executes the query and returns the whole result dataset.
result = @[]
rawExec(db, query, args)
Expand All @@ -353,19 +414,19 @@ proc getAllRows*(db: DbConn, query: SqlQuery,
mysql.freeResult(sqlres)

iterator rows*(db: DbConn, query: SqlQuery,
args: varargs[string, `$`]): Row {.tags: [ReadDbEffect].} =
args: varargs[SQLInput, toSQLInput]): Row {.tags: [ReadDbEffect].} =
## same as `fastRows`, but slower and safe.
for r in items(getAllRows(db, query, args)): yield r

proc getValue*(db: DbConn, query: SqlQuery,
args: varargs[string, `$`]): string {.tags: [ReadDbEffect].} =
args: varargs[SQLInput, toSQLInput]): string {.tags: [ReadDbEffect].} =
## executes the query and returns the first column of the first row of the
## result dataset. Returns "" if the dataset contains no rows or the database
## value is NULL.
result = getRow(db, query, args)[0]

proc tryInsertId*(db: DbConn, query: SqlQuery,
args: varargs[string, `$`]): int64 {.tags: [WriteDbEffect].} =
args: varargs[SQLInput, toSQLInput]): int64 {.tags: [WriteDbEffect].} =
## executes the query (typically "INSERT") and returns the
## generated ID for the row or -1 in case of an error.
var q = dbFormat(query, args)
Expand All @@ -375,14 +436,14 @@ proc tryInsertId*(db: DbConn, query: SqlQuery,
result = mysql.insertId(db)

proc insertId*(db: DbConn, query: SqlQuery,
args: varargs[string, `$`]): int64 {.tags: [WriteDbEffect].} =
args: varargs[SQLInput, toSQLInput]): int64 {.tags: [WriteDbEffect].} =
## executes the query (typically "INSERT") and returns the
## generated ID for the row.
result = tryInsertID(db, query, args)
if result < 0: dbError(db)

proc execAffectedRows*(db: DbConn, query: SqlQuery,
args: varargs[string, `$`]): int64 {.
args: varargs[SQLInput, toSQLInput]): int64 {.
tags: [ReadDbEffect, WriteDbEffect].} =
## runs the query (typically "UPDATE") and returns the
## number of affected rows
Expand All @@ -408,7 +469,7 @@ proc open*(connection, user, password, database: string): DbConn {.
if mysql.realConnect(result, host, user, password, database,
port, nil, 0) == nil:
var errmsg = $mysql.error(result)
db_mysql.close(result)
mydb_mysql.close(result)
dbError(errmsg)

proc setEncoding*(connection: DbConn, encoding: string): bool {.
Expand Down