Skip to content

Commit

Permalink
refactor driver.go
Browse files Browse the repository at this point in the history
  • Loading branch information
pirosiki197 committed Jan 27, 2025
1 parent ad46b6f commit 4bea800
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 221 deletions.
6 changes: 2 additions & 4 deletions template/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,8 @@ func TestCacheRows(t *testing.T) {
if err != nil {
return err
}
cacheRows = newCacheRows(rows)
defer cacheRows.Close()

return cacheRows.createCache()
cacheRows, err = newCacheRows(rows)
return err
})
if err != nil {
t.Error(err)
Expand Down
181 changes: 83 additions & 98 deletions template/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"io"
"log"
"strings"
"sync"
"time"

"github.com/go-sql-driver/mysql"
Expand Down Expand Up @@ -115,13 +115,12 @@ func (c *cacheConn) Prepare(rawQuery string) (driver.Stmt, error) {

queryInfo, ok := queryMap[normalizedQuery]
if !ok {
// unknown (insert, update, delete) query
if !strings.HasPrefix(strings.ToUpper(normalizedQuery), "SELECT") {
log.Println("unknown query:", normalizedQuery)
PurgeAllCaches()
}
return c.inner.Prepare(rawQuery)
} else if strings.Contains(normalizedQuery, "FOR UPDATE") {
return c.inner.Prepare(rawQuery)
}

if queryInfo.Type == domains.CachePlanQueryType_SELECT && !queryInfo.Select.Cache {
Expand Down Expand Up @@ -166,117 +165,50 @@ func (c *cacheConn) Ping(ctx context.Context) error {
var _ driver.Rows = &cacheRows{}

type cacheRows struct {
inner driver.Rows
cached bool
columns []string
rows sliceRows
limit int
}

func newCacheRows(inner driver.Rows) (*cacheRows, error) {
r := new(cacheRows)

err := r.cacheInnerRows(inner)
if err != nil {
return nil, err
}

mu sync.Mutex
return r, nil
}

func (r *cacheRows) Clone() *cacheRows {
func (r *cacheRows) clone() *cacheRows {
if !r.cached {
panic("cannot clone uncached rows")
}
return &cacheRows{
inner: r.inner,
cached: r.cached,
columns: r.columns,
rows: r.rows.clone(),
limit: r.limit,
}
}

func newCacheRows(inner driver.Rows) *cacheRows {
return &cacheRows{inner: inner}
}

type row = []driver.Value

type sliceRows struct {
rows []row
idx int
}

func (r sliceRows) clone() sliceRows {
rows := make([]row, len(r.rows))
copy(rows, r.rows)
return sliceRows{rows: rows}
}

func (r *sliceRows) append(row ...row) {
r.rows = append(r.rows, row...)
}

func (r *sliceRows) reset() {
r.idx = 0
}

func (r *sliceRows) Next(dest []driver.Value, limit int) error {
if r.idx >= len(r.rows) {
r.reset()
return io.EOF
}
if limit > 0 && r.idx >= limit {
r.reset()
return io.EOF
}
row := r.rows[r.idx]
r.idx++
copy(dest, row)
return nil
}

func (r *cacheRows) Columns() []string {
if r.cached {
return r.columns
if !r.cached {
panic("cannot get columns of uncached rows")
}
columns := r.inner.Columns()
r.columns = make([]string, len(columns))
copy(r.columns, columns)
return columns
return r.columns
}

func (r *cacheRows) Close() error {
if r.cached {
r.rows.reset()
return nil
}
return r.inner.Close()
r.rows.reset()
return nil
}

func (r *cacheRows) Next(dest []driver.Value) error {
if r.cached {
return r.rows.Next(dest, r.limit)
}

err := r.inner.Next(dest)
if err != nil {
if err == io.EOF {
r.cached = true
return err
}
return err
}

cachedRow := make(row, len(dest))
for i := 0; i < len(dest); i++ {
switch v := dest[i].(type) {
case int64, uint64, float64, string, bool, time.Time, nil: // no need to copy
cachedRow[i] = v
case []byte: // copy to prevent mutation
data := make([]byte, len(v))
copy(data, v)
cachedRow[i] = data
default:
// TODO: handle other types
// Should we mark this row as uncacheable?
}
if !r.cached {
return fmt.Errorf("cannot get next row of uncached rows")
}
r.rows.append(cachedRow)

return nil
return r.rows.next(dest)
}

func mergeCachedRows(rows []*cacheRows) *cacheRows {
Expand All @@ -289,7 +221,7 @@ func mergeCachedRows(rows []*cacheRows) *cacheRows {

mergedSlice := sliceRows{}
for _, r := range rows {
mergedSlice.append(r.rows.rows...)
mergedSlice.concat(r.rows)
}

return &cacheRows{
Expand All @@ -299,20 +231,73 @@ func mergeCachedRows(rows []*cacheRows) *cacheRows {
}
}

func (r *cacheRows) createCache() error {
r.mu.Lock()
defer r.mu.Unlock()
columns := r.Columns()
func (r *cacheRows) cacheInnerRows(inner driver.Rows) error {
columns := inner.Columns()
r.columns = columns
dest := make([]driver.Value, len(columns))

for {
err := r.Next(dest)
err := inner.Next(dest)
if err == io.EOF {
break
}
if err != nil {
} else if err != nil {
return err
}

cachedRow := make(row, len(dest))
for i := 0; i < len(dest); i++ {
switch v := dest[i].(type) {
case int64, uint64, float64, string, bool, time.Time, nil: // no need to copy
cachedRow[i] = v
case []byte: // copy to prevent mutation
data := make([]byte, len(v))
copy(data, v)
cachedRow[i] = data
default:
// TODO: handle other types
// Should we mark this row as uncacheable?
}
}
r.rows.append(cachedRow)
}
r.Close()

r.cached = true

return nil
}

type row = []driver.Value

type sliceRows struct {
rows []row
idx int
}

func (r sliceRows) clone() sliceRows {
rows := make([]row, len(r.rows))
copy(rows, r.rows)
return sliceRows{rows: rows}
}

func (r *sliceRows) append(row ...row) {
r.rows = append(r.rows, row...)
}

func (r *sliceRows) concat(rows sliceRows) {
r.rows = append(r.rows, rows.rows...)
}

func (r *sliceRows) reset() {
r.idx = 0
}

func (r *sliceRows) next(dest []driver.Value) error {
if r.idx >= len(r.rows) {
r.reset()
return io.EOF
}
row := r.rows[r.idx]
r.idx++
copy(dest, row)
return nil
}
Loading

0 comments on commit 4bea800

Please sign in to comment.