Skip to content

Commit

Permalink
Merge pull request #99 from mailru/evilaffliction-master
Browse files Browse the repository at this point in the history
make driver compatible to chproxy
  • Loading branch information
DoubleDi authored Mar 28, 2021
2 parents e108514 + 24c3508 commit 1430cf0
Show file tree
Hide file tree
Showing 10 changed files with 84 additions and 92 deletions.
5 changes: 2 additions & 3 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@ notifications:
email: false
language: go
go:
- '1.11'
- '1.12'
- '1.13'
- '1.14'

- '1.15'
- '1.16'
services:
- docker

Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Yet another Golang SQL database driver for [Yandex ClickHouse](https://clickhous

* Uses official http interface
* Compatibility with [dbr](https://github.com/mailru/dbr)
* Compatibility with [chproxy](https://github.com/Vertamedia/chproxy)

## DSN
```
Expand Down
6 changes: 2 additions & 4 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,8 @@ func (cfg *Config) url(extra map[string]string, dsn bool) *url.URL {
for k, v := range cfg.Params {
query.Set(k, v)
}
if extra != nil {
for k, v := range extra {
query.Set(k, v)
}
for k, v := range extra {
query.Set(k, v)
}

u.RawQuery = query.Encode()
Expand Down
85 changes: 51 additions & 34 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,6 @@ func (c *conn) doRequest(ctx context.Context, req *http.Request) (io.ReadCloser,
}
return nil, err
}

return resp.Body, nil
}

Expand All @@ -291,65 +290,83 @@ func (c *conn) buildRequest(ctx context.Context, query string, params []driver.V
method string
err error
)
if params != nil && len(params) > 0 {
if len(params) > 0 {
if query, err = interpolateParams(query, params); err != nil {
return nil, err
}
}

var (
bodyReader io.Reader
bodyWriter io.WriteCloser
)
if readonly {
method = http.MethodGet
} else {
method = http.MethodPost
bodyReader, bodyWriter = io.Pipe()
go func() {
if c.useGzipCompression {
gz := gzip.NewWriter(bodyWriter)
gz.Write([]byte(query))
gz.Close()
bodyWriter.Close()
} else {
bodyWriter.Write([]byte(query))
bodyWriter.Close()
}
}()
}
c.log("query: ", query)

bodyReader, bodyWriter := io.Pipe()
go func() {
if c.useGzipCompression {
gz := gzip.NewWriter(bodyWriter)
_, err := gz.Write([]byte(query))
gz.Close()
bodyWriter.CloseWithError(err)
} else {
bodyWriter.Write([]byte(query))
bodyWriter.Close()
}
}()

req, err := http.NewRequest(method, c.url.String(), bodyReader)
if err != nil {
return nil, err
}

// http.Transport ignores url.User argument, handle it here
if c.user != nil {
p, _ := c.user.Password()
req.SetBasicAuth(c.user.Username(), p)
}
var queryID, quotaKey string
if ctx != nil {
quotaKey, _ = ctx.Value(QuotaKey).(string)
queryID, _ = ctx.Value(QueryID).(string)
}

if c.killQueryOnErr && queryID == "" {
queryUUID, err := uuid.NewV4()
if err != nil {
c.log("can't generate query_id: ", err)
} else {
queryID = queryUUID.String()
var reqQuery url.Values
if ctx != nil {
quotaKey, quotaOk := ctx.Value(QuotaKey).(string)
if quotaOk && quotaKey != "" {
if reqQuery == nil {
reqQuery = req.URL.Query()
}
reqQuery.Add(quotaKeyParamName, quotaKey)
}
queryID, queryOk := ctx.Value(QueryID).(string)
if c.killQueryOnErr && (!queryOk || queryID == "") {
queryUUID, err := uuid.NewV4()
if err != nil {
c.log("can't generate query_id: ", err)
} else {
queryID = queryUUID.String()
}
}
if queryID != "" {
if reqQuery == nil {
reqQuery = req.URL.Query()
}
reqQuery.Add(queryIDParamName, queryID)
}
}

reqQuery := req.URL.Query()
if quotaKey != "" {
reqQuery.Add(quotaKeyParamName, quotaKey)
}
if queryID != "" {
reqQuery.Add(queryIDParamName, queryID)
if method == http.MethodGet {
if reqQuery == nil {
reqQuery = req.URL.Query()
}
reqQuery.Add("query", query)
}
if reqQuery != nil {
req.URL.RawQuery = reqQuery.Encode()
}
req.URL.RawQuery = reqQuery.Encode()

if c.useGzipCompression {
if method == http.MethodPost && c.useGzipCompression {
req.Header.Set("Content-Encoding", "gzip")
}

Expand Down
16 changes: 6 additions & 10 deletions conn_go18.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ import (
"context"
"database/sql/driver"
"io/ioutil"
"net/http"
"net/url"
"strings"
)

Expand All @@ -16,27 +14,25 @@ func (c *conn) Ping(ctx context.Context) error {
if c.transport == nil {
return ErrTransportNil
}
// make request with empty body, response must be "Ok.\n"
u := &url.URL{Scheme: c.url.Scheme, User: c.url.User, Host: c.url.Host, Path: "/ping"}
req, err := http.NewRequest(http.MethodGet, u.String(), nil)

req, err := c.buildRequest(ctx, "select 1", nil, true)
if err != nil {
return err
}

respBody, err := c.doRequest(ctx, req)
defer func() {
c.cancel = nil
}()
if err != nil {
return err
}

// Close response body to enable connection reuse
defer respBody.Close()
resp, err := ioutil.ReadAll(respBody)
if err != nil {
return err
}
if len(resp) != 4 || !strings.HasPrefix(string(resp), "Ok.") {
return ErrIncorrectResponse
if err != nil || !strings.HasPrefix(string(resp), "1") {
return driver.ErrBadConn
}
return nil
}
Expand Down
3 changes: 2 additions & 1 deletion conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"database/sql/driver"
"io/ioutil"
"net/http"
"net/url"
"testing"
"time"

Expand Down Expand Up @@ -218,7 +219,7 @@ func (s *connSuite) TestBuildRequestReadonlyWithAuth() {
s.Equal("user", user)
s.Equal("password", password)
s.Equal(http.MethodGet, req.Method)
s.Equal(cn.url.String(), req.URL.String())
s.Equal(cn.url.String()+"&query="+url.QueryEscape("SELECT 1"), req.URL.String())
s.Nil(req.URL.User)
}
}
Expand Down
4 changes: 2 additions & 2 deletions dataparser.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ type tupleParser struct {
}

func (p *tupleParser) Type() reflect.Type {
fields := make([]reflect.StructField, len(p.args), len(p.args))
fields := make([]reflect.StructField, len(p.args))
for i, arg := range p.args {
fields[i].Name = "Field" + strconv.Itoa(i)
fields[i].Type = arg.Type()
Expand Down Expand Up @@ -573,7 +573,7 @@ func newDataParser(t *TypeDesc, unquote bool, opt *DataParserOptions) (DataParse
if len(t.Args) < 1 {
return nil, fmt.Errorf("element types not specified for Tuple")
}
subParsers := make([]DataParser, len(t.Args), len(t.Args))
subParsers := make([]DataParser, len(t.Args))
for i, arg := range t.Args {
subParser, err := newDataParser(arg, true, opt)
if err != nil {
Expand Down
30 changes: 0 additions & 30 deletions helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,33 +39,3 @@ func readResponse(response *http.Response) (result []byte, err error) {
result = buf.Bytes()
return
}

func numOfColumns(data []byte) int {
var cnt int
for _, ch := range data {
switch ch {
case '\t':
cnt++
case '\n':
return cnt + 1
}
}
return -1
}

// splitTSV splits one row of tab separated values, returns begin of next row
func splitTSV(data []byte, out []string) int {
var i, k int
for j, ch := range data {
switch ch {
case '\t':
out[k] = string(data[i:j])
k++
i = j + 1
case '\n':
out[k] = string(data[i:j])
return j + 1
}
}
return -1
}
2 changes: 1 addition & 1 deletion rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func newTextRows(c *conn, body io.ReadCloser, location *time.Location, useDBLoca
}
}

parsers := make([]DataParser, len(types), len(types))
parsers := make([]DataParser, len(types))
for i, typ := range types {
desc, err := ParseTypeDesc(typ)
if err != nil {
Expand Down
24 changes: 17 additions & 7 deletions stmt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,12 @@ func (s *stmtSuite) TestExecMulti() {
require.NoError(err)
st, err := tx.Prepare(tc.insertQuery)
require.NoError(err)
st.Exec(tc.exec1)
st.Exec(tc.exec2)
_, err = st.Exec(tc.exec1)
require.NoError(err)
_, err = st.Exec(tc.exec2)
require.NoError(err)
rows, err := s.conn.Query(tc.query1)
require.NoError(err)
s.False(rows.Next())
s.NoError(rows.Close())
require.NoError(tx.Commit())
Expand All @@ -167,9 +170,12 @@ func (s *stmtSuite) TestExecMultiRollback() {
require.NoError(err)
st, err := tx.Prepare("INSERT INTO data (i64) VALUES (?)")
require.NoError(err)
st.Exec(31)
st.Exec(32)
_, err = st.Exec(31)
require.NoError(err)
_, err = st.Exec(32)
require.NoError(err)
rows, err := s.conn.Query("SELECT i64 FROM data WHERE i64=31")
s.NoError(err)
s.False(rows.Next())
s.NoError(rows.Close())
require.NoError(tx.Rollback())
Expand All @@ -188,9 +194,12 @@ func (s *stmtSuite) TestExecMultiInterrupt() {
require.NoError(err)
st2, err := tx.Prepare("INSERT INTO data (i64) VALUES (?)")
require.NoError(err)
st.Exec(31)
st.Exec(32)
_, err = st.Exec(31)
require.NoError(err)
_, err = st.Exec(32)
require.NoError(err)
rows, err := s.conn.Query("SELECT i64 FROM data WHERE i64=31")
s.NoError(err)
s.False(rows.Next())
s.NoError(rows.Close())
require.NoError(st.Close())
Expand All @@ -209,7 +218,8 @@ func (s *stmtSuite) TestFixDoubleInterpolateInStmt() {
st, err := tx.Prepare("INSERT INTO data (s, s2) VALUES (?, ?)")
require.NoError(err)
args := []interface{}{"'", "?"}
st.Exec(args...)
_, err = st.Exec(args...)
require.NoError(err)
require.NoError(tx.Commit())
require.NoError(st.Close())
rows, err := s.conn.Query("SELECT s, s2 FROM data WHERE s='\\'' AND s2='?'")
Expand Down

0 comments on commit 1430cf0

Please sign in to comment.