diff --git a/CHANGELOG.md b/CHANGELOG.md index bed4bea..39795b2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,11 @@ All notable changes to this project will be documented in this file. This project adheres to [Semantic Versioning](http://semver.org/). +## v2.0.2 - 2019-07-01 + +### Fixed + - If an error is thrown whilst scanning/reading a single row, the original pointer will stay untouched (previously it may have been allocated with the initial value of your struct) + ## v2.0.1 - 2019-06-28 ### Changed diff --git a/doc.go b/doc.go index a637445..d9625a7 100644 --- a/doc.go +++ b/doc.go @@ -1,6 +1,6 @@ // Package gocassa is a high-level library on top of gocql // -// Current version: v2.0.1 +// Current version: v2.0.2 // Compared to gocql it provides query building, adds data binding, and provides // easy-to-use "recipe" tables for common query use-cases. Unlike cqlc, it does // not use code generation. diff --git a/scanner.go b/scanner.go index 66a0097..c7d5cf7 100644 --- a/scanner.go +++ b/scanner.go @@ -40,19 +40,25 @@ func (s *scanner) ScanIter(iter Scannable) (int, error) { } func (s *scanner) iterSlice(iter Scannable) (int, error) { + // Extract the type of the slice + sliceType := getNonPtrType(reflect.TypeOf(s.result)) + sliceElemType := sliceType.Elem() + sliceElemValType := getNonPtrType(sliceType.Elem()) + + // Extract the type of the underlying struct + structFields, err := s.structFields(sliceElemValType) + if err != nil { + return 0, err + } + // If we're given a pointer address to nil, we are responsible for // allocating it before we assign. Note that this could be a ptr to // a ptr (and so forth) - err := allocateNilReference(s.result) + err = allocateNilReference(s.result) if err != nil { return 0, err } - // Extract the type of the slice - sliceType := getNonPtrType(reflect.TypeOf(s.result)) - sliceElemType := sliceType.Elem() - sliceElemValType := getNonPtrType(sliceType.Elem()) - // To preserve prior behaviour, if the result slice is not empty // then allocate a new slice and set it as the value sliceElem := reflect.ValueOf(s.result) @@ -63,12 +69,6 @@ func (s *scanner) iterSlice(iter Scannable) (int, error) { sliceElem.Set(reflect.Zero(sliceType)) } - // Extract the type of the underlying struct - structFields, err := s.structFields(sliceElemValType) - if err != nil { - return 0, err - } - rowsScanned := 0 for iter.Next() { ptrs := generatePtrs(structFields) @@ -90,20 +90,6 @@ func (s *scanner) iterSlice(iter Scannable) (int, error) { } func (s *scanner) iterSingle(iter Scannable) (int, error) { - // If we're given a pointer address to nil, we are responsible for - // allocating it before we assign. Note that this could be a ptr to - // a ptr (and so forth) - err := allocateNilReference(s.result) - if err != nil { - return 0, err - } - - outPtr := reflect.ValueOf(s.result) - outVal := outPtr.Elem() - for outVal.Kind() == reflect.Ptr { - outVal = outVal.Elem() // we will eventually get to the underlying value - } - // Extract the type of the underlying struct resultBaseType := getNonPtrType(reflect.TypeOf(s.result)) structFields, err := s.structFields(resultBaseType) @@ -124,6 +110,20 @@ func (s *scanner) iterSingle(iter Scannable) (int, error) { return 0, err } + // If we're given a pointer address to nil, we are responsible for + // allocating it before we assign. Note that this could be a ptr to + // a ptr (and so forth) + err = allocateNilReference(s.result) + if err != nil { + return 0, err + } + + outPtr := reflect.ValueOf(s.result) + outVal := outPtr.Elem() + for outVal.Kind() == reflect.Ptr { + outVal = outVal.Elem() // we will eventually get to the underlying value + } + setPtrs(structFields, ptrs, outVal) s.rowsScanned++ diff --git a/scanner_test.go b/scanner_test.go index 0a4e70a..5bd78a9 100644 --- a/scanner_test.go +++ b/scanner_test.go @@ -183,6 +183,7 @@ func TestScanIterStruct(t *testing.T) { noResultsIter := newMockIterator([]map[string]interface{}{}, stmt.FieldNames()) rowsRead, err = newScanner(stmt, &f1).ScanIter(noResultsIter) assert.EqualError(t, err, ":0: No rows returned") + assert.Nil(t, f1) // Test for a non-rows-not-found error var g1 *Account @@ -193,6 +194,7 @@ func TestScanIterStruct(t *testing.T) { rowsRead, err = errorScanner.ScanIter(errorerIter) assert.Equal(t, 0, rowsRead) assert.Equal(t, err, expectedErr) + assert.Nil(t, g1) } func TestScanIterComposite(t *testing.T) {