Skip to content
This repository has been archived by the owner on Sep 9, 2024. It is now read-only.

🗞️ Only allocate pointers when we are ready to set/use the result #51

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
All notable changes to this project will be documented in this file.
This project adheres to [Semantic Versioning](http://semver.org/).

## v2.0.2 - 2019-07-01

### Fixed
- If an error is thrown whilst scanning/reading a single row, the original pointer will stay untouched (previously it may have been allocated with the initial value of your struct)

## v2.0.1 - 2019-06-28

### Changed
Expand Down
2 changes: 1 addition & 1 deletion doc.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Package gocassa is a high-level library on top of gocql
//
// Current version: v2.0.1
// Current version: v2.0.2
// Compared to gocql it provides query building, adds data binding, and provides
// easy-to-use "recipe" tables for common query use-cases. Unlike cqlc, it does
// not use code generation.
Expand Down
52 changes: 26 additions & 26 deletions scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,25 @@ func (s *scanner) ScanIter(iter Scannable) (int, error) {
}

func (s *scanner) iterSlice(iter Scannable) (int, error) {
// Extract the type of the slice
sliceType := getNonPtrType(reflect.TypeOf(s.result))
sliceElemType := sliceType.Elem()
sliceElemValType := getNonPtrType(sliceType.Elem())

// Extract the type of the underlying struct
structFields, err := s.structFields(sliceElemValType)
if err != nil {
return 0, err
}

// If we're given a pointer address to nil, we are responsible for
// allocating it before we assign. Note that this could be a ptr to
// a ptr (and so forth)
err := allocateNilReference(s.result)
err = allocateNilReference(s.result)
if err != nil {
return 0, err
}

// Extract the type of the slice
sliceType := getNonPtrType(reflect.TypeOf(s.result))
sliceElemType := sliceType.Elem()
sliceElemValType := getNonPtrType(sliceType.Elem())

// To preserve prior behaviour, if the result slice is not empty
// then allocate a new slice and set it as the value
sliceElem := reflect.ValueOf(s.result)
Expand All @@ -63,12 +69,6 @@ func (s *scanner) iterSlice(iter Scannable) (int, error) {
sliceElem.Set(reflect.Zero(sliceType))
}

// Extract the type of the underlying struct
structFields, err := s.structFields(sliceElemValType)
if err != nil {
return 0, err
}

rowsScanned := 0
for iter.Next() {
ptrs := generatePtrs(structFields)
Expand All @@ -90,20 +90,6 @@ func (s *scanner) iterSlice(iter Scannable) (int, error) {
}

func (s *scanner) iterSingle(iter Scannable) (int, error) {
// If we're given a pointer address to nil, we are responsible for
// allocating it before we assign. Note that this could be a ptr to
// a ptr (and so forth)
err := allocateNilReference(s.result)
if err != nil {
return 0, err
}

outPtr := reflect.ValueOf(s.result)
outVal := outPtr.Elem()
for outVal.Kind() == reflect.Ptr {
outVal = outVal.Elem() // we will eventually get to the underlying value
}

// Extract the type of the underlying struct
resultBaseType := getNonPtrType(reflect.TypeOf(s.result))
structFields, err := s.structFields(resultBaseType)
Expand All @@ -124,6 +110,20 @@ func (s *scanner) iterSingle(iter Scannable) (int, error) {
return 0, err
}

// If we're given a pointer address to nil, we are responsible for
// allocating it before we assign. Note that this could be a ptr to
// a ptr (and so forth)
err = allocateNilReference(s.result)
if err != nil {
return 0, err
}

outPtr := reflect.ValueOf(s.result)
outVal := outPtr.Elem()
for outVal.Kind() == reflect.Ptr {
outVal = outVal.Elem() // we will eventually get to the underlying value
}

setPtrs(structFields, ptrs, outVal)

s.rowsScanned++
Expand Down
2 changes: 2 additions & 0 deletions scanner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ func TestScanIterStruct(t *testing.T) {
noResultsIter := newMockIterator([]map[string]interface{}{}, stmt.FieldNames())
rowsRead, err = newScanner(stmt, &f1).ScanIter(noResultsIter)
assert.EqualError(t, err, ":0: No rows returned")
assert.Nil(t, f1)

// Test for a non-rows-not-found error
var g1 *Account
Expand All @@ -193,6 +194,7 @@ func TestScanIterStruct(t *testing.T) {
rowsRead, err = errorScanner.ScanIter(errorerIter)
assert.Equal(t, 0, rowsRead)
assert.Equal(t, err, expectedErr)
assert.Nil(t, g1)
}

func TestScanIterComposite(t *testing.T) {
Expand Down