diff --git a/.travis.yml b/.travis.yml index fe76471..98050de 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,11 +3,10 @@ notifications: email: false language: go go: - - '1.11' - - '1.12' - '1.13' - '1.14' - + - '1.15' + - '1.16' services: - docker diff --git a/README.md b/README.md index 2170d0e..607b991 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,7 @@ Yet another Golang SQL database driver for [Yandex ClickHouse](https://clickhous * Uses official http interface * Compatibility with [dbr](https://github.com/mailru/dbr) +* Compatibility with [chproxy](https://github.com/Vertamedia/chproxy) ## DSN ``` diff --git a/config.go b/config.go index 5924d74..e7d50d0 100644 --- a/config.go +++ b/config.go @@ -101,10 +101,8 @@ func (cfg *Config) url(extra map[string]string, dsn bool) *url.URL { for k, v := range cfg.Params { query.Set(k, v) } - if extra != nil { - for k, v := range extra { - query.Set(k, v) - } + for k, v := range extra { + query.Set(k, v) } u.RawQuery = query.Encode() diff --git a/conn.go b/conn.go index 363f4e4..e3a05a0 100644 --- a/conn.go +++ b/conn.go @@ -282,7 +282,6 @@ func (c *conn) doRequest(ctx context.Context, req *http.Request) (io.ReadCloser, } return nil, err } - return resp.Body, nil } @@ -291,65 +290,83 @@ func (c *conn) buildRequest(ctx context.Context, query string, params []driver.V method string err error ) - if params != nil && len(params) > 0 { + if len(params) > 0 { if query, err = interpolateParams(query, params); err != nil { return nil, err } } + + var ( + bodyReader io.Reader + bodyWriter io.WriteCloser + ) if readonly { method = http.MethodGet } else { method = http.MethodPost + bodyReader, bodyWriter = io.Pipe() + go func() { + if c.useGzipCompression { + gz := gzip.NewWriter(bodyWriter) + gz.Write([]byte(query)) + gz.Close() + bodyWriter.Close() + } else { + bodyWriter.Write([]byte(query)) + bodyWriter.Close() + } + }() } c.log("query: ", query) - bodyReader, bodyWriter := io.Pipe() - go func() { - if c.useGzipCompression { - gz := gzip.NewWriter(bodyWriter) - _, err := gz.Write([]byte(query)) - gz.Close() - bodyWriter.CloseWithError(err) - } else { - bodyWriter.Write([]byte(query)) - bodyWriter.Close() - } - }() - req, err := http.NewRequest(method, c.url.String(), bodyReader) if err != nil { return nil, err } + // http.Transport ignores url.User argument, handle it here if c.user != nil { p, _ := c.user.Password() req.SetBasicAuth(c.user.Username(), p) } - var queryID, quotaKey string - if ctx != nil { - quotaKey, _ = ctx.Value(QuotaKey).(string) - queryID, _ = ctx.Value(QueryID).(string) - } - if c.killQueryOnErr && queryID == "" { - queryUUID, err := uuid.NewV4() - if err != nil { - c.log("can't generate query_id: ", err) - } else { - queryID = queryUUID.String() + var reqQuery url.Values + if ctx != nil { + quotaKey, quotaOk := ctx.Value(QuotaKey).(string) + if quotaOk && quotaKey != "" { + if reqQuery == nil { + reqQuery = req.URL.Query() + } + reqQuery.Add(quotaKeyParamName, quotaKey) + } + queryID, queryOk := ctx.Value(QueryID).(string) + if c.killQueryOnErr && (!queryOk || queryID == "") { + queryUUID, err := uuid.NewV4() + if err != nil { + c.log("can't generate query_id: ", err) + } else { + queryID = queryUUID.String() + } + } + if queryID != "" { + if reqQuery == nil { + reqQuery = req.URL.Query() + } + reqQuery.Add(queryIDParamName, queryID) } - } - reqQuery := req.URL.Query() - if quotaKey != "" { - reqQuery.Add(quotaKeyParamName, quotaKey) } - if queryID != "" { - reqQuery.Add(queryIDParamName, queryID) + if method == http.MethodGet { + if reqQuery == nil { + reqQuery = req.URL.Query() + } + reqQuery.Add("query", query) + } + if reqQuery != nil { + req.URL.RawQuery = reqQuery.Encode() } - req.URL.RawQuery = reqQuery.Encode() - if c.useGzipCompression { + if method == http.MethodPost && c.useGzipCompression { req.Header.Set("Content-Encoding", "gzip") } diff --git a/conn_go18.go b/conn_go18.go index 5b8cdc9..16fa33e 100644 --- a/conn_go18.go +++ b/conn_go18.go @@ -6,8 +6,6 @@ import ( "context" "database/sql/driver" "io/ioutil" - "net/http" - "net/url" "strings" ) @@ -16,12 +14,12 @@ func (c *conn) Ping(ctx context.Context) error { if c.transport == nil { return ErrTransportNil } - // make request with empty body, response must be "Ok.\n" - u := &url.URL{Scheme: c.url.Scheme, User: c.url.User, Host: c.url.Host, Path: "/ping"} - req, err := http.NewRequest(http.MethodGet, u.String(), nil) + + req, err := c.buildRequest(ctx, "select 1", nil, true) if err != nil { return err } + respBody, err := c.doRequest(ctx, req) defer func() { c.cancel = nil @@ -29,14 +27,12 @@ func (c *conn) Ping(ctx context.Context) error { if err != nil { return err } + // Close response body to enable connection reuse defer respBody.Close() resp, err := ioutil.ReadAll(respBody) - if err != nil { - return err - } - if len(resp) != 4 || !strings.HasPrefix(string(resp), "Ok.") { - return ErrIncorrectResponse + if err != nil || !strings.HasPrefix(string(resp), "1") { + return driver.ErrBadConn } return nil } diff --git a/conn_test.go b/conn_test.go index 54a637e..10df519 100644 --- a/conn_test.go +++ b/conn_test.go @@ -7,6 +7,7 @@ import ( "database/sql/driver" "io/ioutil" "net/http" + "net/url" "testing" "time" @@ -218,7 +219,7 @@ func (s *connSuite) TestBuildRequestReadonlyWithAuth() { s.Equal("user", user) s.Equal("password", password) s.Equal(http.MethodGet, req.Method) - s.Equal(cn.url.String(), req.URL.String()) + s.Equal(cn.url.String()+"&query="+url.QueryEscape("SELECT 1"), req.URL.String()) s.Nil(req.URL.User) } } diff --git a/dataparser.go b/dataparser.go index 547ddd8..56ec1e2 100644 --- a/dataparser.go +++ b/dataparser.go @@ -268,7 +268,7 @@ type tupleParser struct { } func (p *tupleParser) Type() reflect.Type { - fields := make([]reflect.StructField, len(p.args), len(p.args)) + fields := make([]reflect.StructField, len(p.args)) for i, arg := range p.args { fields[i].Name = "Field" + strconv.Itoa(i) fields[i].Type = arg.Type() @@ -573,7 +573,7 @@ func newDataParser(t *TypeDesc, unquote bool, opt *DataParserOptions) (DataParse if len(t.Args) < 1 { return nil, fmt.Errorf("element types not specified for Tuple") } - subParsers := make([]DataParser, len(t.Args), len(t.Args)) + subParsers := make([]DataParser, len(t.Args)) for i, arg := range t.Args { subParser, err := newDataParser(arg, true, opt) if err != nil { diff --git a/helpers.go b/helpers.go index 3b4cf25..fdfee89 100644 --- a/helpers.go +++ b/helpers.go @@ -39,33 +39,3 @@ func readResponse(response *http.Response) (result []byte, err error) { result = buf.Bytes() return } - -func numOfColumns(data []byte) int { - var cnt int - for _, ch := range data { - switch ch { - case '\t': - cnt++ - case '\n': - return cnt + 1 - } - } - return -1 -} - -// splitTSV splits one row of tab separated values, returns begin of next row -func splitTSV(data []byte, out []string) int { - var i, k int - for j, ch := range data { - switch ch { - case '\t': - out[k] = string(data[i:j]) - k++ - i = j + 1 - case '\n': - out[k] = string(data[i:j]) - return j + 1 - } - } - return -1 -} diff --git a/rows.go b/rows.go index a29c708..9afb36c 100644 --- a/rows.go +++ b/rows.go @@ -31,7 +31,7 @@ func newTextRows(c *conn, body io.ReadCloser, location *time.Location, useDBLoca } } - parsers := make([]DataParser, len(types), len(types)) + parsers := make([]DataParser, len(types)) for i, typ := range types { desc, err := ParseTypeDesc(typ) if err != nil { diff --git a/stmt_test.go b/stmt_test.go index dd8697f..a03f276 100644 --- a/stmt_test.go +++ b/stmt_test.go @@ -145,9 +145,12 @@ func (s *stmtSuite) TestExecMulti() { require.NoError(err) st, err := tx.Prepare(tc.insertQuery) require.NoError(err) - st.Exec(tc.exec1) - st.Exec(tc.exec2) + _, err = st.Exec(tc.exec1) + require.NoError(err) + _, err = st.Exec(tc.exec2) + require.NoError(err) rows, err := s.conn.Query(tc.query1) + require.NoError(err) s.False(rows.Next()) s.NoError(rows.Close()) require.NoError(tx.Commit()) @@ -167,9 +170,12 @@ func (s *stmtSuite) TestExecMultiRollback() { require.NoError(err) st, err := tx.Prepare("INSERT INTO data (i64) VALUES (?)") require.NoError(err) - st.Exec(31) - st.Exec(32) + _, err = st.Exec(31) + require.NoError(err) + _, err = st.Exec(32) + require.NoError(err) rows, err := s.conn.Query("SELECT i64 FROM data WHERE i64=31") + s.NoError(err) s.False(rows.Next()) s.NoError(rows.Close()) require.NoError(tx.Rollback()) @@ -188,9 +194,12 @@ func (s *stmtSuite) TestExecMultiInterrupt() { require.NoError(err) st2, err := tx.Prepare("INSERT INTO data (i64) VALUES (?)") require.NoError(err) - st.Exec(31) - st.Exec(32) + _, err = st.Exec(31) + require.NoError(err) + _, err = st.Exec(32) + require.NoError(err) rows, err := s.conn.Query("SELECT i64 FROM data WHERE i64=31") + s.NoError(err) s.False(rows.Next()) s.NoError(rows.Close()) require.NoError(st.Close()) @@ -209,7 +218,8 @@ func (s *stmtSuite) TestFixDoubleInterpolateInStmt() { st, err := tx.Prepare("INSERT INTO data (s, s2) VALUES (?, ?)") require.NoError(err) args := []interface{}{"'", "?"} - st.Exec(args...) + _, err = st.Exec(args...) + require.NoError(err) require.NoError(tx.Commit()) require.NoError(st.Close()) rows, err := s.conn.Query("SELECT s, s2 FROM data WHERE s='\\'' AND s2='?'")