diff --git a/pgtype/float4.go b/pgtype/float4.go index 2540f9e51..91ca01473 100644 --- a/pgtype/float4.go +++ b/pgtype/float4.go @@ -3,6 +3,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "encoding/json" "fmt" "math" "strconv" @@ -65,6 +66,29 @@ func (f Float4) Value() (driver.Value, error) { return float64(f.Float32), nil } +func (f Float4) MarshalJSON() ([]byte, error) { + if !f.Valid { + return []byte("null"), nil + } + return json.Marshal(f.Float32) +} + +func (f *Float4) UnmarshalJSON(b []byte) error { + var n *float32 + err := json.Unmarshal(b, &n) + if err != nil { + return err + } + + if n == nil { + *f = Float4{} + } else { + *f = Float4{Float32: *n, Valid: true} + } + + return nil +} + type Float4Codec struct{} func (Float4Codec) FormatSupported(format int16) bool { diff --git a/pgtype/float4_test.go b/pgtype/float4_test.go index f155ed976..bc74921cf 100644 --- a/pgtype/float4_test.go +++ b/pgtype/float4_test.go @@ -21,3 +21,44 @@ func TestFloat4Codec(t *testing.T) { {nil, new(*float32), isExpectedEq((*float32)(nil))}, }) } + +func TestFloat4MarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Float4 + result string + }{ + {source: pgtype.Float4{Float32: 0}, result: "null"}, + {source: pgtype.Float4{Float32: 1.23, Valid: true}, result: "1.23"}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + +func TestFloat4UnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Float4 + }{ + {source: "null", result: pgtype.Float4{Float32: 0}}, + {source: "1.23", result: pgtype.Float4{Float32: 1.23, Valid: true}}, + } + for i, tt := range successfulTests { + var r pgtype.Float4 + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/pgtype/float8.go b/pgtype/float8.go index 65e5d8b32..9c923c9a3 100644 --- a/pgtype/float8.go +++ b/pgtype/float8.go @@ -74,6 +74,29 @@ func (f Float8) Value() (driver.Value, error) { return f.Float64, nil } +func (f Float8) MarshalJSON() ([]byte, error) { + if !f.Valid { + return []byte("null"), nil + } + return json.Marshal(f.Float64) +} + +func (f *Float8) UnmarshalJSON(b []byte) error { + var n *float64 + err := json.Unmarshal(b, &n) + if err != nil { + return err + } + + if n == nil { + *f = Float8{} + } else { + *f = Float8{Float64: *n, Valid: true} + } + + return nil +} + type Float8Codec struct{} func (Float8Codec) FormatSupported(format int16) bool { @@ -109,13 +132,6 @@ func (Float8Codec) PlanEncode(m *Map, oid uint32, format int16, value any) Encod return nil } -func (f *Float8) MarshalJSON() ([]byte, error) { - if !f.Valid { - return []byte("null"), nil - } - return json.Marshal(f.Float64) -} - type encodePlanFloat8CodecBinaryFloat64 struct{} func (encodePlanFloat8CodecBinaryFloat64) Encode(value any, buf []byte) (newBuf []byte, err error) { diff --git a/pgtype/float8_test.go b/pgtype/float8_test.go index 496b718b3..64593d97c 100644 --- a/pgtype/float8_test.go +++ b/pgtype/float8_test.go @@ -21,3 +21,44 @@ func TestFloat8Codec(t *testing.T) { {nil, new(*float64), isExpectedEq((*float64)(nil))}, }) } + +func TestFloat8MarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Float8 + result string + }{ + {source: pgtype.Float8{Float64: 0}, result: "null"}, + {source: pgtype.Float8{Float64: 1.23, Valid: true}, result: "1.23"}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + +func TestFloat8UnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Float8 + }{ + {source: "null", result: pgtype.Float8{Float64: 0}}, + {source: "1.23", result: pgtype.Float8{Float64: 1.23, Valid: true}}, + } + for i, tt := range successfulTests { + var r pgtype.Float8 + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +}