diff --git a/db.go b/db.go index b785c98c..3137c95e 100644 --- a/db.go +++ b/db.go @@ -378,7 +378,7 @@ func (r *Result) Iterate(fn func(r *Row) error) error { var row Row if r.ctx == nil { return r.result.Iterate(func(dr database.Row) error { - row.row = dr + row.Row = dr return fn(&row) }) } @@ -388,7 +388,7 @@ func (r *Result) Iterate(fn func(r *Row) error) error { return err } - row.row = dr + row.Row = dr return fn(&row) }) } @@ -488,26 +488,26 @@ func newQueryContext(conn *Connection, params []environment.Param) *query.Contex } type Row struct { - row database.Row + Row database.Row } func (r *Row) Clone() *Row { var rr Row cb := row.NewColumnBuffer() - err := cb.Copy(r.row) + err := cb.Copy(r.Row) if err != nil { panic(err) } var br database.BasicRow - br.ResetWith(r.row.TableName(), r.row.Key(), cb) - rr.row = &br + br.ResetWith(r.Row.TableName(), r.Row.Key(), cb) + rr.Row = &br return &rr } func (r *Row) Columns() ([]string, error) { var cols []string - err := r.row.Iterate(func(column string, value types.Value) error { + err := r.Row.Iterate(func(column string, value types.Value) error { cols = append(cols, column) return nil }) @@ -518,7 +518,7 @@ func (r *Row) Columns() ([]string, error) { return cols, nil } func (r *Row) GetColumnType(column string) (string, error) { - v, err := r.row.Get(column) + v, err := r.Row.Get(column) if errors.Is(err, types.ErrColumnNotFound) { return "", err } @@ -527,21 +527,21 @@ func (r *Row) GetColumnType(column string) (string, error) { } func (r *Row) ScanColumn(column string, dest any) error { - return row.ScanColumn(r.row, column, dest) + return row.ScanColumn(r.Row, column, dest) } func (r *Row) Scan(dest ...any) error { - return row.Scan(r.row, dest...) + return row.Scan(r.Row, dest...) } func (r *Row) StructScan(dest any) error { - return row.StructScan(r.row, dest) + return row.StructScan(r.Row, dest) } func (r *Row) MapScan(dest map[string]any) error { - return row.MapScan(r.row, dest) + return row.MapScan(r.Row, dest) } func (r *Row) MarshalJSON() ([]byte, error) { - return r.row.MarshalJSON() + return r.Row.MarshalJSON() } diff --git a/driver/driver.go b/driver/driver.go index 0cc7afc7..7e349528 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -10,6 +10,7 @@ import ( "github.com/chaisql/chai" "github.com/chaisql/chai/internal/environment" + "github.com/chaisql/chai/internal/row" "github.com/chaisql/chai/internal/types" "github.com/cockroachdb/errors" ) @@ -301,81 +302,84 @@ func (rs *Rows) Close() error { func (rs *Rows) Next(dest []driver.Value) error { rs.c <- Row{} - row, ok := <-rs.c + r, ok := <-rs.c if !ok { return io.EOF } - if row.err != nil { - return row.err + if r.err != nil { + return r.err } - for i := range rs.columns { - if rs.columns[i] == "*" { - dest[i] = row.r + var i int + err := r.r.Row.Iterate(func(column string, v types.Value) error { + var err error - continue - } - - tp, err := row.r.GetColumnType(rs.columns[i]) - if err != nil { - return err - } - switch tp { - case types.TypeBoolean.String(): + switch v.Type() { + case types.TypeNull: + dest[i] = nil + case types.TypeBoolean: var b bool - err = row.r.ScanColumn(rs.columns[i], &b) + err = row.ScanValue(v, &b) if err != nil { return err } dest[i] = b - case types.TypeInteger.String(): + case types.TypeInteger: var ii int32 - err = row.r.ScanColumn(rs.columns[i], &ii) + err = row.ScanValue(v, &ii) if err != nil { return err } dest[i] = ii - case types.TypeBigint.String(): + case types.TypeBigint: var bi int64 - err = row.r.ScanColumn(rs.columns[i], &bi) + err = row.ScanValue(v, &bi) if err != nil { return err } - case types.TypeDouble.String(): + dest[i] = bi + case types.TypeDouble: var d float64 - err = row.r.ScanColumn(rs.columns[i], &d) + err = row.ScanValue(v, &d) if err != nil { return err } dest[i] = d - case types.TypeTimestamp.String(): + case types.TypeTimestamp: var t time.Time - err = row.r.ScanColumn(rs.columns[i], &t) + err = row.ScanValue(v, &t) if err != nil { return err } dest[i] = t - case types.TypeText.String(): + case types.TypeText: var s string - err = row.r.ScanColumn(rs.columns[i], &s) + err = row.ScanValue(v, &s) if err != nil { return err } dest[i] = s - case types.TypeBlob.String(): + case types.TypeBlob: var b []byte - err = row.r.ScanColumn(rs.columns[i], &b) + err = row.ScanValue(v, &b) if err != nil { return err } dest[i] = b default: - err = row.r.ScanColumn(rs.columns[i], dest[i]) + err = row.ScanValue(v, dest[i]) if err != nil { return err } } + + i++ + + return nil + }) + if err != nil { + return err } return nil diff --git a/internal/environment/env.go b/internal/environment/env.go index 2df88212..5713dad5 100644 --- a/internal/environment/env.go +++ b/internal/environment/env.go @@ -21,7 +21,6 @@ type Param struct { // the expression is evaluated. type Environment struct { Params []Param - Vars *row.ColumnBuffer Row row.Row DB *database.Database Tx *database.Transaction @@ -46,29 +45,6 @@ func (e *Environment) SetOuter(env *Environment) { e.Outer = env } -func (e *Environment) Get(column string) (v types.Value, ok bool) { - if e.Vars != nil { - v, err := e.Vars.Get(column) - if err == nil { - return v, true - } - } - - if e.Outer != nil { - return e.Outer.Get(column) - } - - return types.NewNullValue(), false -} - -func (e *Environment) Set(column string, v types.Value) { - if e.Vars == nil { - e.Vars = row.NewColumnBuffer() - } - - e.Vars.Set(column, v) -} - func (e *Environment) GetRow() (row.Row, bool) { if e.Row != nil { return e.Row, true diff --git a/internal/expr/column.go b/internal/expr/column.go index 04529a1a..708eb998 100644 --- a/internal/expr/column.go +++ b/internal/expr/column.go @@ -6,19 +6,30 @@ import ( "github.com/cockroachdb/errors" ) -type Column string +type Column struct { + Name string + Table string +} + +func (c *Column) String() string { + return c.Name +} + +func (c *Column) IsEqual(other Expr) bool { + if o, ok := other.(*Column); ok { + return c.Name == o.Name && c.Table == o.Table + } -func (c Column) String() string { - return string(c) + return false } -func (c Column) Eval(env *environment.Environment) (types.Value, error) { +func (c *Column) Eval(env *environment.Environment) (types.Value, error) { r, ok := env.GetRow() if !ok { return NullLiteral, errors.New("no table specified") } - v, err := r.Get(string(c)) + v, err := r.Get(c.Name) if err != nil { return NullLiteral, err } diff --git a/internal/expr/comparison.go b/internal/expr/comparison.go index 01c64780..f548beba 100644 --- a/internal/expr/comparison.go +++ b/internal/expr/comparison.go @@ -240,7 +240,7 @@ func (op *InOperator) validateLeftExpression(a Expr) (Expr, error) { switch t := a.(type) { case Parentheses: return op.validateLeftExpression(t.E) - case Column: + case *Column: return a, nil case LiteralValue: return a, nil diff --git a/internal/expr/constraint.go b/internal/expr/constraint.go index 1ce0afa6..8ed76aba 100644 --- a/internal/expr/constraint.go +++ b/internal/expr/constraint.go @@ -33,8 +33,8 @@ func (t *ConstraintExpr) Eval(tx *database.Transaction, r row.Row) (types.Value, func (t *ConstraintExpr) Validate(info *database.TableInfo) (err error) { Walk(t.Expr, func(e Expr) bool { switch e := e.(type) { - case Column: - if info.GetColumnConstraint(string(e)) == nil { + case *Column: + if info.GetColumnConstraint(e.Name) == nil { err = errors.Newf("column %q does not exist", e) return false } diff --git a/internal/expr/expr.go b/internal/expr/expr.go index c3a7c1d4..2e403c37 100644 --- a/internal/expr/expr.go +++ b/internal/expr/expr.go @@ -97,13 +97,8 @@ type NamedExpr struct { ExprName string } -// Name returns ExprName. -func (e *NamedExpr) Name() string { - return e.ExprName -} - func (e *NamedExpr) String() string { - return e.Expr.String() + return e.ExprName } // A Function is an expression whose evaluation calls a function previously defined. @@ -261,7 +256,7 @@ func Clone(e Expr) Expr { CastAs: e.CastAs, } case LiteralValue, - Column, + *Column, NamedParam, PositionalParam, NextValueFor, diff --git a/internal/expr/functions/definition_test.go b/internal/expr/functions/definition_test.go index 86ddcb4b..249d0798 100644 --- a/internal/expr/functions/definition_test.go +++ b/internal/expr/functions/definition_test.go @@ -17,7 +17,9 @@ func TestDefinitions(t *testing.T) { }) t.Run("Function()", func(t *testing.T) { - fexpr, err := def.Function(expr.Column("a")) + fexpr, err := def.Function(&expr.Column{ + Name: "a", + }) require.NoError(t, err) require.NotNil(t, fexpr) }) diff --git a/internal/expr/functions/scalar_definition_test.go b/internal/expr/functions/scalar_definition_test.go index 543a76f6..5403226d 100644 --- a/internal/expr/functions/scalar_definition_test.go +++ b/internal/expr/functions/scalar_definition_test.go @@ -39,7 +39,7 @@ func TestScalarFunctionDef(t *testing.T) { r := database.NewBasicRow(fb) env := environment.New(r) expr1 := expr.Add(expr.LiteralValue{Value: types.NewIntegerValue(1)}, expr.LiteralValue{Value: types.NewIntegerValue(0)}) - expr2 := expr.Column("a") + expr2 := &expr.Column{Name: "a"} expr3 := expr.Div(expr.LiteralValue{Value: types.NewIntegerValue(6)}, expr.LiteralValue{Value: types.NewIntegerValue(2)}) t.Run("OK", func(t *testing.T) { diff --git a/internal/planner/index_selection.go b/internal/planner/index_selection.go index 77699959..6e106c35 100644 --- a/internal/planner/index_selection.go +++ b/internal/planner/index_selection.go @@ -261,14 +261,14 @@ func (i *indexSelector) isFilterIndexable(f *rows.FilterOperator) (*indexableNod func (i *indexSelector) isTempTreeSortIndexable(n *rows.TempTreeSortOperator) *indexableNode { // only columns can be associated with an index - col, ok := n.Expr.(expr.Column) + col, ok := n.Expr.(*expr.Column) if !ok { return nil } return &indexableNode{ node: n, - col: string(col), + col: col.Name, desc: n.Desc, operator: scanner.ORDER, } @@ -671,14 +671,14 @@ func (i *indexSelector) operatorCanUseIndex(op expr.Operator) (bool, string, exp lh := op.LeftHand() rh := op.RightHand() - lc, leftIsCol := lh.(expr.Column) - rc, rightIsCol := rh.(expr.Column) + lc, leftIsCol := lh.(*expr.Column) + rc, rightIsCol := rh.(*expr.Column) var cc *database.ColumnConstraint if leftIsCol { - cc = i.info.ColumnConstraints.GetColumnConstraint(string(lc)) + cc = i.info.ColumnConstraints.GetColumnConstraint(lc.Name) } else if rightIsCol { - cc = i.info.ColumnConstraints.GetColumnConstraint(string(rc)) + cc = i.info.ColumnConstraints.GetColumnConstraint(rc.Name) } if cc == nil { return false, "", nil, nil @@ -691,7 +691,7 @@ func (i *indexSelector) operatorCanUseIndex(op expr.Operator) (bool, string, exp return false, "", nil, err } - return true, string(lc), v, nil + return true, lc.Name, v, nil } // literal OP column @@ -701,7 +701,7 @@ func (i *indexSelector) operatorCanUseIndex(op expr.Operator) (bool, string, exp return false, "", nil, err } - return true, string(rc), v, nil + return true, rc.Name, v, nil } return false, "", nil, nil @@ -713,13 +713,13 @@ func (i *indexSelector) operatorCanUseIndex(op expr.Operator) (bool, string, exp // invalid: a IN (b + 1, 2) func (i *indexSelector) inOperatorCanUseIndex(op expr.Operator) (bool, string, expr.Expr, error) { rh := op.RightHand() - _, rightIsCol := rh.(expr.Column) + _, rightIsCol := rh.(*expr.Column) if rightIsCol { return false, "", nil, nil } lh := op.LeftHand() - lc, leftIsCol := lh.(expr.Column) + lc, leftIsCol := lh.(*expr.Column) if !leftIsCol { return false, "", nil, nil @@ -734,7 +734,7 @@ func (i *indexSelector) inOperatorCanUseIndex(op expr.Operator) (bool, string, e return false, "", nil, nil } - cc := i.info.ColumnConstraints.GetColumnConstraint(string(lc)) + cc := i.info.ColumnConstraints.GetColumnConstraint(lc.Name) if cc == nil { return false, "", nil, nil } @@ -750,7 +750,7 @@ func (i *indexSelector) inOperatorCanUseIndex(op expr.Operator) (bool, string, e rlist[i] = v } - return true, string(lc), rlist, nil + return true, lc.Name, rlist, nil } // Special case for BETWEEN operator: Given this expression (x BETWEEN a AND b), @@ -760,12 +760,12 @@ func (i *indexSelector) betweenOperatorCanUseIndex(op expr.Operator) (bool, stri rh := op.RightHand() bt := op.(*expr.BetweenOperator) - x, xIsCol := bt.X.(expr.Column) + x, xIsCol := bt.X.(*expr.Column) if !xIsCol { return false, "", nil, nil } - cc := i.info.ColumnConstraints.GetColumnConstraint(string(x)) + cc := i.info.ColumnConstraints.GetColumnConstraint(x.Name) if cc == nil { return false, "", nil, nil } @@ -782,7 +782,7 @@ func (i *indexSelector) betweenOperatorCanUseIndex(op expr.Operator) (bool, stri return false, "", nil, nil } - return true, string(x), expr.LiteralExprList{lv, rv}, nil + return true, x.Name, expr.LiteralExprList{lv, rv}, nil } func exprIsCompatibleLiteral(e expr.Expr, tp types.Type) (bool, expr.LiteralValue, error) { diff --git a/internal/planner/optimizer.go b/internal/planner/optimizer.go index 4a0f6c67..8f79c0a7 100644 --- a/internal/planner/optimizer.go +++ b/internal/planner/optimizer.go @@ -335,11 +335,11 @@ func precalculateExpr(sctx *StreamContext, e expr.Expr) (expr.Expr, error) { // if one operand is a column and the other is a literal // we can check if the types are compatible - lc, leftIsCol := lh.(expr.Column) - rc, rightIsCol := rh.(expr.Column) + lc, leftIsCol := lh.(*expr.Column) + rc, rightIsCol := rh.(*expr.Column) if leftIsCol && rightIsLit { - tp := sctx.TableInfo.ColumnConstraints.GetColumnConstraint(string(lc)).Type + tp := sctx.TableInfo.ColumnConstraints.GetColumnConstraint(lc.Name).Type if !tp.Def().IsComparableWith(rv.Value.Type()) { return nil, errors.Errorf("invalid input syntax for type %s: %s", tp, rh) } @@ -354,7 +354,7 @@ func precalculateExpr(sctx *StreamContext, e expr.Expr) (expr.Expr, error) { } if leftIsLit && rightIsCol { - tp := sctx.TableInfo.ColumnConstraints.GetColumnConstraint(string(rc)).Type + tp := sctx.TableInfo.ColumnConstraints.GetColumnConstraint(rc.Name).Type if !tp.Def().IsComparableWith(lv.Value.Type()) { return nil, errors.Errorf("invalid input syntax for type %s: %s", tp, lh) } @@ -421,8 +421,8 @@ func checkExprType(sctx *StreamContext, e expr.Expr) (err error) { lh := op.LeftHand() rh := op.RightHand() - lc, leftIsCol := lh.(expr.Column) - rc, rightIsCol := rh.(expr.Column) + lc, leftIsCol := lh.(*expr.Column) + rc, rightIsCol := rh.(*expr.Column) lv, leftIsLit := lh.(expr.LiteralValue) rv, rightIsLit := rh.(expr.LiteralValue) @@ -432,7 +432,7 @@ func checkExprType(sctx *StreamContext, e expr.Expr) (err error) { } if leftIsCol && rightIsLit { - tp := sctx.TableInfo.ColumnConstraints.GetColumnConstraint(string(lc)).Type + tp := sctx.TableInfo.ColumnConstraints.GetColumnConstraint(lc.Name).Type _, err := rv.Value.CastAs(tp) if err != nil { return errors.Errorf("invalid input syntax for type %s: %s", tp, rh) @@ -442,7 +442,7 @@ func checkExprType(sctx *StreamContext, e expr.Expr) (err error) { } if leftIsLit && rightIsCol { - tp := sctx.TableInfo.ColumnConstraints.GetColumnConstraint(string(rc)).Type + tp := sctx.TableInfo.ColumnConstraints.GetColumnConstraint(rc.Name).Type _, err := lv.Value.CastAs(tp) if err != nil { return errors.Errorf("invalid input syntax for type %s: %s", tp, lh) @@ -514,17 +514,17 @@ func RemoveUnnecessaryTempSortNodesRule(sctx *StreamContext) error { return nil } - lcol, ok := sctx.TempTreeSorts[0].Expr.(expr.Column) + lcol, ok := sctx.TempTreeSorts[0].Expr.(*expr.Column) if !ok { return nil } - rcol, ok := sctx.TempTreeSorts[1].Expr.(expr.Column) + rcol, ok := sctx.TempTreeSorts[1].Expr.(*expr.Column) if !ok { return nil } - if lcol != rcol { + if lcol.Name != rcol.Name { return nil } diff --git a/internal/planner/optimizer_test.go b/internal/planner/optimizer_test.go index ffff202e..bd98222a 100644 --- a/internal/planner/optimizer_test.go +++ b/internal/planner/optimizer_test.go @@ -109,17 +109,17 @@ func TestPrecalculateExprRule(t *testing.T) { }, { "constant sub-expr: a > 1 - 40 -> a > -39", - expr.Gt(expr.Column("a"), expr.Sub(testutil.IntegerValue(1), testutil.DoubleValue(40))), - expr.Gt(expr.Column("a"), testutil.DoubleValue(-39)), + expr.Gt(&expr.Column{Name: "a"}, expr.Sub(testutil.IntegerValue(1), testutil.DoubleValue(40))), + expr.Gt(&expr.Column{Name: "a"}, testutil.DoubleValue(-39)), }, { "non-constant expr list: (a, 1 - 40) -> (a, -39)", expr.LiteralExprList{ - expr.Column("a"), + &expr.Column{Name: "a"}, expr.Sub(testutil.IntegerValue(1), testutil.DoubleValue(40)), }, expr.LiteralExprList{ - expr.Column("a"), + &expr.Column{Name: "a"}, testutil.DoubleValue(-39), }, }, diff --git a/internal/query/query.go b/internal/query/query.go index eb9488cc..88e153bf 100644 --- a/internal/query/query.go +++ b/internal/query/query.go @@ -50,11 +50,6 @@ func (q *Query) Prepare(context *Context) error { } } - p, ok := stmt.(statement.Preparer) - if !ok { - return nil - } - if tx == nil { tx = context.GetTx() if tx == nil { @@ -68,11 +63,23 @@ func (q *Query) Prepare(context *Context) error { } } - stmt, err := p.Prepare(&statement.Context{ + sctx := &statement.Context{ DB: context.DB, Conn: context.Conn, Tx: tx, - }) + } + + err = stmt.Bind(sctx) + if err != nil { + return err + } + + p, ok := stmt.(statement.Preparer) + if !ok { + return nil + } + + stmt, err := p.Prepare(sctx) if err != nil { return err } diff --git a/internal/query/statement/alter.go b/internal/query/statement/alter.go index 68879fd7..63f093e8 100644 --- a/internal/query/statement/alter.go +++ b/internal/query/statement/alter.go @@ -9,20 +9,27 @@ import ( "github.com/cockroachdb/errors" ) +var _ Statement = (*AlterTableRenameStmt)(nil) +var _ Statement = (*AlterTableAddColumnStmt)(nil) + // AlterTableRenameStmt is a DSL that allows creating a full ALTER TABLE query. type AlterTableRenameStmt struct { TableName string NewTableName string } +func (stmt *AlterTableRenameStmt) Bind(ctx *Context) error { + return nil +} + // IsReadOnly always returns false. It implements the Statement interface. -func (stmt AlterTableRenameStmt) IsReadOnly() bool { +func (stmt *AlterTableRenameStmt) IsReadOnly() bool { return false } // Run runs the ALTER TABLE statement in the given transaction. // It implements the Statement interface. -func (stmt AlterTableRenameStmt) Run(ctx *Context) (Result, error) { +func (stmt *AlterTableRenameStmt) Run(ctx *Context) (Result, error) { var res Result if stmt.TableName == "" { @@ -52,6 +59,10 @@ func (stmt *AlterTableAddColumnStmt) IsReadOnly() bool { return false } +func (stmt *AlterTableAddColumnStmt) Bind(ctx *Context) error { + return nil +} + // Run runs the ALTER TABLE ADD COLUMN statement in the given transaction. // It implements the Statement interface. // The statement rebuilds the table. diff --git a/internal/query/statement/create.go b/internal/query/statement/create.go index 9cd661a8..fddaebf2 100644 --- a/internal/query/statement/create.go +++ b/internal/query/statement/create.go @@ -10,6 +10,10 @@ import ( "github.com/chaisql/chai/internal/stream/table" ) +var _ Statement = (*CreateTableStmt)(nil) +var _ Statement = (*CreateIndexStmt)(nil) +var _ Statement = (*CreateSequenceStmt)(nil) + // CreateTableStmt represents a parsed CREATE TABLE statement. type CreateTableStmt struct { IfNotExists bool @@ -21,6 +25,10 @@ func (stmt *CreateTableStmt) IsReadOnly() bool { return false } +func (stmt *CreateTableStmt) Bind(ctx *Context) error { + return nil +} + // Run runs the Create table statement in the given transaction. // It implements the Statement interface. func (stmt *CreateTableStmt) Run(ctx *Context) (Result, error) { @@ -84,6 +92,10 @@ func (stmt *CreateIndexStmt) IsReadOnly() bool { return false } +func (stmt *CreateIndexStmt) Bind(ctx *Context) error { + return nil +} + // Run runs the Create index statement in the given transaction. // It implements the Statement interface. func (stmt *CreateIndexStmt) Run(ctx *Context) (Result, error) { @@ -122,6 +134,10 @@ func (stmt *CreateSequenceStmt) IsReadOnly() bool { return false } +func (stmt *CreateSequenceStmt) Bind(ctx *Context) error { + return nil +} + // Run the statement in the given transaction. // It implements the Statement interface. func (stmt *CreateSequenceStmt) Run(ctx *Context) (Result, error) { diff --git a/internal/query/statement/delete.go b/internal/query/statement/delete.go index ab6702bf..95c502ed 100644 --- a/internal/query/statement/delete.go +++ b/internal/query/statement/delete.go @@ -9,6 +9,8 @@ import ( "github.com/chaisql/chai/internal/stream/table" ) +var _ Statement = (*DeleteStmt)(nil) + // DeleteConfig holds DELETE configuration. type DeleteStmt struct { basePreparedStatement @@ -16,7 +18,7 @@ type DeleteStmt struct { TableName string WhereExpr expr.Expr OffsetExpr expr.Expr - OrderBy expr.Column + OrderBy *expr.Column LimitExpr expr.Expr OrderByDirection scanner.Token } @@ -32,19 +34,38 @@ func NewDeleteStatement() *DeleteStmt { return &p } +func (stmt *DeleteStmt) Bind(ctx *Context) error { + err := BindExpr(ctx, stmt.TableName, stmt.WhereExpr) + if err != nil { + return err + } + + err = BindExpr(ctx, stmt.TableName, stmt.OffsetExpr) + if err != nil { + return err + } + + err = BindExpr(ctx, stmt.TableName, stmt.OrderBy) + if err != nil { + return err + } + + err = BindExpr(ctx, stmt.TableName, stmt.LimitExpr) + if err != nil { + return err + } + + return nil +} + func (stmt *DeleteStmt) Prepare(c *Context) (Statement, error) { s := stream.New(table.Scan(stmt.TableName)) if stmt.WhereExpr != nil { - err := ensureExprColumnsExist(c, stmt.TableName, stmt.WhereExpr) - if err != nil { - return nil, err - } - s = s.Pipe(rows.Filter(stmt.WhereExpr)) } - if stmt.OrderBy != "" { + if stmt.OrderBy != nil { if stmt.OrderByDirection == scanner.DESC { s = s.Pipe(rows.TempTreeSortReverse(stmt.OrderBy)) } else { @@ -53,18 +74,10 @@ func (stmt *DeleteStmt) Prepare(c *Context) (Statement, error) { } if stmt.OffsetExpr != nil { - err := ensureExprColumnsExist(c, stmt.TableName, stmt.OffsetExpr) - if err != nil { - return nil, err - } s = s.Pipe(rows.Skip(stmt.OffsetExpr)) } if stmt.LimitExpr != nil { - err := ensureExprColumnsExist(c, stmt.TableName, stmt.LimitExpr) - if err != nil { - return nil, err - } s = s.Pipe(rows.Take(stmt.LimitExpr)) } diff --git a/internal/query/statement/drop.go b/internal/query/statement/drop.go index 6b83d00f..14d7e6ae 100644 --- a/internal/query/statement/drop.go +++ b/internal/query/statement/drop.go @@ -7,6 +7,10 @@ import ( "github.com/cockroachdb/errors" ) +var _ Statement = (*DropTableStmt)(nil) +var _ Statement = (*DropIndexStmt)(nil) +var _ Statement = (*DropSequenceStmt)(nil) + // DropTableStmt is a DSL that allows creating a DROP TABLE query. type DropTableStmt struct { TableName string @@ -14,13 +18,17 @@ type DropTableStmt struct { } // IsReadOnly always returns false. It implements the Statement interface. -func (stmt DropTableStmt) IsReadOnly() bool { +func (stmt *DropTableStmt) IsReadOnly() bool { return false } +func (stmt *DropTableStmt) Bind(ctx *Context) error { + return nil +} + // Run runs the DropTable statement in the given transaction. // It implements the Statement interface. -func (stmt DropTableStmt) Run(ctx *Context) (Result, error) { +func (stmt *DropTableStmt) Run(ctx *Context) (Result, error) { var res Result if stmt.TableName == "" { @@ -59,13 +67,17 @@ type DropIndexStmt struct { } // IsReadOnly always returns false. It implements the Statement interface. -func (stmt DropIndexStmt) IsReadOnly() bool { +func (stmt *DropIndexStmt) IsReadOnly() bool { return false } +func (stmt *DropIndexStmt) Bind(ctx *Context) error { + return nil +} + // Run runs the DropIndex statement in the given transaction. // It implements the Statement interface. -func (stmt DropIndexStmt) Run(ctx *Context) (Result, error) { +func (stmt *DropIndexStmt) Run(ctx *Context) (Result, error) { var res Result if stmt.IndexName == "" { @@ -87,13 +99,17 @@ type DropSequenceStmt struct { } // IsReadOnly always returns false. It implements the Statement interface. -func (stmt DropSequenceStmt) IsReadOnly() bool { +func (stmt *DropSequenceStmt) IsReadOnly() bool { return false } +func (stmt *DropSequenceStmt) Bind(ctx *Context) error { + return nil +} + // Run runs the DropSequence statement in the given transaction. // It implements the Statement interface. -func (stmt DropSequenceStmt) Run(ctx *Context) (Result, error) { +func (stmt *DropSequenceStmt) Run(ctx *Context) (Result, error) { var res Result if stmt.SequenceName == "" { diff --git a/internal/query/statement/explain.go b/internal/query/statement/explain.go index ac4bf467..c0a18249 100644 --- a/internal/query/statement/explain.go +++ b/internal/query/statement/explain.go @@ -9,6 +9,8 @@ import ( "github.com/cockroachdb/errors" ) +var _ Statement = &ExplainStmt{} + // ExplainStmt is a Statement that // displays information about how a statement // is going to be executed, without executing it. @@ -16,6 +18,14 @@ type ExplainStmt struct { Statement Preparer } +func (stmt *ExplainStmt) Bind(ctx *Context) error { + if s, ok := stmt.Statement.(Statement); ok { + return s.Bind(ctx) + } + + return nil +} + // Run analyses the inner statement and displays its execution plan. // If the statement is a stream, Optimize will be called prior to // displaying all the operations. diff --git a/internal/query/statement/insert.go b/internal/query/statement/insert.go index 45270b51..029d9b86 100644 --- a/internal/query/statement/insert.go +++ b/internal/query/statement/insert.go @@ -11,6 +11,8 @@ import ( "github.com/cockroachdb/errors" ) +var _ Statement = (*InsertStmt)(nil) + // InsertStmt holds INSERT configuration. type InsertStmt struct { basePreparedStatement @@ -34,6 +36,33 @@ func NewInsertStatement() *InsertStmt { return &p } +func (stmt *InsertStmt) Bind(ctx *Context) error { + for i := range stmt.Values { + err := BindExpr(ctx, stmt.TableName, stmt.Values[i]) + if err != nil { + return err + } + } + + if stmt.SelectStmt != nil { + if s, ok := stmt.SelectStmt.(Statement); ok { + err := s.Bind(ctx) + if err != nil { + return err + } + } + } + + for i := range stmt.Returning { + err := BindExpr(ctx, stmt.TableName, stmt.Returning[i]) + if err != nil { + return err + } + } + + return nil +} + func (stmt *InsertStmt) Prepare(c *Context) (Statement, error) { var s *stream.Stream diff --git a/internal/query/statement/reindex.go b/internal/query/statement/reindex.go index 3c4dc252..28b6bfd1 100644 --- a/internal/query/statement/reindex.go +++ b/internal/query/statement/reindex.go @@ -8,6 +8,8 @@ import ( "github.com/chaisql/chai/internal/stream/table" ) +var _ Statement = (*ReIndexStmt)(nil) + // ReIndexStmt is a DSL that allows creating a full REINDEX statement. type ReIndexStmt struct { basePreparedStatement @@ -26,8 +28,12 @@ func NewReIndexStatement() *ReIndexStmt { return &p } +func (stmt *ReIndexStmt) Bind(ctx *Context) error { + return nil +} + // Prepare implements the Preparer interface. -func (stmt ReIndexStmt) Prepare(ctx *Context) (Statement, error) { +func (stmt *ReIndexStmt) Prepare(ctx *Context) (Statement, error) { var indexNames []string if stmt.TableOrIndexName == "" { diff --git a/internal/query/statement/select.go b/internal/query/statement/select.go index 4f331585..4164640f 100644 --- a/internal/query/statement/select.go +++ b/internal/query/statement/select.go @@ -11,6 +11,8 @@ import ( "github.com/cockroachdb/errors" ) +var _ Statement = (*SelectStmt)(nil) + type SelectCoreStmt struct { TableName string Distinct bool @@ -19,6 +21,27 @@ type SelectCoreStmt struct { ProjectionExprs []expr.Expr } +func (stmt *SelectCoreStmt) Bind(ctx *Context) error { + err := BindExpr(ctx, stmt.TableName, stmt.WhereExpr) + if err != nil { + return err + } + + err = BindExpr(ctx, stmt.TableName, stmt.GroupByExpr) + if err != nil { + return err + } + + for i := range stmt.ProjectionExprs { + err = BindExpr(ctx, stmt.TableName, stmt.ProjectionExprs[i]) + if err != nil { + return err + } + } + + return nil +} + func (stmt *SelectCoreStmt) Prepare(ctx *Context) (*StreamStmt, error) { isReadOnly := true @@ -34,20 +57,11 @@ func (stmt *SelectCoreStmt) Prepare(ctx *Context) (*StreamStmt, error) { } if stmt.WhereExpr != nil { - err := ensureExprColumnsExist(ctx, stmt.TableName, stmt.WhereExpr) - if err != nil { - return nil, err - } - s = s.Pipe(rows.Filter(stmt.WhereExpr)) } // when using GROUP BY, only aggregation functions or GroupByExpr can be selected if stmt.GroupByExpr != nil { - err := ensureExprColumnsExist(ctx, stmt.TableName, stmt.GroupByExpr) - if err != nil { - return nil, err - } var invalidProjectedField expr.Expr var aggregators []expr.AggregatorBuilder @@ -70,7 +84,10 @@ func (stmt *SelectCoreStmt) Prepare(ctx *Context) (*StreamStmt, error) { // if so, replace the expression with a column expression stmt.ProjectionExprs[i] = &expr.NamedExpr{ ExprName: ne.ExprName, - Expr: expr.Column(e.String()), + Expr: &expr.Column{ + Name: e.String(), + Table: stmt.TableName, + }, } continue } @@ -99,14 +116,6 @@ func (stmt *SelectCoreStmt) Prepare(ctx *Context) (*StreamStmt, error) { return true } - if c, ok := e.(expr.Column); ok { - // check if the projected expression is a column - err := ensureExprColumnsExist(ctx, stmt.TableName, c) - if err != nil { - return false - } - } - return true }) } @@ -124,7 +133,7 @@ func (stmt *SelectCoreStmt) Prepare(ctx *Context) (*StreamStmt, error) { for _, e := range stmt.ProjectionExprs { expr.Walk(e, func(e expr.Expr) bool { switch e.(type) { - case expr.Column, expr.Wildcard: + case *expr.Column, expr.Wildcard: err = errors.New("no tables specified") return false default: @@ -168,7 +177,7 @@ type SelectStmt struct { CompoundSelect []*SelectCoreStmt CompoundOperators []scanner.Token - OrderBy expr.Column + OrderBy *expr.Column OrderByDirection scanner.Token OffsetExpr expr.Expr LimitExpr expr.Expr @@ -185,6 +194,32 @@ func NewSelectStatement() *SelectStmt { return &p } +func (stmt *SelectStmt) Bind(ctx *Context) error { + for i := range stmt.CompoundSelect { + err := stmt.CompoundSelect[i].Bind(ctx) + if err != nil { + return err + } + } + + err := BindExpr(ctx, stmt.CompoundSelect[0].TableName, stmt.OrderBy) + if err != nil { + return err + } + + err = BindExpr(ctx, stmt.CompoundSelect[0].TableName, stmt.OffsetExpr) + if err != nil { + return err + } + + err = BindExpr(ctx, stmt.CompoundSelect[0].TableName, stmt.LimitExpr) + if err != nil { + return err + } + + return nil +} + // Prepare implements the Preparer interface. func (stmt *SelectStmt) Prepare(ctx *Context) (Statement, error) { var s *stream.Stream @@ -231,7 +266,7 @@ func (stmt *SelectStmt) Prepare(ctx *Context) (Statement, error) { prev = tok } - if stmt.OrderBy != "" { + if stmt.OrderBy != nil { if stmt.OrderByDirection == scanner.DESC { s = s.Pipe(rows.TempTreeSortReverse(stmt.OrderBy)) } else { diff --git a/internal/query/statement/statement.go b/internal/query/statement/statement.go index ee7455d7..3c48f9d3 100644 --- a/internal/query/statement/statement.go +++ b/internal/query/statement/statement.go @@ -9,6 +9,7 @@ import ( // A Statement represents a unique action that can be executed against the database. type Statement interface { + Bind(*Context) error Run(*Context) (Result, error) IsReadOnly() bool } @@ -84,19 +85,37 @@ func (r *Result) Close() (err error) { return err } -func ensureExprColumnsExist(ctx *Context, tableName string, e expr.Expr) (err error) { - info, err := ctx.Tx.Catalog.GetTableInfo(tableName) - if err != nil { - return err +func BindExpr(ctx *Context, tableName string, e expr.Expr) (err error) { + if e == nil { + return nil + } + + var info *database.TableInfo + if tableName != "" { + info, err = ctx.Tx.Catalog.GetTableInfo(tableName) + if err != nil { + return err + } } + expr.Walk(e, func(e expr.Expr) bool { switch t := e.(type) { - case expr.Column: - cc := info.ColumnConstraints.GetColumnConstraint(string(t)) + case *expr.Column: + if t == nil { + return true + } + + if info == nil { + err = errors.New("no table specified") + return false + } + + cc := info.ColumnConstraints.GetColumnConstraint(t.Name) if cc == nil { err = errors.Newf("column %s does not exist", t) return false } + t.Table = tableName } return true diff --git a/internal/query/statement/stream.go b/internal/query/statement/stream.go index bfd363e1..95da6854 100644 --- a/internal/query/statement/stream.go +++ b/internal/query/statement/stream.go @@ -8,6 +8,8 @@ import ( "github.com/cockroachdb/errors" ) +var _ Statement = (*PreparedStreamStmt)(nil) + // StreamStmt is a StreamStmt using a Stream. type StreamStmt struct { Stream *stream.Stream @@ -28,6 +30,10 @@ type PreparedStreamStmt struct { ReadOnly bool } +func (s *PreparedStreamStmt) Bind(ctx *Context) error { + return nil +} + // Run returns a result containing the stream. The stream will be executed by calling the Iterate method of // the result. func (s *PreparedStreamStmt) Run(ctx *Context) (Result, error) { diff --git a/internal/query/statement/update.go b/internal/query/statement/update.go index f3a0bee1..b5007ef1 100644 --- a/internal/query/statement/update.go +++ b/internal/query/statement/update.go @@ -9,6 +9,8 @@ import ( "github.com/chaisql/chai/internal/stream/table" ) +var _ Statement = (*UpdateStmt)(nil) + // UpdateConfig holds UPDATE configuration. type UpdateStmt struct { basePreparedStatement @@ -35,10 +37,31 @@ func NewUpdateStatement() *UpdateStmt { } type UpdateSetPair struct { - Column expr.Column + Column *expr.Column E expr.Expr } +func (stmt *UpdateStmt) Bind(ctx *Context) error { + err := BindExpr(ctx, stmt.TableName, stmt.WhereExpr) + if err != nil { + return err + } + + for i := range stmt.SetPairs { + err = BindExpr(ctx, stmt.TableName, stmt.SetPairs[i].Column) + if err != nil { + return err + } + + err = BindExpr(ctx, stmt.TableName, stmt.SetPairs[i].E) + if err != nil { + return err + } + } + + return nil +} + // Prepare implements the Preparer interface. func (stmt *UpdateStmt) Prepare(c *Context) (Statement, error) { ti, err := c.Tx.Catalog.GetTableInfo(stmt.TableName) @@ -50,33 +73,23 @@ func (stmt *UpdateStmt) Prepare(c *Context) (Statement, error) { s := stream.New(table.Scan(stmt.TableName)) if stmt.WhereExpr != nil { - err := ensureExprColumnsExist(c, stmt.TableName, stmt.WhereExpr) - if err != nil { - return nil, err - } - s = s.Pipe(rows.Filter(stmt.WhereExpr)) } var pkModified bool if stmt.SetPairs != nil { for _, pair := range stmt.SetPairs { - err := ensureExprColumnsExist(c, stmt.TableName, pair.Column) - if err != nil { - return nil, err - } - // if we modify the primary key, // we must remove the old row and create an new one if pk != nil && !pkModified { for _, c := range pk.Columns { - if c == string(pair.Column) { + if c == pair.Column.Name { pkModified = true break } } } - s = s.Pipe(path.Set(string(pair.Column), pair.E)) + s = s.Pipe(path.Set(pair.Column.Name, pair.E)) } } diff --git a/internal/query/transaction.go b/internal/query/transaction.go index 66ff059d..65d91952 100644 --- a/internal/query/transaction.go +++ b/internal/query/transaction.go @@ -15,6 +15,10 @@ type BeginStmt struct { Writable bool } +func (stmt BeginStmt) Bind(ctx *statement.Context) error { + return nil +} + // Prepare implements the Preparer interface. func (stmt BeginStmt) Prepare(*statement.Context) (statement.Statement, error) { return stmt, nil @@ -44,6 +48,10 @@ func (stmt BeginStmt) Run(ctx *statement.Context) (statement.Result, error) { // RollbackStmt is a statement that rollbacks the current active transaction. type RollbackStmt struct{} +func (stmt RollbackStmt) Bind(ctx *statement.Context) error { + return nil +} + // Prepare implements the Preparer interface. func (stmt RollbackStmt) Prepare(*statement.Context) (statement.Statement, error) { return stmt, nil @@ -74,6 +82,10 @@ func (stmt RollbackStmt) Run(ctx *statement.Context) (statement.Result, error) { // CommitStmt is a statement that commits the current active transaction. type CommitStmt struct{} +func (stmt CommitStmt) Bind(ctx *statement.Context) error { + return nil +} + // Prepare implements the Preparer interface. func (stmt CommitStmt) Prepare(*statement.Context) (statement.Statement, error) { return stmt, nil diff --git a/internal/sql/parser/alter.go b/internal/sql/parser/alter.go index 463ebbae..8733f229 100644 --- a/internal/sql/parser/alter.go +++ b/internal/sql/parser/alter.go @@ -7,22 +7,22 @@ import ( "github.com/chaisql/chai/internal/sql/scanner" ) -func (p *Parser) parseAlterTableRenameStatement(tableName string) (_ statement.AlterTableRenameStmt, err error) { +func (p *Parser) parseAlterTableRenameStatement(tableName string) (_ *statement.AlterTableRenameStmt, err error) { var stmt statement.AlterTableRenameStmt stmt.TableName = tableName // Parse "TO". if err := p.ParseTokens(scanner.TO); err != nil { - return stmt, err + return nil, err } // Parse new table name. stmt.NewTableName, err = p.parseIdent() if err != nil { - return stmt, err + return nil, err } - return stmt, nil + return &stmt, nil } func (p *Parser) parseAlterTableAddColumnStatement(tableName string) (*statement.AlterTableAddColumnStmt, error) { diff --git a/internal/sql/parser/alter_test.go b/internal/sql/parser/alter_test.go index 256da3b1..3f39aa0a 100644 --- a/internal/sql/parser/alter_test.go +++ b/internal/sql/parser/alter_test.go @@ -18,10 +18,10 @@ func TestParserAlterTable(t *testing.T) { expected statement.Statement errored bool }{ - {"Basic", "ALTER TABLE foo RENAME TO bar", statement.AlterTableRenameStmt{TableName: "foo", NewTableName: "bar"}, false}, - {"With error / missing TABLE keyword", "ALTER foo RENAME TO bar", statement.AlterTableRenameStmt{}, true}, - {"With error / two identifiers for table name", "ALTER TABLE foo baz RENAME TO bar", statement.AlterTableRenameStmt{}, true}, - {"With error / two identifiers for new table name", "ALTER TABLE foo RENAME TO bar baz", statement.AlterTableRenameStmt{}, true}, + {"Basic", "ALTER TABLE foo RENAME TO bar", &statement.AlterTableRenameStmt{TableName: "foo", NewTableName: "bar"}, false}, + {"With error / missing TABLE keyword", "ALTER foo RENAME TO bar", nil, true}, + {"With error / two identifiers for table name", "ALTER TABLE foo baz RENAME TO bar", nil, true}, + {"With error / two identifiers for new table name", "ALTER TABLE foo RENAME TO bar baz", nil, true}, } for _, test := range tests { diff --git a/internal/sql/parser/create.go b/internal/sql/parser/create.go index 8ed6351e..b3d8a4cb 100644 --- a/internal/sql/parser/create.go +++ b/internal/sql/parser/create.go @@ -637,8 +637,8 @@ func (p *Parser) parseCheckConstraint() (expr.Expr, []string, error) { // extract all the paths from the expression expr.Walk(e, func(e expr.Expr) bool { switch t := e.(type) { - case expr.Column: - scol := string(t) + case *expr.Column: + scol := t.Name // ensure that the path is not already in the list found := false for _, c := range columns { diff --git a/internal/sql/parser/delete_test.go b/internal/sql/parser/delete_test.go index 993439a0..28d74a1b 100644 --- a/internal/sql/parser/delete_test.go +++ b/internal/sql/parser/delete_test.go @@ -4,6 +4,7 @@ import ( "context" "testing" + "github.com/chaisql/chai/internal/expr" "github.com/chaisql/chai/internal/query" "github.com/chaisql/chai/internal/query/statement" "github.com/chaisql/chai/internal/sql/parser" @@ -15,6 +16,18 @@ import ( ) func TestParserDelete(t *testing.T) { + db, tx, cleanup := testutil.NewTestTx(t) + defer cleanup() + + testutil.MustExec(t, db, tx, "CREATE TABLE test(age int)") + + parseExpr := func(s string) expr.Expr { + e := parser.MustParseExpr(s) + err := statement.BindExpr(&statement.Context{DB: db, Tx: tx, Conn: tx.Connection()}, "test", e) + require.NoError(t, err) + return e + } + tests := []struct { name string s string @@ -24,37 +37,37 @@ func TestParserDelete(t *testing.T) { Pipe(stream.Discard())}, {"WithCond", "DELETE FROM test WHERE age = 10", stream.New(table.Scan("test")). - Pipe(rows.Filter(parser.MustParseExpr("age = 10"))). + Pipe(rows.Filter(parseExpr("age = 10"))). Pipe(table.Delete("test")). Pipe(stream.Discard()), }, {"WithOffset", "DELETE FROM test WHERE age = 10 OFFSET 20", stream.New(table.Scan("test")). - Pipe(rows.Filter(parser.MustParseExpr("age = 10"))). - Pipe(rows.Skip(parser.MustParseExpr("20"))). + Pipe(rows.Filter(parseExpr("age = 10"))). + Pipe(rows.Skip(parseExpr("20"))). Pipe(table.Delete("test")). Pipe(stream.Discard()), }, {"WithLimit", "DELETE FROM test LIMIT 10", stream.New(table.Scan("test")). - Pipe(rows.Take(parser.MustParseExpr("10"))). + Pipe(rows.Take(parseExpr("10"))). Pipe(table.Delete("test")). Pipe(stream.Discard()), }, {"WithOrderByThenOffset", "DELETE FROM test WHERE age = 10 ORDER BY age OFFSET 20", stream.New(table.Scan("test")). - Pipe(rows.Filter(parser.MustParseExpr("age = 10"))). - Pipe(rows.TempTreeSort(parser.MustParseExpr("age"))). - Pipe(rows.Skip(parser.MustParseExpr("20"))). + Pipe(rows.Filter(parseExpr("age = 10"))). + Pipe(rows.TempTreeSort(parseExpr("age"))). + Pipe(rows.Skip(parseExpr("20"))). Pipe(table.Delete("test")). Pipe(stream.Discard()), }, {"WithOrderByThenLimitThenOffset", "DELETE FROM test WHERE age = 10 ORDER BY age LIMIT 10 OFFSET 20", stream.New(table.Scan("test")). - Pipe(rows.Filter(parser.MustParseExpr("age = 10"))). - Pipe(rows.TempTreeSort(parser.MustParseExpr("age"))). - Pipe(rows.Skip(parser.MustParseExpr("20"))). - Pipe(rows.Take(parser.MustParseExpr("10"))). + Pipe(rows.Filter(parseExpr("age = 10"))). + Pipe(rows.TempTreeSort(parseExpr("age"))). + Pipe(rows.Skip(parseExpr("20"))). + Pipe(rows.Take(parseExpr("10"))). Pipe(table.Delete("test")). Pipe(stream.Discard()), }, @@ -62,12 +75,6 @@ func TestParserDelete(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - - db, tx, cleanup := testutil.NewTestTx(t) - defer cleanup() - - testutil.MustExec(t, db, tx, "CREATE TABLE test(age int)") - q, err := parser.ParseQuery(test.s) require.NoError(t, err) diff --git a/internal/sql/parser/drop.go b/internal/sql/parser/drop.go index d063a436..aa815623 100644 --- a/internal/sql/parser/drop.go +++ b/internal/sql/parser/drop.go @@ -29,13 +29,13 @@ func (p *Parser) parseDropStatement() (statement.Statement, error) { // parseDropTableStatement parses a drop table string and returns a Statement AST row. // This function assumes the DROP TABLE tokens have already been consumed. -func (p *Parser) parseDropTableStatement() (statement.DropTableStmt, error) { +func (p *Parser) parseDropTableStatement() (*statement.DropTableStmt, error) { var stmt statement.DropTableStmt var err error stmt.IfExists, err = p.parseOptional(scanner.IF, scanner.EXISTS) if err != nil { - return stmt, err + return nil, err } // Parse table name @@ -43,21 +43,21 @@ func (p *Parser) parseDropTableStatement() (statement.DropTableStmt, error) { if err != nil { pErr := errors.Unwrap(err).(*ParseError) pErr.Expected = []string{"table_name"} - return stmt, pErr + return nil, pErr } - return stmt, nil + return &stmt, nil } // parseDropIndexStatement parses a drop index string and returns a Statement AST row. // This function assumes the DROP INDEX tokens have already been consumed. -func (p *Parser) parseDropIndexStatement() (statement.DropIndexStmt, error) { +func (p *Parser) parseDropIndexStatement() (*statement.DropIndexStmt, error) { var stmt statement.DropIndexStmt var err error stmt.IfExists, err = p.parseOptional(scanner.IF, scanner.EXISTS) if err != nil { - return stmt, err + return nil, err } // Parse index name @@ -65,21 +65,21 @@ func (p *Parser) parseDropIndexStatement() (statement.DropIndexStmt, error) { if err != nil { pErr := errors.Unwrap(err).(*ParseError) pErr.Expected = []string{"index_name"} - return stmt, pErr + return nil, pErr } - return stmt, nil + return &stmt, nil } // parseDropSequenceStatement parses a drop sequence string and returns a Statement AST row. // This function assumes the DROP SEQUENCE tokens have already been consumed. -func (p *Parser) parseDropSequenceStatement() (statement.DropSequenceStmt, error) { +func (p *Parser) parseDropSequenceStatement() (*statement.DropSequenceStmt, error) { var stmt statement.DropSequenceStmt var err error stmt.IfExists, err = p.parseOptional(scanner.IF, scanner.EXISTS) if err != nil { - return stmt, err + return nil, err } // Parse sequence name @@ -87,8 +87,8 @@ func (p *Parser) parseDropSequenceStatement() (statement.DropSequenceStmt, error if err != nil { pErr := errors.Unwrap(err).(*ParseError) pErr.Expected = []string{"sequence_name"} - return stmt, pErr + return nil, pErr } - return stmt, nil + return &stmt, nil } diff --git a/internal/sql/parser/drop_test.go b/internal/sql/parser/drop_test.go index 19268c29..2fdbef43 100644 --- a/internal/sql/parser/drop_test.go +++ b/internal/sql/parser/drop_test.go @@ -15,12 +15,12 @@ func TestParserDrop(t *testing.T) { expected statement.Statement errored bool }{ - {"Drop table", "DROP TABLE test", statement.DropTableStmt{TableName: "test"}, false}, - {"Drop table If not exists", "DROP TABLE IF EXISTS test", statement.DropTableStmt{TableName: "test", IfExists: true}, false}, - {"Drop index", "DROP INDEX test", statement.DropIndexStmt{IndexName: "test"}, false}, - {"Drop index if exists", "DROP INDEX IF EXISTS test", statement.DropIndexStmt{IndexName: "test", IfExists: true}, false}, - {"Drop index", "DROP SEQUENCE test", statement.DropSequenceStmt{SequenceName: "test"}, false}, - {"Drop index if exists", "DROP SEQUENCE IF EXISTS test", statement.DropSequenceStmt{SequenceName: "test", IfExists: true}, false}, + {"Drop table", "DROP TABLE test", &statement.DropTableStmt{TableName: "test"}, false}, + {"Drop table If not exists", "DROP TABLE IF EXISTS test", &statement.DropTableStmt{TableName: "test", IfExists: true}, false}, + {"Drop index", "DROP INDEX test", &statement.DropIndexStmt{IndexName: "test"}, false}, + {"Drop index if exists", "DROP INDEX IF EXISTS test", &statement.DropIndexStmt{IndexName: "test", IfExists: true}, false}, + {"Drop index", "DROP SEQUENCE test", &statement.DropSequenceStmt{SequenceName: "test"}, false}, + {"Drop index if exists", "DROP SEQUENCE IF EXISTS test", &statement.DropSequenceStmt{SequenceName: "test", IfExists: true}, false}, } for _, test := range tests { diff --git a/internal/sql/parser/expr.go b/internal/sql/parser/expr.go index 24d6bdff..320876b2 100644 --- a/internal/sql/parser/expr.go +++ b/internal/sql/parser/expr.go @@ -378,30 +378,6 @@ func (p *Parser) parseIdentList() ([]string, error) { } } -// parseParam parses a positional or named param. -func (p *Parser) parseParam() (expr.Expr, error) { - tok, _, lit := p.ScanIgnoreWhitespace() - switch tok { - case scanner.NAMEDPARAM: - if len(lit) == 1 { - return nil, errors.WithStack(&ParseError{Message: "missing param name"}) - } - if p.orderedParams > 0 { - return nil, errors.WithStack(&ParseError{Message: "cannot mix positional arguments with named arguments"}) - } - p.namedParams++ - return expr.NamedParam(lit[1:]), nil - case scanner.POSITIONALPARAM: - if p.namedParams > 0 { - return nil, errors.WithStack(&ParseError{Message: "cannot mix positional arguments with named arguments"}) - } - p.orderedParams++ - return expr.PositionalParam(p.orderedParams), nil - default: - return nil, nil - } -} - func (p *Parser) parseType() (types.Type, error) { tok, pos, lit := p.ScanIgnoreWhitespace() switch tok { @@ -448,14 +424,14 @@ func (p *Parser) parseType() (types.Type, error) { } // parsePath parses a path to a specific value. -func (p *Parser) parseColumn() (expr.Column, error) { +func (p *Parser) parseColumn() (*expr.Column, error) { // parse first mandatory ident col, err := p.parseIdent() if err != nil { - return "", err + return nil, err } - return expr.Column(col), nil + return &expr.Column{Name: col}, nil } func (p *Parser) parseExprListUntil(rightToken scanner.Token) (expr.LiteralExprList, error) { diff --git a/internal/sql/parser/expr_test.go b/internal/sql/parser/expr_test.go index 56133455..22f5c0e5 100644 --- a/internal/sql/parser/expr_test.go +++ b/internal/sql/parser/expr_test.go @@ -62,26 +62,26 @@ func TestParserExpr(t *testing.T) { }, false}, // operators - {"=", "age = 10", expr.Eq(expr.Column("age"), testutil.IntegerValue(10)), false}, - {"!=", "age != 10", expr.Neq(expr.Column("age"), testutil.IntegerValue(10)), false}, - {">", "age > 10", expr.Gt(expr.Column("age"), testutil.IntegerValue(10)), false}, - {">=", "age >= 10", expr.Gte(expr.Column("age"), testutil.IntegerValue(10)), false}, - {"<", "age < 10", expr.Lt(expr.Column("age"), testutil.IntegerValue(10)), false}, - {"<=", "age <= 10", expr.Lte(expr.Column("age"), testutil.IntegerValue(10)), false}, + {"=", "age = 10", expr.Eq(&expr.Column{Name: "age"}, testutil.IntegerValue(10)), false}, + {"!=", "age != 10", expr.Neq(&expr.Column{Name: "age"}, testutil.IntegerValue(10)), false}, + {">", "age > 10", expr.Gt(&expr.Column{Name: "age"}, testutil.IntegerValue(10)), false}, + {">=", "age >= 10", expr.Gte(&expr.Column{Name: "age"}, testutil.IntegerValue(10)), false}, + {"<", "age < 10", expr.Lt(&expr.Column{Name: "age"}, testutil.IntegerValue(10)), false}, + {"<=", "age <= 10", expr.Lte(&expr.Column{Name: "age"}, testutil.IntegerValue(10)), false}, {"BETWEEN", "1 BETWEEN 10 AND 11", expr.Between(testutil.IntegerValue(10))(testutil.IntegerValue(1), testutil.IntegerValue(11)), false}, - {"+", "age + 10", expr.Add(expr.Column("age"), testutil.IntegerValue(10)), false}, - {"-", "age - 10", expr.Sub(expr.Column("age"), testutil.IntegerValue(10)), false}, - {"*", "age * 10", expr.Mul(expr.Column("age"), testutil.IntegerValue(10)), false}, - {"/", "age / 10", expr.Div(expr.Column("age"), testutil.IntegerValue(10)), false}, - {"%", "age % 10", expr.Mod(expr.Column("age"), testutil.IntegerValue(10)), false}, - {"&", "age & 10", expr.BitwiseAnd(expr.Column("age"), testutil.IntegerValue(10)), false}, - {"||", "name || 'foo'", expr.Concat(expr.Column("name"), testutil.TextValue("foo")), false}, - {"IN", "age IN ages", expr.In(expr.Column("age"), expr.Column("ages")), false}, - {"NOT IN", "age NOT IN ages", expr.NotIn(expr.Column("age"), expr.Column("ages")), false}, - {"IS", "age IS NULL", expr.Is(expr.Column("age"), testutil.NullValue()), false}, - {"IS NOT", "age IS NOT NULL", expr.IsNot(expr.Column("age"), testutil.NullValue()), false}, - {"LIKE", "name LIKE 'foo'", expr.Like(expr.Column("name"), testutil.TextValue("foo")), false}, - {"NOT LIKE", "name NOT LIKE 'foo'", expr.NotLike(expr.Column("name"), testutil.TextValue("foo")), false}, + {"+", "age + 10", expr.Add(&expr.Column{Name: "age"}, testutil.IntegerValue(10)), false}, + {"-", "age - 10", expr.Sub(&expr.Column{Name: "age"}, testutil.IntegerValue(10)), false}, + {"*", "age * 10", expr.Mul(&expr.Column{Name: "age"}, testutil.IntegerValue(10)), false}, + {"/", "age / 10", expr.Div(&expr.Column{Name: "age"}, testutil.IntegerValue(10)), false}, + {"%", "age % 10", expr.Mod(&expr.Column{Name: "age"}, testutil.IntegerValue(10)), false}, + {"&", "age & 10", expr.BitwiseAnd(&expr.Column{Name: "age"}, testutil.IntegerValue(10)), false}, + {"||", "name || 'foo'", expr.Concat(&expr.Column{Name: "name"}, testutil.TextValue("foo")), false}, + {"IN", "age IN ages", expr.In(&expr.Column{Name: "age"}, &expr.Column{Name: "ages"}), false}, + {"NOT IN", "age NOT IN ages", expr.NotIn(&expr.Column{Name: "age"}, &expr.Column{Name: "ages"}), false}, + {"IS", "age IS NULL", expr.Is(&expr.Column{Name: "age"}, testutil.NullValue()), false}, + {"IS NOT", "age IS NOT NULL", expr.IsNot(&expr.Column{Name: "age"}, testutil.NullValue()), false}, + {"LIKE", "name LIKE 'foo'", expr.Like(&expr.Column{Name: "name"}, testutil.TextValue("foo")), false}, + {"NOT LIKE", "name NOT LIKE 'foo'", expr.NotLike(&expr.Column{Name: "name"}, testutil.TextValue("foo")), false}, {"NOT =", "name NOT = 'foo'", nil, true}, {"precedence", "4 > 1 + 2", expr.Gt( testutil.IntegerValue(4), @@ -92,26 +92,26 @@ func TestParserExpr(t *testing.T) { ), false}, {"AND", "age = 10 AND age <= 11", expr.And( - expr.Eq(expr.Column("age"), testutil.IntegerValue(10)), - expr.Lte(expr.Column("age"), testutil.IntegerValue(11)), + expr.Eq(&expr.Column{Name: "age"}, testutil.IntegerValue(10)), + expr.Lte(&expr.Column{Name: "age"}, testutil.IntegerValue(11)), ), false}, {"OR", "age = 10 OR age = 11", expr.Or( - expr.Eq(expr.Column("age"), testutil.IntegerValue(10)), - expr.Eq(expr.Column("age"), testutil.IntegerValue(11)), + expr.Eq(&expr.Column{Name: "age"}, testutil.IntegerValue(10)), + expr.Eq(&expr.Column{Name: "age"}, testutil.IntegerValue(11)), ), false}, {"AND then OR", "age >= 10 AND age > $age OR age < 10.4", expr.Or( expr.And( - expr.Gte(expr.Column("age"), testutil.IntegerValue(10)), - expr.Gt(expr.Column("age"), expr.NamedParam("age")), + expr.Gte(&expr.Column{Name: "age"}, testutil.IntegerValue(10)), + expr.Gt(&expr.Column{Name: "age"}, expr.NamedParam("age")), ), - expr.Lt(expr.Column("age"), testutil.DoubleValue(10.4)), + expr.Lt(&expr.Column{Name: "age"}, testutil.DoubleValue(10.4)), ), false}, - {"with NULL", "age > NULL", expr.Gt(expr.Column("age"), testutil.NullValue()), false}, + {"with NULL", "age > NULL", expr.Gt(&expr.Column{Name: "age"}, testutil.NullValue()), false}, // unary operators - {"CAST", "CAST(a AS TEXT)", &expr.Cast{Expr: expr.Column("a"), CastAs: types.TypeText}, false}, + {"CAST", "CAST(a AS TEXT)", &expr.Cast{Expr: &expr.Column{Name: "a"}, CastAs: types.TypeText}, false}, {"NOT", "NOT 10", expr.Not(testutil.IntegerValue(10)), false}, {"NOT", "NOT NOT", nil, true}, {"NOT", "NOT NOT 10", expr.Not(expr.Not(testutil.IntegerValue(10))), false}, @@ -120,7 +120,7 @@ func TestParserExpr(t *testing.T) { {"NEXT VALUE FOR", "NEXT VALUE FOR 10", nil, true}, // functions - {"count(expr) function", "count(a)", &functions.Count{Expr: expr.Column("a")}, false}, + {"count(expr) function", "count(a)", &functions.Count{Expr: &expr.Column{Name: "a"}}, false}, {"count(*) function", "count(*)", functions.NewCount(expr.Wildcard{}), false}, {"count (*) function with spaces", "count (*)", functions.NewCount(expr.Wildcard{}), false}, {"packaged function", "floor(1.2)", testutil.FunctionExpr(t, "floor", testutil.DoubleValue(1.2)), false}, @@ -148,17 +148,17 @@ func TestParserParams(t *testing.T) { expected expr.Expr errored bool }{ - {"one positional", "age = ?", expr.Eq(expr.Column("age"), expr.PositionalParam(1)), false}, + {"one positional", "age = ?", expr.Eq(&expr.Column{Name: "age"}, expr.PositionalParam(1)), false}, {"multiple positional", "age = ? AND age <= ?", expr.And( - expr.Eq(expr.Column("age"), expr.PositionalParam(1)), - expr.Lte(expr.Column("age"), expr.PositionalParam(2)), + expr.Eq(&expr.Column{Name: "age"}, expr.PositionalParam(1)), + expr.Lte(&expr.Column{Name: "age"}, expr.PositionalParam(2)), ), false}, - {"one named", "age = $age", expr.Eq(expr.Column("age"), expr.NamedParam("age")), false}, + {"one named", "age = $age", expr.Eq(&expr.Column{Name: "age"}, expr.NamedParam("age")), false}, {"multiple named", "age = $foo OR age = $bar", expr.Or( - expr.Eq(expr.Column("age"), expr.NamedParam("foo")), - expr.Eq(expr.Column("age"), expr.NamedParam("bar")), + expr.Eq(&expr.Column{Name: "age"}, expr.NamedParam("foo")), + expr.Eq(&expr.Column{Name: "age"}, expr.NamedParam("bar")), ), false}, {"mixed", "age >= ? AND age > $foo OR age < ?", nil, true}, } diff --git a/internal/sql/parser/insert_test.go b/internal/sql/parser/insert_test.go index 151f3745..73b05c7e 100644 --- a/internal/sql/parser/insert_test.go +++ b/internal/sql/parser/insert_test.go @@ -62,7 +62,7 @@ func TestParserInsert(t *testing.T) { Pipe(table.Insert("test")). Pipe(stream.Discard()), false}, - {"Values / Returning", "INSERT INTO test (a, b) VALUES ('c', 'd') RETURNING *, a, b as B, c", + {"Values / Returning", "INSERT INTO test (a, b) VALUES ('c', 'd') RETURNING *, a, b as B", stream.New(rows.Emit( []string{"a", "b"}, expr.Row{ @@ -75,7 +75,7 @@ func TestParserInsert(t *testing.T) { )). Pipe(table.Validate("test")). Pipe(table.Insert("test")). - Pipe(rows.Project(expr.Wildcard{}, testutil.ParseNamedExpr(t, "a"), testutil.ParseNamedExpr(t, "b", "B"), testutil.ParseNamedExpr(t, "c"))), + Pipe(rows.Project(expr.Wildcard{}, testutil.ParseNamedExpr(t, "a"), testutil.ParseNamedExpr(t, "b", "B"))), false}, {"Values / With fields / Wrong values", "INSERT INTO test (a, b) VALUES {a: 1}, ('e', 'f')", nil, true}, @@ -155,9 +155,9 @@ func TestParserInsert(t *testing.T) { Pipe(table.Insert("test")). Pipe(stream.Discard()), false}, - {"Select / Without fields / With projection", "INSERT INTO test SELECT a, b FROM foo", + {"Select / Without fields / With projection", "INSERT INTO test SELECT c, d FROM foo", stream.New(table.Scan("foo")). - Pipe(rows.Project(testutil.ParseNamedExpr(t, "a"), testutil.ParseNamedExpr(t, "b"))). + Pipe(rows.Project(testutil.ParseNamedExpr(t, "c"), testutil.ParseNamedExpr(t, "d"))). Pipe(table.Validate("test")). Pipe(table.Insert("test")). Pipe(stream.Discard()), @@ -170,9 +170,9 @@ func TestParserInsert(t *testing.T) { Pipe(table.Insert("test")). Pipe(stream.Discard()), false}, - {"Select / With fields / With projection", "INSERT INTO test (a, b) SELECT a, b FROM foo", + {"Select / With fields / With projection", "INSERT INTO test (a, b) SELECT c, d FROM foo", stream.New(table.Scan("foo")). - Pipe(rows.Project(testutil.ParseNamedExpr(t, "a"), testutil.ParseNamedExpr(t, "b"))). + Pipe(rows.Project(testutil.ParseNamedExpr(t, "c"), testutil.ParseNamedExpr(t, "d"))). Pipe(path.PathsRename("a", "b")). Pipe(table.Validate("test")). Pipe(table.Insert("test")). diff --git a/internal/sql/parser/order_by.go b/internal/sql/parser/order_by.go index 42c2e286..15e71a02 100644 --- a/internal/sql/parser/order_by.go +++ b/internal/sql/parser/order_by.go @@ -7,17 +7,17 @@ import ( "github.com/chaisql/chai/internal/sql/scanner" ) -func (p *Parser) parseOrderBy() (expr.Column, scanner.Token, error) { +func (p *Parser) parseOrderBy() (*expr.Column, scanner.Token, error) { // parse ORDER token ok, err := p.parseOptional(scanner.ORDER, scanner.BY) if err != nil || !ok { - return "", 0, err + return nil, 0, err } // parse col col, err := p.parseColumn() if err != nil { - return "", 0, err + return nil, 0, err } // parse optional ASC or DESC diff --git a/internal/sql/parser/select_test.go b/internal/sql/parser/select_test.go index 82562e61..f6facca2 100644 --- a/internal/sql/parser/select_test.go +++ b/internal/sql/parser/select_test.go @@ -17,6 +17,44 @@ import ( ) func TestParserSelect(t *testing.T) { + db, tx, cleanup := testutil.NewTestTx(t) + defer cleanup() + + testutil.MustExec(t, db, tx, ` + CREATE TABLE test(a TEXT, b TEXT, age int); + CREATE TABLE test1(age INT, a INT); + CREATE TABLE test2(age INT, a INT); + CREATE TABLE a(age INT, a INT); + CREATE TABLE b(age INT, a INT); + CREATE TABLE c(age INT, a INT); + CREATE TABLE d(age INT, a INT); + `, + ) + + parseExpr := func(s string, table ...string) expr.Expr { + e := parser.MustParseExpr(s) + tb := "test" + if len(table) > 0 { + tb = table[0] + } + err := statement.BindExpr(&statement.Context{DB: db, Tx: tx, Conn: tx.Connection()}, tb, e) + require.NoError(t, err) + return e + } + + parseNamedExpr := func(t *testing.T, s string, name ...string) *expr.NamedExpr { + ne := expr.NamedExpr{ + Expr: parseExpr(s), + ExprName: s, + } + + if len(name) > 0 { + ne.ExprName = name[0] + } + + return &ne + } + tests := []struct { name string s string @@ -25,13 +63,13 @@ func TestParserSelect(t *testing.T) { mustFail bool }{ {"NoTable", "SELECT 1", - stream.New(rows.Project(testutil.ParseNamedExpr(t, "1"))), + stream.New(rows.Project(parseNamedExpr(t, "1"))), true, false, }, {"NoTableWithINOperator", "SELECT 1 in (1, 2), 3", stream.New(rows.Project( - testutil.ParseNamedExpr(t, "1 IN (1, 2)"), - testutil.ParseNamedExpr(t, "3"), + parseNamedExpr(t, "1 IN (1, 2)"), + parseNamedExpr(t, "3"), )), true, false, }, @@ -45,87 +83,87 @@ func TestParserSelect(t *testing.T) { true, false, }, {"WithFields", "SELECT a, b FROM test", - stream.New(table.Scan("test")).Pipe(rows.Project(testutil.ParseNamedExpr(t, "a"), testutil.ParseNamedExpr(t, "b"))), + stream.New(table.Scan("test")).Pipe(rows.Project(parseNamedExpr(t, "a"), parseNamedExpr(t, "b"))), true, false, }, {"WithAlias", "SELECT a AS A, b FROM test", - stream.New(table.Scan("test")).Pipe(rows.Project(testutil.ParseNamedExpr(t, "a", "A"), testutil.ParseNamedExpr(t, "b"))), + stream.New(table.Scan("test")).Pipe(rows.Project(parseNamedExpr(t, "a", "A"), parseNamedExpr(t, "b"))), true, false, }, {"WithFields and wildcard", "SELECT a, b, * FROM test", - stream.New(table.Scan("test")).Pipe(rows.Project(testutil.ParseNamedExpr(t, "a"), testutil.ParseNamedExpr(t, "b"), expr.Wildcard{})), + stream.New(table.Scan("test")).Pipe(rows.Project(parseNamedExpr(t, "a"), parseNamedExpr(t, "b"), expr.Wildcard{})), true, false, }, {"WithExpr", "SELECT a > 1 FROM test", - stream.New(table.Scan("test")).Pipe(rows.Project(testutil.ParseNamedExpr(t, "a > 1"))), + stream.New(table.Scan("test")).Pipe(rows.Project(parseNamedExpr(t, "a > 1"))), true, false, }, {"WithCond", "SELECT * FROM test WHERE age = 10", stream.New(table.Scan("test")). - Pipe(rows.Filter(parser.MustParseExpr("age = 10"))). + Pipe(rows.Filter(parseExpr("age = 10"))). Pipe(rows.Project(expr.Wildcard{})), true, false, }, {"WithGroupBy", "SELECT a FROM test WHERE age = 10 GROUP BY a", stream.New(table.Scan("test")). - Pipe(rows.Filter(parser.MustParseExpr("age = 10"))). - Pipe(rows.TempTreeSort(parser.MustParseExpr("a"))). - Pipe(rows.GroupAggregate(parser.MustParseExpr("a"))). - Pipe(rows.Project(&expr.NamedExpr{ExprName: "a", Expr: expr.Column("a")})), + Pipe(rows.Filter(parseExpr("age = 10"))). + Pipe(rows.TempTreeSort(parseExpr("a"))). + Pipe(rows.GroupAggregate(parseExpr("a"))). + Pipe(rows.Project(&expr.NamedExpr{ExprName: "a", Expr: parseExpr("a")})), true, false, }, {"WithOrderBy", "SELECT * FROM test WHERE age = 10 ORDER BY a", stream.New(table.Scan("test")). - Pipe(rows.Filter(parser.MustParseExpr("age = 10"))). + Pipe(rows.Filter(parseExpr("age = 10"))). Pipe(rows.Project(expr.Wildcard{})). - Pipe(rows.TempTreeSort(expr.Column("a"))), + Pipe(rows.TempTreeSort(parseExpr("a"))), true, false, }, {"WithOrderBy ASC", "SELECT * FROM test WHERE age = 10 ORDER BY a ASC", stream.New(table.Scan("test")). - Pipe(rows.Filter(parser.MustParseExpr("age = 10"))). + Pipe(rows.Filter(parseExpr("age = 10"))). Pipe(rows.Project(expr.Wildcard{})). - Pipe(rows.TempTreeSort(expr.Column("a"))), + Pipe(rows.TempTreeSort(parseExpr("a"))), true, false, }, {"WithOrderBy DESC", "SELECT * FROM test WHERE age = 10 ORDER BY a DESC", stream.New(table.Scan("test")). - Pipe(rows.Filter(parser.MustParseExpr("age = 10"))). + Pipe(rows.Filter(parseExpr("age = 10"))). Pipe(rows.Project(expr.Wildcard{})). - Pipe(rows.TempTreeSortReverse(expr.Column("a"))), + Pipe(rows.TempTreeSortReverse(parseExpr("a"))), true, false, }, {"WithLimit", "SELECT * FROM test WHERE age = 10 LIMIT 20", stream.New(table.Scan("test")). - Pipe(rows.Filter(parser.MustParseExpr("age = 10"))). + Pipe(rows.Filter(parseExpr("age = 10"))). Pipe(rows.Project(expr.Wildcard{})). - Pipe(rows.Take(parser.MustParseExpr("20"))), + Pipe(rows.Take(parseExpr("20"))), true, false, }, {"WithOffset", "SELECT * FROM test WHERE age = 10 OFFSET 20", stream.New(table.Scan("test")). - Pipe(rows.Filter(parser.MustParseExpr("age = 10"))). + Pipe(rows.Filter(parseExpr("age = 10"))). Pipe(rows.Project(expr.Wildcard{})). - Pipe(rows.Skip(parser.MustParseExpr("20"))), + Pipe(rows.Skip(parseExpr("20"))), true, false, }, {"WithLimitThenOffset", "SELECT * FROM test WHERE age = 10 LIMIT 10 OFFSET 20", stream.New(table.Scan("test")). - Pipe(rows.Filter(parser.MustParseExpr("age = 10"))). + Pipe(rows.Filter(parseExpr("age = 10"))). Pipe(rows.Project(expr.Wildcard{})). - Pipe(rows.Skip(parser.MustParseExpr("20"))). - Pipe(rows.Take(parser.MustParseExpr("10"))), + Pipe(rows.Skip(parseExpr("20"))). + Pipe(rows.Take(parseExpr("10"))), true, false, }, {"WithOffsetThenLimit", "SELECT * FROM test WHERE age = 10 OFFSET 20 LIMIT 10", nil, true, true}, {"With aggregation function", "SELECT COUNT(*) FROM test", stream.New(table.Scan("test")). Pipe(rows.GroupAggregate(nil, functions.NewCount(expr.Wildcard{}))). - Pipe(rows.Project(testutil.ParseNamedExpr(t, "COUNT(*)"))), + Pipe(rows.Project(parseNamedExpr(t, "COUNT(*)"))), true, false}, {"With NEXT VALUE FOR", "SELECT NEXT VALUE FOR foo FROM test", stream.New(table.Scan("test")). - Pipe(rows.Project(testutil.ParseNamedExpr(t, "NEXT VALUE FOR foo"))), + Pipe(rows.Project(parseNamedExpr(t, "NEXT VALUE FOR foo"))), false, false}, {"WithUnionAll", "SELECT * FROM test1 UNION ALL SELECT * FROM test2", stream.New(stream.Concat( @@ -137,7 +175,7 @@ func TestParserSelect(t *testing.T) { {"CondWithUnionAll", "SELECT * FROM test1 WHERE age = 10 UNION ALL SELECT * FROM test2", stream.New(stream.Concat( stream.New(table.Scan("test1")). - Pipe(rows.Filter(parser.MustParseExpr("age = 10"))). + Pipe(rows.Filter(parseExpr("age = 10", "test1"))). Pipe(rows.Project(expr.Wildcard{})), stream.New(table.Scan("test2")). Pipe(rows.Project(expr.Wildcard{})), @@ -162,7 +200,7 @@ func TestParserSelect(t *testing.T) { Pipe(rows.Project(expr.Wildcard{})), stream.New(table.Scan("test2")). Pipe(rows.Project(expr.Wildcard{})), - )).Pipe(rows.TempTreeSort(expr.Column("a"))), + )).Pipe(rows.TempTreeSort(parseExpr("a", "test1"))), true, false, }, {"WithUnionAllAndLimit", "SELECT * FROM test1 UNION ALL SELECT * FROM test2 LIMIT 10", @@ -171,7 +209,7 @@ func TestParserSelect(t *testing.T) { Pipe(rows.Project(expr.Wildcard{})), stream.New(table.Scan("test2")). Pipe(rows.Project(expr.Wildcard{})), - )).Pipe(rows.Take(parser.MustParseExpr("10"))), + )).Pipe(rows.Take(parseExpr("10"))), true, false, }, {"WithUnionAllAndOffset", "SELECT * FROM test1 UNION ALL SELECT * FROM test2 OFFSET 20", @@ -180,7 +218,7 @@ func TestParserSelect(t *testing.T) { Pipe(rows.Project(expr.Wildcard{})), stream.New(table.Scan("test2")). Pipe(rows.Project(expr.Wildcard{})), - )).Pipe(rows.Skip(parser.MustParseExpr("20"))), + )).Pipe(rows.Skip(parseExpr("20"))), true, false, }, {"WithUnionAllAndOrderByAndLimitAndOffset", "SELECT * FROM test1 UNION ALL SELECT * FROM test2 ORDER BY a LIMIT 10 OFFSET 20", @@ -189,7 +227,7 @@ func TestParserSelect(t *testing.T) { Pipe(rows.Project(expr.Wildcard{})), stream.New(table.Scan("test2")). Pipe(rows.Project(expr.Wildcard{})), - )).Pipe(rows.TempTreeSort(expr.Column("a"))).Pipe(rows.Skip(parser.MustParseExpr("20"))).Pipe(rows.Take(parser.MustParseExpr("10"))), + )).Pipe(rows.TempTreeSort(parseExpr("a", "test1"))).Pipe(rows.Skip(parseExpr("20"))).Pipe(rows.Take(parseExpr("10"))), true, false, }, @@ -205,7 +243,7 @@ func TestParserSelect(t *testing.T) { {"CondWithUnion", "SELECT * FROM test1 WHERE age = 10 UNION SELECT * FROM test2", stream.New(stream.Union( stream.New(table.Scan("test1")). - Pipe(rows.Filter(parser.MustParseExpr("age = 10"))). + Pipe(rows.Filter(parseExpr("age = 10", "test1"))). Pipe(rows.Project(expr.Wildcard{})), stream.New(table.Scan("test2")). Pipe(rows.Project(expr.Wildcard{})), @@ -230,7 +268,7 @@ func TestParserSelect(t *testing.T) { Pipe(rows.Project(expr.Wildcard{})), stream.New(table.Scan("test2")). Pipe(rows.Project(expr.Wildcard{})), - )).Pipe(rows.TempTreeSort(expr.Column("a"))), + )).Pipe(rows.TempTreeSort(parseExpr("a", "test1"))), true, false, }, {"WithUnionAndLimit", "SELECT * FROM test1 UNION SELECT * FROM test2 LIMIT 10", @@ -239,7 +277,7 @@ func TestParserSelect(t *testing.T) { Pipe(rows.Project(expr.Wildcard{})), stream.New(table.Scan("test2")). Pipe(rows.Project(expr.Wildcard{})), - )).Pipe(rows.Take(parser.MustParseExpr("10"))), + )).Pipe(rows.Take(parseExpr("10"))), true, false, }, {"WithUnionAndOffset", "SELECT * FROM test1 UNION SELECT * FROM test2 OFFSET 20", @@ -248,7 +286,7 @@ func TestParserSelect(t *testing.T) { Pipe(rows.Project(expr.Wildcard{})), stream.New(table.Scan("test2")). Pipe(rows.Project(expr.Wildcard{})), - )).Pipe(rows.Skip(parser.MustParseExpr("20"))), + )).Pipe(rows.Skip(parseExpr("20"))), true, false, }, {"WithUnionAndOrderByAndLimitAndOffset", "SELECT * FROM test1 UNION SELECT * FROM test2 ORDER BY a LIMIT 10 OFFSET 20", @@ -257,7 +295,7 @@ func TestParserSelect(t *testing.T) { Pipe(rows.Project(expr.Wildcard{})), stream.New(table.Scan("test2")). Pipe(rows.Project(expr.Wildcard{})), - )).Pipe(rows.TempTreeSort(expr.Column("a"))).Pipe(rows.Skip(parser.MustParseExpr("20"))).Pipe(rows.Take(parser.MustParseExpr("10"))), + )).Pipe(rows.TempTreeSort(parseExpr("a", "test1"))).Pipe(rows.Skip(parseExpr("20"))).Pipe(rows.Take(parseExpr("10"))), true, false, }, {"WithMultipleCompoundOps/1", "SELECT * FROM a UNION ALL SELECT * FROM b UNION ALL SELECT * FROM c", @@ -343,7 +381,7 @@ func TestParserSelect(t *testing.T) { stream.New(table.Scan("c")). Pipe(rows.Project(expr.Wildcard{})), )), - stream.New(table.Scan("d")).Pipe(rows.Project(testutil.ParseNamedExpr(t, "NEXT VALUE FOR foo"))), + stream.New(table.Scan("d")).Pipe(rows.Project(parseNamedExpr(t, "NEXT VALUE FOR foo"))), )), false, false, }, @@ -352,33 +390,20 @@ func TestParserSelect(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { q, err := parser.ParseQuery(test.s) - if !test.mustFail { - db, tx, cleanup := testutil.NewTestTx(t) - defer cleanup() - - testutil.MustExec(t, db, tx, ` - CREATE TABLE test(a TEXT, b TEXT, age int); - CREATE TABLE test1(age INT, a INT); - CREATE TABLE test2(age INT, a INT); - CREATE TABLE a(age INT, a INT); - CREATE TABLE b(age INT, a INT); - CREATE TABLE c(age INT, a INT); - CREATE TABLE d(age INT, a INT); - `, - ) - - err = q.Prepare(&query.Context{ - Ctx: context.Background(), - DB: db, - Conn: tx.Connection(), - }) - require.NoError(t, err) - - require.Len(t, q.Statements, 1) - require.EqualValues(t, &statement.PreparedStreamStmt{ReadOnly: test.readOnly, Stream: test.expected}, q.Statements[0].(*statement.PreparedStreamStmt)) - } else { + if test.mustFail { require.Error(t, err) + return } + + err = q.Prepare(&query.Context{ + Ctx: context.Background(), + DB: db, + Conn: tx.Connection(), + }) + require.NoError(t, err) + + require.Len(t, q.Statements, 1) + require.EqualValues(t, &statement.PreparedStreamStmt{ReadOnly: test.readOnly, Stream: test.expected}, q.Statements[0].(*statement.PreparedStreamStmt)) }) } } diff --git a/internal/sql/parser/update_test.go b/internal/sql/parser/update_test.go index 91a73e5e..7f3fe791 100644 --- a/internal/sql/parser/update_test.go +++ b/internal/sql/parser/update_test.go @@ -4,6 +4,7 @@ import ( "context" "testing" + "github.com/chaisql/chai/internal/expr" "github.com/chaisql/chai/internal/query" "github.com/chaisql/chai/internal/query/statement" "github.com/chaisql/chai/internal/sql/parser" @@ -16,6 +17,22 @@ import ( ) func TestParserUpdate(t *testing.T) { + db, tx, cleanup := testutil.NewTestTx(t) + defer cleanup() + + testutil.MustExec(t, db, tx, "CREATE TABLE test(a INT, b TEXT)") + + parseExpr := func(s string, table ...string) expr.Expr { + e := parser.MustParseExpr(s) + tb := "test" + if len(table) > 0 { + tb = table[0] + } + err := statement.BindExpr(&statement.Context{DB: db, Tx: tx, Conn: tx.Connection()}, tb, e) + require.NoError(t, err) + return e + } + tests := []struct { name string s string @@ -32,9 +49,9 @@ func TestParserUpdate(t *testing.T) { }, {"SET/With cond", "UPDATE test SET a = 1, b = 2 WHERE a = 10", stream.New(table.Scan("test")). - Pipe(rows.Filter(parser.MustParseExpr("a = 10"))). + Pipe(rows.Filter(parseExpr("a = 10"))). Pipe(path.Set("a", testutil.IntegerValue(1))). - Pipe(path.Set("b", parser.MustParseExpr("2"))). + Pipe(path.Set("b", parseExpr("2"))). Pipe(table.Validate("test")). Pipe(table.Replace("test")). Pipe(stream.Discard()), @@ -49,11 +66,6 @@ func TestParserUpdate(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - db, tx, cleanup := testutil.NewTestTx(t) - defer cleanup() - - testutil.MustExec(t, db, tx, "CREATE TABLE test(a INT, b TEXT)") - q, err := parser.ParseQuery(test.s) if test.errored { require.Error(t, err) diff --git a/internal/stream/rows/project.go b/internal/stream/rows/project.go index 07e90398..a30af2db 100644 --- a/internal/stream/rows/project.go +++ b/internal/stream/rows/project.go @@ -4,11 +4,11 @@ import ( "fmt" "strings" + "github.com/chaisql/chai/internal/database" "github.com/chaisql/chai/internal/environment" "github.com/chaisql/chai/internal/expr" "github.com/chaisql/chai/internal/row" "github.com/chaisql/chai/internal/stream" - "github.com/chaisql/chai/internal/tree" "github.com/chaisql/chai/internal/types" "github.com/cockroachdb/errors" ) @@ -59,26 +59,68 @@ func (op *ProjectOperator) Columns(env *environment.Environment) ([]string, erro // Iterate implements the Operator interface. func (op *ProjectOperator) Iterate(in *environment.Environment, f func(out *environment.Environment) error) error { - var mask RowMask + cb := row.NewColumnBuffer() + var br database.BasicRow + var newEnv environment.Environment if op.Prev == nil { - mask.Env = in - mask.Exprs = op.Exprs - newEnv.SetRow(&mask) + for _, e := range op.Exprs { + if _, ok := e.(expr.Wildcard); ok { + return errors.New("no table specified") + } + + v, err := e.Eval(in) + if err != nil { + return err + } + + cb.Add(e.String(), v) + } + + br.ResetWith("", nil, cb) + newEnv.SetRow(&br) newEnv.SetOuter(in) return f(&newEnv) } return op.Prev.Iterate(in, func(env *environment.Environment) error { - r, ok := env.GetDatabaseRow() + cb.Reset() + + for _, e := range op.Exprs { + if _, ok := e.(expr.Wildcard); ok { + r, ok := env.GetRow() + if !ok { + return errors.New("no table specified") + } + + err := r.Iterate(func(field string, value types.Value) error { + cb.Add(field, value) + return nil + }) + if err != nil { + return err + } + + continue + } + + v, err := e.Eval(env) + if err != nil { + return err + } + + cb.Add(e.String(), v) + } + + dr, ok := env.GetDatabaseRow() if ok { - mask.tableName = r.TableName() - mask.key = r.Key() + br.ResetWith(dr.TableName(), dr.Key(), cb) + } else { + br.ResetWith("", nil, cb) } - mask.Env = env - mask.Exprs = op.Exprs - newEnv.SetRow(&mask) + newEnv.SetRow(&br) + newEnv.SetOuter(env) return f(&newEnv) }) @@ -97,91 +139,3 @@ func (op *ProjectOperator) String() string { b.WriteString(")") return b.String() } - -type RowMask struct { - Env *environment.Environment - Exprs []expr.Expr - key *tree.Key - tableName string -} - -func (m *RowMask) Key() *tree.Key { - return m.key -} - -func (m *RowMask) TableName() string { - return m.tableName -} - -func (m *RowMask) Get(column string) (v types.Value, err error) { - for _, e := range m.Exprs { - if _, ok := e.(expr.Wildcard); ok { - r, ok := m.Env.GetRow() - if !ok { - continue - } - - v, err = r.Get(column) - if errors.Is(err, types.ErrColumnNotFound) { - continue - } - return - } - - if ne, ok := e.(*expr.NamedExpr); ok && ne.Name() == column { - return e.Eval(m.Env) - } - - if col, ok := e.(expr.Column); ok && col.String() == column { - return e.Eval(m.Env) - } - - if e.(fmt.Stringer).String() == column { - return e.Eval(m.Env) - } - } - - err = errors.Wrapf(types.ErrColumnNotFound, "%s not found", column) - return -} - -func (m *RowMask) Iterate(fn func(field string, value types.Value) error) error { - for _, e := range m.Exprs { - if _, ok := e.(expr.Wildcard); ok { - r, ok := m.Env.GetRow() - if !ok { - return nil - } - - err := r.Iterate(fn) - if err != nil { - return errors.Wrap(err, "wildcard iteration") - } - - continue - } - - var col string - if ne, ok := e.(*expr.NamedExpr); ok { - col = ne.Name() - } else { - col = e.(fmt.Stringer).String() - } - - v, err := e.Eval(m.Env) - if err != nil { - return err - } - - err = fn(col, v) - if err != nil { - return err - } - } - - return nil -} - -func (m *RowMask) MarshalJSON() ([]byte, error) { - return row.MarshalJSON(m) -} diff --git a/internal/stream/rows/project_test.go b/internal/stream/rows/project_test.go index 95893df4..a056a575 100644 --- a/internal/stream/rows/project_test.go +++ b/internal/stream/rows/project_test.go @@ -8,6 +8,7 @@ import ( "github.com/chaisql/chai/internal/expr" "github.com/chaisql/chai/internal/row" "github.com/chaisql/chai/internal/sql/parser" + "github.com/chaisql/chai/internal/stream" "github.com/chaisql/chai/internal/stream/rows" "github.com/chaisql/chai/internal/testutil" "github.com/chaisql/chai/internal/types" @@ -18,35 +19,35 @@ func TestProject(t *testing.T) { tests := []struct { name string exprs []expr.Expr - in row.Row + in expr.Row out string fails bool }{ { "Constant", []expr.Expr{parser.MustParseExpr("10")}, - testutil.MakeRow(t, `{"a":1,"b":true}`), + testutil.MakeRowExpr(t, `{"a":1,"b":true}`), `{"10":10}`, false, }, { "Wildcard", []expr.Expr{expr.Wildcard{}}, - testutil.MakeRow(t, `{"a":1,"b":true}`), + testutil.MakeRowExpr(t, `{"a":1,"b":true}`), `{"a":1,"b":true}`, false, }, { "Multiple", []expr.Expr{expr.Wildcard{}, expr.Wildcard{}, parser.MustParseExpr("10")}, - testutil.MakeRow(t, `{"a":1,"b":true}`), + testutil.MakeRowExpr(t, `{"a":1,"b":true}`), `{"a":1,"b":true,"a":1,"b":true,"10":10}`, false, }, { "Named", []expr.Expr{&expr.NamedExpr{Expr: parser.MustParseExpr("10"), ExprName: "foo"}}, - testutil.MakeRow(t, `{"a":1,"b":true}`), + testutil.MakeRowExpr(t, `{"a":1,"b":true}`), `{"foo":10}`, false, }, @@ -55,10 +56,9 @@ func TestProject(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { var inEnv environment.Environment - inEnv.SetRow(test.in) - err := rows.Project(test.exprs...).Iterate(&inEnv, func(out *environment.Environment) error { - require.Equal(t, &inEnv, out.GetOuter()) + err := stream.New(rows.Emit([]string{"a", "b"}, test.in)). + Pipe(rows.Project(test.exprs...)).Iterate(&inEnv, func(out *environment.Environment) error { r, ok := out.GetRow() require.True(t, ok) tt, err := json.Marshal(r) diff --git a/internal/stream/rows/temp_tree_sort.go b/internal/stream/rows/temp_tree_sort.go index 33586a63..125c11aa 100644 --- a/internal/stream/rows/temp_tree_sort.go +++ b/internal/stream/rows/temp_tree_sort.go @@ -68,7 +68,7 @@ func (op *TempTreeSortOperator) Iterate(in *environment.Environment, fn func(out if v == nil { // the expression might be pointing to the original row. - v, err = op.Expr.Eval(out.Outer) + v, err = op.Expr.Eval(out.GetOuter()) if err != nil { // the only valid error here is a missing column. if !errors.Is(err, types.ErrColumnNotFound) { @@ -82,8 +82,6 @@ func (op *TempTreeSortOperator) Iterate(in *environment.Environment, fn func(out return errors.New("missing row") } - // TODO: we should find a way to encode using the table info. - buf, err = encodeTempRow(buf, r) if err != nil { return errors.Wrap(err, "failed to encode row") diff --git a/internal/testutil/stream.go b/internal/testutil/stream.go index 848ed413..f1afa7e5 100644 --- a/internal/testutil/stream.go +++ b/internal/testutil/stream.go @@ -161,3 +161,24 @@ func RequireStreamEqf(t *testing.T, raw string, res *chai.Result, msg string, ar require.Equal(t, expected.String(), actual.String()) } } + +// rows, err := db.Query(test.Expr) +// if err != nil { +// return err +// } +// defer rows.Close() + +// cols, err := rows.Columns() +// if err != nil { +// return err +// } + +// for rows.Next() { +// var vals []interface{} +// for range cols { +// vals = append(vals, new(interface{})) +// } +// if err := rows.Scan(vals...); err != nil { +// return err +// } +// } diff --git a/internal/types/types.go b/internal/types/types.go index 536cb380..0fe9037a 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -201,21 +201,6 @@ func (t Type) IsAny() bool { return t == TypeAny } -type Value interface { - Comparable - - Type() Type - V() any - IsZero() (bool, error) - String() string - MarshalJSON() ([]byte, error) - MarshalText() ([]byte, error) - TypeDef() TypeDefinition - Encode(dst []byte) ([]byte, error) - EncodeAsKey(dst []byte) ([]byte, error) - CastAs(t Type) (Value, error) -} - type TypeDefinition interface { New(v any) Value Type() Type diff --git a/internal/types/value.go b/internal/types/value.go index 89f76b61..2bfce60b 100644 --- a/internal/types/value.go +++ b/internal/types/value.go @@ -6,6 +6,21 @@ import ( "time" ) +type Value interface { + Comparable + + Type() Type + V() any + IsZero() (bool, error) + String() string + MarshalJSON() ([]byte, error) + MarshalText() ([]byte, error) + TypeDef() TypeDefinition + Encode(dst []byte) ([]byte, error) + EncodeAsKey(dst []byte) ([]byte, error) + CastAs(t Type) (Value, error) +} + func AsBool(v Value) bool { return v.V().(bool) } @@ -89,3 +104,43 @@ func IsTruthy(v Value) (bool, error) { b, err := v.IsZero() return !b, err } + +// ValueScanner implements the sql.Scanner interface for Value. +// The src value will be of one of the following types: +// +// int32 +// int64 +// float64 +// bool +// []byte +// string +// time.Time +// nil - for NULL values +type ValueScanner struct { + V Value +} + +func (v *ValueScanner) Scan(src any) error { + switch t := src.(type) { + case int32: + v.V = NewIntegerValue(t) + case int64: + v.V = NewBigintValue(t) + case float64: + v.V = NewDoubleValue(t) + case bool: + v.V = NewBooleanValue(t) + case []byte: + v.V = NewBlobValue(t) + case string: + v.V = NewTextValue(t) + case time.Time: + v.V = NewTimestampValue(t) + case nil: + v.V = NewNullValue() + default: + return fmt.Errorf("unexpected type: %T", src) + } + + return nil +} diff --git a/sqltests/SELECT/union.sql b/sqltests/SELECT/union.sql index bac79823..19829b65 100644 --- a/sqltests/SELECT/union.sql +++ b/sqltests/SELECT/union.sql @@ -24,8 +24,8 @@ SELECT * FROM baz; /* result: {"a": 1.0, "b": 1.0} {"a": 2.0, "b": 2.0} -{"x": "a", "y": "a"} -{"x": "b", "y": "b"} +{"a": "a", "b": "a"} +{"a": "b", "b": "b"} */ -- test: union all with conditions @@ -34,7 +34,7 @@ UNION ALL SELECT * FROM baz WHERE x != "b"; /* result: {"a": 2.0, "b": 2.0} -{"x": "a", "y": "a"} +{"a": "a", "b": "a"} */ -- test: self union all @@ -57,8 +57,8 @@ SELECT * FROM baz; {"a": 2.0, "b": 2.0} {"a": 2.0, "b": 2.0} {"a": 3.0, "b": 3.0} -{"x": "a", "y": "a"} -{"x": "b", "y": "b"} +{"a": "a", "b": "a"} +{"a": "b", "b": "b"} */ -- test: basic union @@ -78,8 +78,8 @@ SELECT * FROM baz; /* result: {"a": 1.0, "b": 1.0} {"a": 2.0, "b": 2.0} -{"x": "a", "y": "a"} -{"x": "b", "y": "b"} +{"a": "a", "b": "a"} +{"a": "b", "b": "b"} */ -- test: union with conditions @@ -88,7 +88,7 @@ UNION SELECT * FROM baz WHERE x != "b"; /* result: {"a": 2.0, "b": 2.0} -{"x": "a", "y": "a"} +{"a": "a", "b": "a"} */ -- test: self union @@ -120,8 +120,8 @@ SELECT * FROM baz; {"a": 2.0, "b": 2.0} {"a": 2.0, "b": 2.0} {"a": 3.0, "b": 3.0} -{"x": "a", "y": "a"} -{"x": "b", "y": "b"} +{"a": "a", "b": "a"} +{"a": "b", "b": "b"} */ -- test: combined unions @@ -134,5 +134,5 @@ SELECT * FROM baz; {"a": 1.0, "b": 1.0} {"a": 2.0, "b": 2.0} {"a": 3.0, "b": 3.0} -{"x": "a", "y": "a"} -{"x": "b", "y": "b"} \ No newline at end of file +{"a": "a", "b": "a"} +{"a": "b", "b": "b"} \ No newline at end of file diff --git a/sqltests/sql_test.go b/sqltests/sql_test.go index 0ebc81ac..27e65780 100644 --- a/sqltests/sql_test.go +++ b/sqltests/sql_test.go @@ -2,6 +2,8 @@ package sql_test import ( "bufio" + "database/sql" + "errors" "io" "io/fs" "log" @@ -10,8 +12,11 @@ import ( "strings" "testing" - "github.com/chaisql/chai" + _ "github.com/chaisql/chai/driver" + "github.com/chaisql/chai/internal/row" + "github.com/chaisql/chai/internal/sql/parser" "github.com/chaisql/chai/internal/testutil" + "github.com/chaisql/chai/internal/types" "github.com/stretchr/testify/require" ) @@ -65,9 +70,9 @@ func TestSQL(t *testing.T) { } t.Run(ts.Filename, func(t *testing.T) { - setup := func(t *testing.T, db *chai.DB) { + setup := func(t *testing.T, db *sql.DB) { t.Helper() - err := db.Exec(ts.Setup) + _, err := db.Exec(ts.Setup) require.NoError(t, err) } @@ -95,7 +100,7 @@ func TestSQL(t *testing.T) { for _, test := range tests { t.Run(test.Name, func(t *testing.T) { - db, err := chai.Open(":memory:") + db, err := sql.Open("chai", ":memory:") require.NoError(t, err) defer db.Close() @@ -105,28 +110,14 @@ func TestSQL(t *testing.T) { // post setup if suite.PostSetup != "" { - err = db.Exec(suite.PostSetup) + _, err = db.Exec(suite.PostSetup) require.NoError(t, err) } if test.Fails { exec := func() error { - conn, err := db.Connect() - if err != nil { - return err - } - defer conn.Close() - - res, err := conn.Query(test.Expr) - if err != nil { - return err - } - defer res.Close() - - return res.Iterate(func(r *chai.Row) error { - _, err := r.MarshalJSON() - return err - }) + _, err := db.Exec(test.Expr) + return err } err := exec() @@ -137,15 +128,11 @@ func TestSQL(t *testing.T) { require.Errorf(t, err, "\nSource:%s:%d expected\n%s\nto raise an error but got none", absPath, test.Line, test.Expr) } } else { - conn, err := db.Connect() + rows, err := db.Query(test.Expr) require.NoError(t, err, "Source: %s:%d", absPath, test.Line) - defer conn.Close() + defer rows.Close() - res, err := conn.Query(test.Expr) - require.NoError(t, err, "Source: %s:%d", absPath, test.Line) - defer res.Close() - - testutil.RequireStreamEqf(t, test.Result, res, "Source: %s:%d", absPath, test.Line) + RequireRowsEqf(t, test.Result, rows, "Source: %s:%d", absPath, test.Line) } }) } @@ -285,3 +272,83 @@ func parse(r io.Reader, filename string) *testSuite { return &ts } + +func RequireRowsEqf(t *testing.T, raw string, rows *sql.Rows, msg string, args ...any) { + errMsg := append([]any{msg}, args...) + t.Helper() + r := testutil.ParseResultStream(raw) + + var want []row.Row + + for { + v, err := r.Next() + if err != nil { + if perr, ok := err.(*parser.ParseError); ok { + if perr.Found == "EOF" { + break + } + } else if perr, ok := errors.Unwrap(err).(*parser.ParseError); ok { + if perr.Found == "EOF" { + break + } + } + } + require.NoError(t, err, errMsg...) + + want = append(want, v) + } + + var got []row.Row + + cols, err := rows.Columns() + require.NoError(t, err, errMsg...) + + for rows.Next() { + vals := make([]any, len(cols)) + for i := range vals { + vals[i] = new(types.ValueScanner) + } + err := rows.Scan(vals...) + require.NoError(t, err, errMsg...) + + var cb row.ColumnBuffer + + for i := range vals { + cb.Add(cols[i], vals[i].(*types.ValueScanner).V) + } + + got = append(got, &cb) + } + + if err := rows.Err(); err != nil { + require.NoError(t, err, errMsg...) + } + + var expected strings.Builder + for i := range want { + data, err := row.MarshalTextIndent(want[i], "\n", " ") + require.NoError(t, err, errMsg...) + if i > 0 { + expected.WriteString("\n") + } + + expected.WriteString(string(data)) + } + + var actual strings.Builder + for i := range got { + data, err := row.MarshalTextIndent(got[i], "\n", " ") + require.NoError(t, err, errMsg...) + if i > 0 { + actual.WriteString("\n") + } + + actual.WriteString(string(data)) + } + + if msg != "" { + require.Equal(t, expected.String(), actual.String(), errMsg...) + } else { + require.Equal(t, expected.String(), actual.String()) + } +}