diff --git a/borsh-derive/src/internals/attributes/item/mod.rs b/borsh-derive/src/internals/attributes/item/mod.rs index ff3551486..f85ec5ace 100644 --- a/borsh-derive/src/internals/attributes/item/mod.rs +++ b/borsh-derive/src/internals/attributes/item/mod.rs @@ -1,26 +1,28 @@ use crate::internals::attributes::{BORSH, CRATE, INIT, USE_DISCRIMINANT}; +use proc_macro2::Span; use quote::ToTokens; -use syn::{spanned::Spanned, Attribute, DeriveInput, Error, Expr, ItemEnum, Path}; +use syn::{spanned::Spanned, Attribute, DeriveInput, Error, Expr, ItemEnum, Path, TypePath}; -use super::{get_one_attribute, parsing}; +use super::{get_one_attribute, parsing, RUST_REPR, TAG_WIDTH}; pub fn check_attributes(derive_input: &DeriveInput) -> Result<(), Error> { let borsh = get_one_attribute(&derive_input.attrs)?; if let Some(attr) = borsh { attr.parse_nested_meta(|meta| { - if meta.path != USE_DISCRIMINANT && meta.path != INIT && meta.path != CRATE { + if meta.path != USE_DISCRIMINANT && meta.path != INIT && meta.path != CRATE && meta.path != TAG_WIDTH { return Err(syn::Error::new( meta.path.span(), - "`crate`, `use_discriminant` or `init` are the only supported attributes for `borsh`", + "`crate`, `use_discriminant`, `tag_width` or `init` are the only supported attributes for `borsh`", )); } - if meta.path == USE_DISCRIMINANT { + if meta.path == USE_DISCRIMINANT || meta.path == TAG_WIDTH { + let msg = if meta.path == USE_DISCRIMINANT { "borsh(use_discriminant=)"} else { "borsh(tag_width=)"}; let _expr: Expr = meta.value()?.parse()?; if let syn::Data::Struct(ref _data) = derive_input.data { return Err(syn::Error::new( derive_input.ident.span(), - "borsh(use_discriminant=) does not support structs", + format!("{msg} does not support structs"), )); } } else if meta.path == INIT || meta.path == CRATE { @@ -34,14 +36,13 @@ pub fn check_attributes(derive_input: &DeriveInput) -> Result<(), Error> { } pub(crate) fn contains_use_discriminant(input: &ItemEnum) -> Result { - if input.variants.len() > 256 { + if input.variants.len() > u8::MAX as usize + 1 { return Err(syn::Error::new( input.span(), - "up to 256 enum variants are supported", + format!("up to {} enum variants are supported", u8::MAX as usize + 1), )); } - - let attrs = &input.attrs; + let attrs: &Vec = &input.attrs; let mut use_discriminant = None; let attr = attrs.iter().find(|attr| attr.path() == BORSH); if let Some(attr) = attr { @@ -61,7 +62,7 @@ pub(crate) fn contains_use_discriminant(input: &ItemEnum) -> Result Result Option<(TypePath, Span)> { + input + .attrs + .iter() + .find(|attr| attr.path() == RUST_REPR) + .map(|attr| { + attr.parse_args::() + .map(|value| (attr, value)) + .unwrap() + }) + .map(|(attr, value)| (value, attr.span())) +} + +pub(crate) fn get_maybe_borsh_tag_width( + input: &ItemEnum, +) -> Result, syn::Error> { + let mut maybe_borsh_tag_width = None; + let attr = input.attrs.iter().find(|attr| attr.path() == BORSH); + let Some(attr) = attr else { + return Ok(None); + }; + + attr.parse_nested_meta(|meta| { + if meta.path == TAG_WIDTH { + let value_expr: Expr = meta.value()?.parse()?; + let value = value_expr.to_token_stream().to_string(); + let value = value + .parse::() + .map_err(|_| syn::Error::new(value_expr.span(), "`tag_width` accepts only u8"))?; + if value > 8 { + return Err(syn::Error::new( + value_expr.span(), + "`tag_width` accepts only values from 0 to 8", + )); + } + maybe_borsh_tag_width = Some((value, value_expr.span())); + } else if meta.path == INIT + || meta.path == CRATE + || meta.path == TAG_WIDTH + || meta.path == USE_DISCRIMINANT + { + let _value_expr: Expr = meta.value()?.parse()?; + } + Ok(()) + })?; + Ok(maybe_borsh_tag_width) +} + pub(crate) fn contains_initialize_with(attrs: &[Attribute]) -> Result, Error> { let mut res = None; let attr = attrs.iter().find(|attr| attr.path() == BORSH); @@ -88,7 +137,8 @@ pub(crate) fn contains_initialize_with(attrs: &[Attribute]) -> Result Result, Error> { if meta.path == CRATE { let value_expr: Path = parsing::parse_lit_into(BORSH, CRATE, &meta)?; res = Some(value_expr); - } else if meta.path == USE_DISCRIMINANT || meta.path == INIT { + } else if meta.path == USE_DISCRIMINANT || meta.path == INIT || meta.path == TAG_WIDTH { let _value_expr: Expr = meta.value()?.parse()?; } diff --git a/borsh-derive/src/internals/attributes/item/snapshots/check_attrs_borsh_invalid_on_whole_item.snap b/borsh-derive/src/internals/attributes/item/snapshots/check_attrs_borsh_invalid_on_whole_item.snap index f2a856a52..f6d28ffa8 100644 --- a/borsh-derive/src/internals/attributes/item/snapshots/check_attrs_borsh_invalid_on_whole_item.snap +++ b/borsh-derive/src/internals/attributes/item/snapshots/check_attrs_borsh_invalid_on_whole_item.snap @@ -1,7 +1,8 @@ --- source: borsh-derive/src/internals/attributes/item/mod.rs expression: actual.unwrap_err() +snapshot_kind: text --- Error( - "`crate`, `use_discriminant` or `init` are the only supported attributes for `borsh`", + "`crate`, `use_discriminant`, `tag_width` or `init` are the only supported attributes for `borsh`", ) diff --git a/borsh-derive/src/internals/attributes/item/snapshots/check_attrs_borsh_skip_on_whole_item.snap b/borsh-derive/src/internals/attributes/item/snapshots/check_attrs_borsh_skip_on_whole_item.snap index f2a856a52..f6d28ffa8 100644 --- a/borsh-derive/src/internals/attributes/item/snapshots/check_attrs_borsh_skip_on_whole_item.snap +++ b/borsh-derive/src/internals/attributes/item/snapshots/check_attrs_borsh_skip_on_whole_item.snap @@ -1,7 +1,8 @@ --- source: borsh-derive/src/internals/attributes/item/mod.rs expression: actual.unwrap_err() +snapshot_kind: text --- Error( - "`crate`, `use_discriminant` or `init` are the only supported attributes for `borsh`", + "`crate`, `use_discriminant`, `tag_width` or `init` are the only supported attributes for `borsh`", ) diff --git a/borsh-derive/src/internals/attributes/item/snapshots/check_attrs_init_function_wrong_format.snap b/borsh-derive/src/internals/attributes/item/snapshots/check_attrs_init_function_wrong_format.snap index f2a856a52..f6d28ffa8 100644 --- a/borsh-derive/src/internals/attributes/item/snapshots/check_attrs_init_function_wrong_format.snap +++ b/borsh-derive/src/internals/attributes/item/snapshots/check_attrs_init_function_wrong_format.snap @@ -1,7 +1,8 @@ --- source: borsh-derive/src/internals/attributes/item/mod.rs expression: actual.unwrap_err() +snapshot_kind: text --- Error( - "`crate`, `use_discriminant` or `init` are the only supported attributes for `borsh`", + "`crate`, `use_discriminant`, `tag_width` or `init` are the only supported attributes for `borsh`", ) diff --git a/borsh-derive/src/internals/attributes/mod.rs b/borsh-derive/src/internals/attributes/mod.rs index 4c5a69d4d..cf97228cf 100644 --- a/borsh-derive/src/internals/attributes/mod.rs +++ b/borsh-derive/src/internals/attributes/mod.rs @@ -30,6 +30,11 @@ pub const DESERIALIZE_WITH: Symbol = Symbol("deserialize_with", "deserialize_wit /// crate - sub-borsh nested meta, item-level only, `BorshSerialize`, `BorshDeserialize`, `BorshSchema` contexts pub const CRATE: Symbol = Symbol("crate", "crate = ..."); +/// tag_width - sub-borsh nested meta, item-level only attribute in `BorshSerialize`, `BorshDeserialize`, `BorshSchema` contexts +pub const TAG_WIDTH: Symbol = Symbol("tag_width", "tag_width = ..."); + +pub const RUST_REPR: Symbol = Symbol("repr", "repr(...)"); + #[cfg(feature = "schema")] pub mod schema_keys { use super::Symbol; diff --git a/borsh-derive/src/internals/deserialize/enums/mod.rs b/borsh-derive/src/internals/deserialize/enums/mod.rs index fb405e90d..c9e7eb7d1 100644 --- a/borsh-derive/src/internals/deserialize/enums/mod.rs +++ b/borsh-derive/src/internals/deserialize/enums/mod.rs @@ -1,6 +1,6 @@ use proc_macro2::TokenStream as TokenStream2; use quote::quote; -use syn::{Fields, ItemEnum, Path, Variant}; +use syn::{Fields, ItemEnum, Path, TypePath, Variant}; use crate::internals::{attributes::item, deserialize, enum_discriminant::Discriminants, generics}; @@ -11,14 +11,21 @@ pub fn process(input: &ItemEnum, cratename: Path) -> syn::Result { let mut where_clause = generics::default_where(where_clause); let mut variant_arms = TokenStream2::new(); let use_discriminant = item::contains_use_discriminant(input)?; - let discriminants = Discriminants::new(&input.variants); + let maybe_borsh_tag_width = item::get_maybe_borsh_tag_width(input)?; + let maybe_rust_repr = item::get_maybe_rust_repr(input); + let discriminants = Discriminants::new( + &input.variants, + maybe_borsh_tag_width, + maybe_rust_repr, + use_discriminant, + )?; let mut generics_output = deserialize::GenericsOutput::new(&generics); - + let discriminant_type = discriminants.discriminant_type(); for (variant_idx, variant) in input.variants.iter().enumerate() { let variant_body = process_variant(variant, &cratename, &mut generics_output)?; let variant_ident = &variant.ident; - let discriminant_value = discriminants.get(variant_ident, use_discriminant, variant_idx)?; + let discriminant_value = discriminants.get(variant_ident, variant_idx)?; variant_arms.extend(quote! { if variant_tag == #discriminant_value { #name::#variant_ident #variant_body } else }); @@ -32,30 +39,48 @@ pub fn process(input: &ItemEnum, cratename: Path) -> syn::Result { }; generics_output.extend(&mut where_clause, &cratename); - Ok(quote! { + let deserialize_variant = quote! { + let mut return_value = + #variant_arms { + return Err(#cratename::io::Error::new( + #cratename::io::ErrorKind::InvalidData, + #cratename::__private::maybestd::format!("Unexpected variant tag: {:?}", variant_tag), + )) + }; + #init + Ok(return_value) + }; + + let deserialize = quote! { impl #impl_generics #cratename::de::BorshDeserialize for #name #ty_generics #where_clause { fn deserialize_reader<__R: #cratename::io::Read>(reader: &mut __R) -> ::core::result::Result { - let tag = ::deserialize_reader(reader)?; - ::deserialize_variant(reader, tag) + let variant_tag = <#discriminant_type as #cratename::de::BorshDeserialize>::deserialize_reader(reader)?; + #deserialize_variant } } + }; - impl #impl_generics #cratename::de::EnumExt for #name #ty_generics #where_clause { - fn deserialize_variant<__R: #cratename::io::Read>( - reader: &mut __R, - variant_tag: u8, - ) -> ::core::result::Result { - let mut return_value = - #variant_arms { - return Err(#cratename::io::Error::new( - #cratename::io::ErrorKind::InvalidData, - #cratename::__private::maybestd::format!("Unexpected variant tag: {:?}", variant_tag), - )) - }; - #init - Ok(return_value) + let impl_trait = if discriminant_type.path.get_ident() + == (syn::parse_str::("u8").unwrap().path.get_ident()) + { + quote! { + impl #impl_generics #cratename::de::EnumExt for #name #ty_generics #where_clause { + fn deserialize_variant<__R: #cratename::io::Read>( + reader: &mut __R, + variant_tag: u8, + ) -> ::core::result::Result { + #deserialize_variant + } } } + } else { + quote! {} + }; + + Ok(quote! { + #deserialize + + #impl_trait }) } diff --git a/borsh-derive/src/internals/deserialize/enums/snapshots/borsh_discriminant_false.snap b/borsh-derive/src/internals/deserialize/enums/snapshots/borsh_discriminant_false.snap index 1e0446681..824f55a0b 100644 --- a/borsh-derive/src/internals/deserialize/enums/snapshots/borsh_discriminant_false.snap +++ b/borsh-derive/src/internals/deserialize/enums/snapshots/borsh_discriminant_false.snap @@ -1,13 +1,38 @@ --- source: borsh-derive/src/internals/deserialize/enums/mod.rs expression: pretty_print_syn_str(&actual).unwrap() +snapshot_kind: text --- impl borsh::de::BorshDeserialize for X { fn deserialize_reader<__R: borsh::io::Read>( reader: &mut __R, ) -> ::core::result::Result { - let tag = ::deserialize_reader(reader)?; - ::deserialize_variant(reader, tag) + let variant_tag = ::deserialize_reader( + reader, + )?; + let mut return_value = if variant_tag == 0u8 { + X::A + } else if variant_tag == 1u8 { + X::B + } else if variant_tag == 2u8 { + X::C + } else if variant_tag == 3u8 { + X::D + } else if variant_tag == 4u8 { + X::E + } else if variant_tag == 5u8 { + X::F + } else { + return Err( + borsh::io::Error::new( + borsh::io::ErrorKind::InvalidData, + borsh::__private::maybestd::format!( + "Unexpected variant tag: {:?}", variant_tag + ), + ), + ) + }; + Ok(return_value) } } impl borsh::de::EnumExt for X { diff --git a/borsh-derive/src/internals/deserialize/enums/snapshots/borsh_discriminant_true.snap b/borsh-derive/src/internals/deserialize/enums/snapshots/borsh_discriminant_true.snap index add0f62ed..aa8337a8f 100644 --- a/borsh-derive/src/internals/deserialize/enums/snapshots/borsh_discriminant_true.snap +++ b/borsh-derive/src/internals/deserialize/enums/snapshots/borsh_discriminant_true.snap @@ -1,13 +1,38 @@ --- source: borsh-derive/src/internals/deserialize/enums/mod.rs expression: pretty_print_syn_str(&actual).unwrap() +snapshot_kind: text --- impl borsh::de::BorshDeserialize for X { fn deserialize_reader<__R: borsh::io::Read>( reader: &mut __R, ) -> ::core::result::Result { - let tag = ::deserialize_reader(reader)?; - ::deserialize_variant(reader, tag) + let variant_tag = ::deserialize_reader( + reader, + )?; + let mut return_value = if variant_tag == 0 { + X::A + } else if variant_tag == 20 { + X::B + } else if variant_tag == 20 + 1 { + X::C + } else if variant_tag == 20 + 1 + 1 { + X::D + } else if variant_tag == 10 { + X::E + } else if variant_tag == 10 + 1 { + X::F + } else { + return Err( + borsh::io::Error::new( + borsh::io::ErrorKind::InvalidData, + borsh::__private::maybestd::format!( + "Unexpected variant tag: {:?}", variant_tag + ), + ), + ) + }; + Ok(return_value) } } impl borsh::de::EnumExt for X { diff --git a/borsh-derive/src/internals/deserialize/enums/snapshots/borsh_init_func.snap b/borsh-derive/src/internals/deserialize/enums/snapshots/borsh_init_func.snap index 28dd7dbce..51f383cd5 100644 --- a/borsh-derive/src/internals/deserialize/enums/snapshots/borsh_init_func.snap +++ b/borsh-derive/src/internals/deserialize/enums/snapshots/borsh_init_func.snap @@ -1,13 +1,39 @@ --- source: borsh-derive/src/internals/deserialize/enums/mod.rs expression: pretty_print_syn_str(&actual).unwrap() +snapshot_kind: text --- impl borsh::de::BorshDeserialize for A { fn deserialize_reader<__R: borsh::io::Read>( reader: &mut __R, ) -> ::core::result::Result { - let tag = ::deserialize_reader(reader)?; - ::deserialize_variant(reader, tag) + let variant_tag = ::deserialize_reader( + reader, + )?; + let mut return_value = if variant_tag == 0u8 { + A::A + } else if variant_tag == 1u8 { + A::B + } else if variant_tag == 2u8 { + A::C + } else if variant_tag == 3u8 { + A::D + } else if variant_tag == 4u8 { + A::E + } else if variant_tag == 5u8 { + A::F + } else { + return Err( + borsh::io::Error::new( + borsh::io::ErrorKind::InvalidData, + borsh::__private::maybestd::format!( + "Unexpected variant tag: {:?}", variant_tag + ), + ), + ) + }; + return_value.initialization_method(); + Ok(return_value) } } impl borsh::de::EnumExt for A { diff --git a/borsh-derive/src/internals/deserialize/enums/snapshots/borsh_skip_struct_variant_field.snap b/borsh-derive/src/internals/deserialize/enums/snapshots/borsh_skip_struct_variant_field.snap index a8a2d9e1a..edbc475c0 100644 --- a/borsh-derive/src/internals/deserialize/enums/snapshots/borsh_skip_struct_variant_field.snap +++ b/borsh-derive/src/internals/deserialize/enums/snapshots/borsh_skip_struct_variant_field.snap @@ -1,13 +1,35 @@ --- source: borsh-derive/src/internals/deserialize/enums/mod.rs expression: pretty_print_syn_str(&actual).unwrap() +snapshot_kind: text --- impl borsh::de::BorshDeserialize for AA { fn deserialize_reader<__R: borsh::io::Read>( reader: &mut __R, ) -> ::core::result::Result { - let tag = ::deserialize_reader(reader)?; - ::deserialize_variant(reader, tag) + let variant_tag = ::deserialize_reader( + reader, + )?; + let mut return_value = if variant_tag == 0u8 { + AA::B { + c: core::default::Default::default(), + d: borsh::BorshDeserialize::deserialize_reader(reader)?, + } + } else if variant_tag == 1u8 { + AA::NegatedVariant { + beta: borsh::BorshDeserialize::deserialize_reader(reader)?, + } + } else { + return Err( + borsh::io::Error::new( + borsh::io::ErrorKind::InvalidData, + borsh::__private::maybestd::format!( + "Unexpected variant tag: {:?}", variant_tag + ), + ), + ) + }; + Ok(return_value) } } impl borsh::de::EnumExt for AA { diff --git a/borsh-derive/src/internals/deserialize/enums/snapshots/borsh_skip_tuple_variant_field.snap b/borsh-derive/src/internals/deserialize/enums/snapshots/borsh_skip_tuple_variant_field.snap index 60149fc60..9c406ba31 100644 --- a/borsh-derive/src/internals/deserialize/enums/snapshots/borsh_skip_tuple_variant_field.snap +++ b/borsh-derive/src/internals/deserialize/enums/snapshots/borsh_skip_tuple_variant_field.snap @@ -1,13 +1,35 @@ --- source: borsh-derive/src/internals/deserialize/enums/mod.rs expression: pretty_print_syn_str(&actual).unwrap() +snapshot_kind: text --- impl borsh::de::BorshDeserialize for AAT { fn deserialize_reader<__R: borsh::io::Read>( reader: &mut __R, ) -> ::core::result::Result { - let tag = ::deserialize_reader(reader)?; - ::deserialize_variant(reader, tag) + let variant_tag = ::deserialize_reader( + reader, + )?; + let mut return_value = if variant_tag == 0u8 { + AAT::B( + core::default::Default::default(), + borsh::BorshDeserialize::deserialize_reader(reader)?, + ) + } else if variant_tag == 1u8 { + AAT::NegatedVariant { + beta: borsh::BorshDeserialize::deserialize_reader(reader)?, + } + } else { + return Err( + borsh::io::Error::new( + borsh::io::ErrorKind::InvalidData, + borsh::__private::maybestd::format!( + "Unexpected variant tag: {:?}", variant_tag + ), + ), + ) + }; + Ok(return_value) } } impl borsh::de::EnumExt for AAT { diff --git a/borsh-derive/src/internals/deserialize/enums/snapshots/bound_generics.snap b/borsh-derive/src/internals/deserialize/enums/snapshots/bound_generics.snap index 6b31f4bc6..66ea7cf79 100644 --- a/borsh-derive/src/internals/deserialize/enums/snapshots/bound_generics.snap +++ b/borsh-derive/src/internals/deserialize/enums/snapshots/bound_generics.snap @@ -1,6 +1,7 @@ --- source: borsh-derive/src/internals/deserialize/enums/mod.rs expression: pretty_print_syn_str(&actual).unwrap() +snapshot_kind: text --- impl borsh::de::BorshDeserialize for A where @@ -12,8 +13,30 @@ where fn deserialize_reader<__R: borsh::io::Read>( reader: &mut __R, ) -> ::core::result::Result { - let tag = ::deserialize_reader(reader)?; - ::deserialize_variant(reader, tag) + let variant_tag = ::deserialize_reader( + reader, + )?; + let mut return_value = if variant_tag == 0u8 { + A::B { + x: borsh::BorshDeserialize::deserialize_reader(reader)?, + y: borsh::BorshDeserialize::deserialize_reader(reader)?, + } + } else if variant_tag == 1u8 { + A::C( + borsh::BorshDeserialize::deserialize_reader(reader)?, + borsh::BorshDeserialize::deserialize_reader(reader)?, + ) + } else { + return Err( + borsh::io::Error::new( + borsh::io::ErrorKind::InvalidData, + borsh::__private::maybestd::format!( + "Unexpected variant tag: {:?}", variant_tag + ), + ), + ) + }; + Ok(return_value) } } impl borsh::de::EnumExt for A diff --git a/borsh-derive/src/internals/deserialize/enums/snapshots/check_deserialize_with_attr.snap b/borsh-derive/src/internals/deserialize/enums/snapshots/check_deserialize_with_attr.snap index 968e0c3b5..12c008cc5 100644 --- a/borsh-derive/src/internals/deserialize/enums/snapshots/check_deserialize_with_attr.snap +++ b/borsh-derive/src/internals/deserialize/enums/snapshots/check_deserialize_with_attr.snap @@ -1,6 +1,7 @@ --- source: borsh-derive/src/internals/deserialize/enums/mod.rs expression: pretty_print_syn_str(&actual).unwrap() +snapshot_kind: text --- impl borsh::de::BorshDeserialize for C where @@ -10,8 +11,30 @@ where fn deserialize_reader<__R: borsh::io::Read>( reader: &mut __R, ) -> ::core::result::Result { - let tag = ::deserialize_reader(reader)?; - ::deserialize_variant(reader, tag) + let variant_tag = ::deserialize_reader( + reader, + )?; + let mut return_value = if variant_tag == 0u8 { + C::C3( + borsh::BorshDeserialize::deserialize_reader(reader)?, + borsh::BorshDeserialize::deserialize_reader(reader)?, + ) + } else if variant_tag == 1u8 { + C::C4 { + x: borsh::BorshDeserialize::deserialize_reader(reader)?, + y: third_party_impl::deserialize_third_party(reader)?, + } + } else { + return Err( + borsh::io::Error::new( + borsh::io::ErrorKind::InvalidData, + borsh::__private::maybestd::format!( + "Unexpected variant tag: {:?}", variant_tag + ), + ), + ) + }; + Ok(return_value) } } impl borsh::de::EnumExt for C diff --git a/borsh-derive/src/internals/deserialize/enums/snapshots/generic_borsh_skip_struct_field.snap b/borsh-derive/src/internals/deserialize/enums/snapshots/generic_borsh_skip_struct_field.snap index cde62e488..20af08a4e 100644 --- a/borsh-derive/src/internals/deserialize/enums/snapshots/generic_borsh_skip_struct_field.snap +++ b/borsh-derive/src/internals/deserialize/enums/snapshots/generic_borsh_skip_struct_field.snap @@ -1,6 +1,7 @@ --- source: borsh-derive/src/internals/deserialize/enums/mod.rs expression: pretty_print_syn_str(&actual).unwrap() +snapshot_kind: text --- impl borsh::de::BorshDeserialize for A where @@ -13,8 +14,30 @@ where fn deserialize_reader<__R: borsh::io::Read>( reader: &mut __R, ) -> ::core::result::Result { - let tag = ::deserialize_reader(reader)?; - ::deserialize_variant(reader, tag) + let variant_tag = ::deserialize_reader( + reader, + )?; + let mut return_value = if variant_tag == 0u8 { + A::B { + x: core::default::Default::default(), + y: borsh::BorshDeserialize::deserialize_reader(reader)?, + } + } else if variant_tag == 1u8 { + A::C( + borsh::BorshDeserialize::deserialize_reader(reader)?, + borsh::BorshDeserialize::deserialize_reader(reader)?, + ) + } else { + return Err( + borsh::io::Error::new( + borsh::io::ErrorKind::InvalidData, + borsh::__private::maybestd::format!( + "Unexpected variant tag: {:?}", variant_tag + ), + ), + ) + }; + Ok(return_value) } } impl borsh::de::EnumExt for A diff --git a/borsh-derive/src/internals/deserialize/enums/snapshots/generic_borsh_skip_tuple_field.snap b/borsh-derive/src/internals/deserialize/enums/snapshots/generic_borsh_skip_tuple_field.snap index 0cc108681..58ea76ad7 100644 --- a/borsh-derive/src/internals/deserialize/enums/snapshots/generic_borsh_skip_tuple_field.snap +++ b/borsh-derive/src/internals/deserialize/enums/snapshots/generic_borsh_skip_tuple_field.snap @@ -1,6 +1,7 @@ --- source: borsh-derive/src/internals/deserialize/enums/mod.rs expression: pretty_print_syn_str(&actual).unwrap() +snapshot_kind: text --- impl borsh::de::BorshDeserialize for A where @@ -12,8 +13,30 @@ where fn deserialize_reader<__R: borsh::io::Read>( reader: &mut __R, ) -> ::core::result::Result { - let tag = ::deserialize_reader(reader)?; - ::deserialize_variant(reader, tag) + let variant_tag = ::deserialize_reader( + reader, + )?; + let mut return_value = if variant_tag == 0u8 { + A::B { + x: borsh::BorshDeserialize::deserialize_reader(reader)?, + y: borsh::BorshDeserialize::deserialize_reader(reader)?, + } + } else if variant_tag == 1u8 { + A::C( + borsh::BorshDeserialize::deserialize_reader(reader)?, + core::default::Default::default(), + ) + } else { + return Err( + borsh::io::Error::new( + borsh::io::ErrorKind::InvalidData, + borsh::__private::maybestd::format!( + "Unexpected variant tag: {:?}", variant_tag + ), + ), + ) + }; + Ok(return_value) } } impl borsh::de::EnumExt for A diff --git a/borsh-derive/src/internals/deserialize/enums/snapshots/generic_deserialize_bound.snap b/borsh-derive/src/internals/deserialize/enums/snapshots/generic_deserialize_bound.snap index adf641118..ee110fa17 100644 --- a/borsh-derive/src/internals/deserialize/enums/snapshots/generic_deserialize_bound.snap +++ b/borsh-derive/src/internals/deserialize/enums/snapshots/generic_deserialize_bound.snap @@ -1,6 +1,7 @@ --- source: borsh-derive/src/internals/deserialize/enums/mod.rs expression: pretty_print_syn_str(&actual).unwrap() +snapshot_kind: text --- impl borsh::de::BorshDeserialize for A where @@ -10,8 +11,30 @@ where fn deserialize_reader<__R: borsh::io::Read>( reader: &mut __R, ) -> ::core::result::Result { - let tag = ::deserialize_reader(reader)?; - ::deserialize_variant(reader, tag) + let variant_tag = ::deserialize_reader( + reader, + )?; + let mut return_value = if variant_tag == 0u8 { + A::C { + a: borsh::BorshDeserialize::deserialize_reader(reader)?, + b: borsh::BorshDeserialize::deserialize_reader(reader)?, + } + } else if variant_tag == 1u8 { + A::D( + borsh::BorshDeserialize::deserialize_reader(reader)?, + borsh::BorshDeserialize::deserialize_reader(reader)?, + ) + } else { + return Err( + borsh::io::Error::new( + borsh::io::ErrorKind::InvalidData, + borsh::__private::maybestd::format!( + "Unexpected variant tag: {:?}", variant_tag + ), + ), + ) + }; + Ok(return_value) } } impl borsh::de::EnumExt for A diff --git a/borsh-derive/src/internals/deserialize/enums/snapshots/recursive_enum.snap b/borsh-derive/src/internals/deserialize/enums/snapshots/recursive_enum.snap index b3c8f7790..441e789ce 100644 --- a/borsh-derive/src/internals/deserialize/enums/snapshots/recursive_enum.snap +++ b/borsh-derive/src/internals/deserialize/enums/snapshots/recursive_enum.snap @@ -1,6 +1,7 @@ --- source: borsh-derive/src/internals/deserialize/enums/mod.rs expression: pretty_print_syn_str(&actual).unwrap() +snapshot_kind: text --- impl borsh::de::BorshDeserialize for A where @@ -11,8 +12,30 @@ where fn deserialize_reader<__R: borsh::io::Read>( reader: &mut __R, ) -> ::core::result::Result { - let tag = ::deserialize_reader(reader)?; - ::deserialize_variant(reader, tag) + let variant_tag = ::deserialize_reader( + reader, + )?; + let mut return_value = if variant_tag == 0u8 { + A::B { + x: borsh::BorshDeserialize::deserialize_reader(reader)?, + y: borsh::BorshDeserialize::deserialize_reader(reader)?, + } + } else if variant_tag == 1u8 { + A::C( + borsh::BorshDeserialize::deserialize_reader(reader)?, + borsh::BorshDeserialize::deserialize_reader(reader)?, + ) + } else { + return Err( + borsh::io::Error::new( + borsh::io::ErrorKind::InvalidData, + borsh::__private::maybestd::format!( + "Unexpected variant tag: {:?}", variant_tag + ), + ), + ) + }; + Ok(return_value) } } impl borsh::de::EnumExt for A diff --git a/borsh-derive/src/internals/deserialize/enums/snapshots/simple_enum_with_custom_crate.snap b/borsh-derive/src/internals/deserialize/enums/snapshots/simple_enum_with_custom_crate.snap index 88457ee99..cafb0cc08 100644 --- a/borsh-derive/src/internals/deserialize/enums/snapshots/simple_enum_with_custom_crate.snap +++ b/borsh-derive/src/internals/deserialize/enums/snapshots/simple_enum_with_custom_crate.snap @@ -1,15 +1,36 @@ --- source: borsh-derive/src/internals/deserialize/enums/mod.rs expression: pretty_print_syn_str(&actual).unwrap() +snapshot_kind: text --- impl reexporter::borsh::de::BorshDeserialize for A { fn deserialize_reader<__R: reexporter::borsh::io::Read>( reader: &mut __R, ) -> ::core::result::Result { - let tag = ::deserialize_reader( + let variant_tag = ::deserialize_reader( reader, )?; - ::deserialize_variant(reader, tag) + let mut return_value = if variant_tag == 0u8 { + A::B { + x: reexporter::borsh::BorshDeserialize::deserialize_reader(reader)?, + y: reexporter::borsh::BorshDeserialize::deserialize_reader(reader)?, + } + } else if variant_tag == 1u8 { + A::C( + reexporter::borsh::BorshDeserialize::deserialize_reader(reader)?, + reexporter::borsh::BorshDeserialize::deserialize_reader(reader)?, + ) + } else { + return Err( + reexporter::borsh::io::Error::new( + reexporter::borsh::io::ErrorKind::InvalidData, + reexporter::borsh::__private::maybestd::format!( + "Unexpected variant tag: {:?}", variant_tag + ), + ), + ) + }; + Ok(return_value) } } impl reexporter::borsh::de::EnumExt for A { diff --git a/borsh-derive/src/internals/deserialize/enums/snapshots/simple_generics.snap b/borsh-derive/src/internals/deserialize/enums/snapshots/simple_generics.snap index d8b85259c..f16009dd7 100644 --- a/borsh-derive/src/internals/deserialize/enums/snapshots/simple_generics.snap +++ b/borsh-derive/src/internals/deserialize/enums/snapshots/simple_generics.snap @@ -1,6 +1,7 @@ --- source: borsh-derive/src/internals/deserialize/enums/mod.rs expression: pretty_print_syn_str(&actual).unwrap() +snapshot_kind: text --- impl borsh::de::BorshDeserialize for A where @@ -11,8 +12,30 @@ where fn deserialize_reader<__R: borsh::io::Read>( reader: &mut __R, ) -> ::core::result::Result { - let tag = ::deserialize_reader(reader)?; - ::deserialize_variant(reader, tag) + let variant_tag = ::deserialize_reader( + reader, + )?; + let mut return_value = if variant_tag == 0u8 { + A::B { + x: borsh::BorshDeserialize::deserialize_reader(reader)?, + y: borsh::BorshDeserialize::deserialize_reader(reader)?, + } + } else if variant_tag == 1u8 { + A::C( + borsh::BorshDeserialize::deserialize_reader(reader)?, + borsh::BorshDeserialize::deserialize_reader(reader)?, + ) + } else { + return Err( + borsh::io::Error::new( + borsh::io::ErrorKind::InvalidData, + borsh::__private::maybestd::format!( + "Unexpected variant tag: {:?}", variant_tag + ), + ), + ) + }; + Ok(return_value) } } impl borsh::de::EnumExt for A diff --git a/borsh-derive/src/internals/enum_discriminant.rs b/borsh-derive/src/internals/enum_discriminant.rs index 03b3f5829..232d0ba3b 100644 --- a/borsh-derive/src/internals/enum_discriminant.rs +++ b/borsh-derive/src/internals/enum_discriminant.rs @@ -1,15 +1,26 @@ use std::collections::HashMap; use std::convert::TryFrom; -use proc_macro2::{Ident, TokenStream}; +use proc_macro2::{Ident, Span, TokenStream}; use quote::quote; -use syn::{punctuated::Punctuated, token::Comma, Variant}; +use syn::{punctuated::Punctuated, token::Comma, TypePath, Variant}; + +pub struct Discriminants { + variants: HashMap, + discriminant_type: syn::TypePath, + use_discriminant: bool, + tag_width: u8, +} -pub struct Discriminants(HashMap); impl Discriminants { /// Calculates the discriminant that will be assigned by the compiler. /// See: https://doc.rust-lang.org/reference/items/enumerations.html#assigning-discriminant-values - pub fn new(variants: &Punctuated) -> Self { + pub fn new( + variants: &Punctuated, + maybe_borsh_tag_width: Option<(u8, Span)>, + maybe_rust_repr: Option<(syn::TypePath, Span)>, + use_discriminant: bool, + ) -> syn::Result { let mut map = HashMap::new(); let mut next_discriminant_if_not_specified = quote! {0}; @@ -23,25 +34,90 @@ impl Discriminants { map.insert(variant.ident.clone(), this_discriminant); } - Self(map) + let mut discriminant_type: TypePath = syn::parse_str("u8").unwrap(); + let mut tag_width = 1; + + if let Some((defined_tag_width, span)) = maybe_borsh_tag_width { + tag_width = defined_tag_width; + if !use_discriminant { + return Err(syn::Error::new( + span, + "`tag_width` specifier is only allowed when `use_discriminant` is set to true", + )); + } + let Some((rust_repr, span)) = maybe_rust_repr else { + return Err(syn::Error::new( + span, + "`tag_width` specifier is only allowed when `repr` is set", + )); + }; + match rust_repr.path.get_ident() { + Some(repr_type) => { + let repr_size= match repr_type.to_string().as_str() { + "u8" => { + 1 + }, + "u16" => { + 2 + }, + "u32" => { + 4 + }, + _ => return Err(syn::Error::new( + span, + "`tag_width` specifier is only allowed when `repr` is set to a u8, u16, or u32", + )), + }; + discriminant_type = rust_repr.clone(); + + if repr_size != tag_width { + return Err(syn::Error::new( + span, + "`tag_width` specifier must match the size of the `repr` type", + )); + } + } + None => { + return Err(syn::Error::new( + span, + "`tag_width` specifier is only allowed when `repr` is set to a specific numeric type", + )); + } + } + } + + Ok(Self { + variants: map, + discriminant_type, + use_discriminant, + tag_width, + }) } - pub fn get( - &self, - variant_ident: &Ident, - use_discriminant: bool, - variant_idx: usize, - ) -> syn::Result { - let variant_idx = u8::try_from(variant_idx).map_err(|err| { - syn::Error::new( - variant_ident.span(), - format!("up to 256 enum variants are supported: {}", err), - ) - })?; - let result = if use_discriminant { - let discriminant_value = self.0.get(variant_ident).unwrap(); + pub fn discriminant_type(&self) -> &syn::TypePath { + &self.discriminant_type + } + + #[allow(dead_code)] + pub fn tag_width(&self) -> u8 { + self.tag_width + } + + pub fn get(&self, variant_ident: &Ident, variant_idx: usize) -> syn::Result { + let result = if self.use_discriminant { + let discriminant_value = self.variants.get(variant_ident).unwrap(); quote! { #discriminant_value } } else { + let variant_idx = u8::try_from(variant_idx).map_err(|err| { + syn::Error::new( + variant_ident.span(), + format!( + "up to {} enum variants are supported: {}", + u8::MAX as usize + 1, + err + ), + ) + })?; quote! { #variant_idx } }; Ok(result) diff --git a/borsh-derive/src/internals/schema/enums/mod.rs b/borsh-derive/src/internals/schema/enums/mod.rs index 95bbed06b..452ccb225 100644 --- a/borsh-derive/src/internals/schema/enums/mod.rs +++ b/borsh-derive/src/internals/schema/enums/mod.rs @@ -36,7 +36,14 @@ pub fn process(input: &ItemEnum, cratename: Path) -> syn::Result { let mut where_clause = generics::default_where(where_clause); let mut generics_output = schema::GenericsOutput::new(&generics); let use_discriminant = item::contains_use_discriminant(input)?; - let discriminants = Discriminants::new(&input.variants); + let maybe_borsh_tag_width = item::get_maybe_borsh_tag_width(input)?; + let maybe_rust_repr = item::get_maybe_rust_repr(input); + let discriminants = Discriminants::new( + &input.variants, + maybe_borsh_tag_width, + maybe_rust_repr, + use_discriminant, + )?; // Generate functions that return the schema for variants. let mut variants_defs = vec![]; @@ -46,7 +53,6 @@ pub fn process(input: &ItemEnum, cratename: Path) -> syn::Result { let discriminant_info = DiscriminantInfo { variant_idx, discriminants: &discriminants, - use_discriminant, }; let variant_output = process_variant( variant, @@ -61,13 +67,15 @@ pub fn process(input: &ItemEnum, cratename: Path) -> syn::Result { variants_defs.push(variant_output.variant_entry); } + let tag_width = discriminants.tag_width(); let type_definitions = quote! { fn add_definitions_recursively(definitions: &mut #cratename::__private::maybestd::collections::BTreeMap<#cratename::schema::Declaration, #cratename::schema::Definition>) { #inner_defs #add_recursive_defs let definition = #cratename::schema::Definition::Enum { - tag_width: 1, + tag_width: #tag_width, variants: #cratename::__private::maybestd::vec![#(#variants_defs),*], + tag_signed: false, }; #cratename::schema::add_definition(::declaration(), definition, definitions); } @@ -97,15 +105,13 @@ struct VariantOutput { struct DiscriminantInfo<'a> { variant_idx: usize, discriminants: &'a Discriminants, - use_discriminant: bool, } fn process_discriminant( variant_ident: &Ident, info: DiscriminantInfo<'_>, ) -> syn::Result { - info.discriminants - .get(variant_ident, info.use_discriminant, info.variant_idx) + info.discriminants.get(variant_ident, info.variant_idx) } fn process_variant( @@ -128,6 +134,7 @@ fn process_variant( let variant_type = quote! { <#full_variant_ident #inner_struct_ty_generics as #cratename::BorshSchema> }; + let discriminant_type = discriminant_info.discriminants.discriminant_type().clone(); let discriminant_value = process_discriminant(&variant.ident, discriminant_info)?; Ok(VariantOutput { @@ -136,7 +143,7 @@ fn process_variant( #variant_type::add_definitions_recursively(definitions); }, variant_entry: quote! { - (u8::from(#discriminant_value) as i64, + (#discriminant_type::from(#discriminant_value) as i64, #variant_name.into(), #variant_type::declaration()) }, diff --git a/borsh-derive/src/internals/schema/enums/snapshots/borsh_discriminant_false.snap b/borsh-derive/src/internals/schema/enums/snapshots/borsh_discriminant_false.snap index 4fa592109..e8009a830 100644 --- a/borsh-derive/src/internals/schema/enums/snapshots/borsh_discriminant_false.snap +++ b/borsh-derive/src/internals/schema/enums/snapshots/borsh_discriminant_false.snap @@ -1,6 +1,7 @@ --- source: borsh-derive/src/internals/schema/enums/mod.rs expression: pretty_print_syn_str(&actual).unwrap() +snapshot_kind: text --- impl borsh::BorshSchema for X { fn declaration() -> borsh::schema::Declaration { @@ -43,7 +44,7 @@ impl borsh::BorshSchema for X { ::add_definitions_recursively(definitions); ::add_definitions_recursively(definitions); let definition = borsh::schema::Definition::Enum { - tag_width: 1, + tag_width: 1u8, variants: borsh::__private::maybestd::vec![ (u8::from(0u8) as i64, "A".into(), < XA as borsh::BorshSchema > ::declaration()), (u8::from(1u8) as i64, "B".into(), < XB as @@ -54,6 +55,7 @@ impl borsh::BorshSchema for X { (u8::from(5u8) as i64, "F".into(), < XF as borsh::BorshSchema > ::declaration()) ], + tag_signed: false, }; borsh::schema::add_definition( ::declaration(), @@ -62,4 +64,3 @@ impl borsh::BorshSchema for X { ); } } - diff --git a/borsh-derive/src/internals/schema/enums/snapshots/borsh_discriminant_true.snap b/borsh-derive/src/internals/schema/enums/snapshots/borsh_discriminant_true.snap index 1b234f1df..2a600241c 100644 --- a/borsh-derive/src/internals/schema/enums/snapshots/borsh_discriminant_true.snap +++ b/borsh-derive/src/internals/schema/enums/snapshots/borsh_discriminant_true.snap @@ -1,6 +1,7 @@ --- source: borsh-derive/src/internals/schema/enums/mod.rs expression: pretty_print_syn_str(&actual).unwrap() +snapshot_kind: text --- impl borsh::BorshSchema for X { fn declaration() -> borsh::schema::Declaration { @@ -43,7 +44,7 @@ impl borsh::BorshSchema for X { ::add_definitions_recursively(definitions); ::add_definitions_recursively(definitions); let definition = borsh::schema::Definition::Enum { - tag_width: 1, + tag_width: 1u8, variants: borsh::__private::maybestd::vec![ (u8::from(0) as i64, "A".into(), < XA as borsh::BorshSchema > ::declaration()), (u8::from(20) as i64, "B".into(), < XB as @@ -54,6 +55,7 @@ impl borsh::BorshSchema for X { ::declaration()), (u8::from(10 + 1) as i64, "F".into(), < XF as borsh::BorshSchema > ::declaration()) ], + tag_signed: false, }; borsh::schema::add_definition( ::declaration(), @@ -62,4 +64,3 @@ impl borsh::BorshSchema for X { ); } } - diff --git a/borsh-derive/src/internals/schema/enums/snapshots/complex_enum.snap b/borsh-derive/src/internals/schema/enums/snapshots/complex_enum.snap index 60006941b..0d5a5b543 100644 --- a/borsh-derive/src/internals/schema/enums/snapshots/complex_enum.snap +++ b/borsh-derive/src/internals/schema/enums/snapshots/complex_enum.snap @@ -1,6 +1,7 @@ --- source: borsh-derive/src/internals/schema/enums/mod.rs expression: pretty_print_syn_str(&actual).unwrap() +snapshot_kind: text --- impl borsh::BorshSchema for A { fn declaration() -> borsh::schema::Declaration { @@ -36,7 +37,7 @@ impl borsh::BorshSchema for A { ::add_definitions_recursively(definitions); ::add_definitions_recursively(definitions); let definition = borsh::schema::Definition::Enum { - tag_width: 1, + tag_width: 1u8, variants: borsh::__private::maybestd::vec![ (u8::from(0u8) as i64, "Bacon".into(), < ABacon as borsh::BorshSchema > ::declaration()), (u8::from(1u8) as i64, "Eggs".into(), < AEggs as @@ -45,6 +46,7 @@ impl borsh::BorshSchema for A { (u8::from(3u8) as i64, "Sausage".into(), < ASausage as borsh::BorshSchema > ::declaration()) ], + tag_signed: false, }; borsh::schema::add_definition( ::declaration(), @@ -53,4 +55,3 @@ impl borsh::BorshSchema for A { ); } } - diff --git a/borsh-derive/src/internals/schema/enums/snapshots/complex_enum_generics.snap b/borsh-derive/src/internals/schema/enums/snapshots/complex_enum_generics.snap index 46c0ca102..5a7c2bcea 100644 --- a/borsh-derive/src/internals/schema/enums/snapshots/complex_enum_generics.snap +++ b/borsh-derive/src/internals/schema/enums/snapshots/complex_enum_generics.snap @@ -1,6 +1,7 @@ --- source: borsh-derive/src/internals/schema/enums/mod.rs expression: pretty_print_syn_str(&actual).unwrap() +snapshot_kind: text --- impl borsh::BorshSchema for A where @@ -44,7 +45,7 @@ where as borsh::BorshSchema>::add_definitions_recursively(definitions); as borsh::BorshSchema>::add_definitions_recursively(definitions); let definition = borsh::schema::Definition::Enum { - tag_width: 1, + tag_width: 1u8, variants: borsh::__private::maybestd::vec![ (u8::from(0u8) as i64, "Bacon".into(), < ABacon as borsh::BorshSchema > ::declaration()), (u8::from(1u8) as i64, "Eggs".into(), < AEggs as @@ -53,6 +54,7 @@ where (u8::from(3u8) as i64, "Sausage".into(), < ASausage < W > as borsh::BorshSchema > ::declaration()) ], + tag_signed: false, }; borsh::schema::add_definition( ::declaration(), @@ -61,4 +63,3 @@ where ); } } - diff --git a/borsh-derive/src/internals/schema/enums/snapshots/complex_enum_generics_borsh_skip_named_field.snap b/borsh-derive/src/internals/schema/enums/snapshots/complex_enum_generics_borsh_skip_named_field.snap index c73bb784d..b3fa481a5 100644 --- a/borsh-derive/src/internals/schema/enums/snapshots/complex_enum_generics_borsh_skip_named_field.snap +++ b/borsh-derive/src/internals/schema/enums/snapshots/complex_enum_generics_borsh_skip_named_field.snap @@ -1,6 +1,7 @@ --- source: borsh-derive/src/internals/schema/enums/mod.rs expression: pretty_print_syn_str(&actual).unwrap() +snapshot_kind: text --- impl borsh::BorshSchema for A where @@ -46,7 +47,7 @@ where as borsh::BorshSchema>::add_definitions_recursively(definitions); as borsh::BorshSchema>::add_definitions_recursively(definitions); let definition = borsh::schema::Definition::Enum { - tag_width: 1, + tag_width: 1u8, variants: borsh::__private::maybestd::vec![ (u8::from(0u8) as i64, "Bacon".into(), < ABacon as borsh::BorshSchema > ::declaration()), (u8::from(1u8) as i64, "Eggs".into(), < AEggs as @@ -55,6 +56,7 @@ where (u8::from(3u8) as i64, "Sausage".into(), < ASausage < W, U > as borsh::BorshSchema > ::declaration()) ], + tag_signed: false, }; borsh::schema::add_definition( ::declaration(), @@ -63,4 +65,3 @@ where ); } } - diff --git a/borsh-derive/src/internals/schema/enums/snapshots/complex_enum_generics_borsh_skip_tuple_field.snap b/borsh-derive/src/internals/schema/enums/snapshots/complex_enum_generics_borsh_skip_tuple_field.snap index a531044c8..5d90a6aaa 100644 --- a/borsh-derive/src/internals/schema/enums/snapshots/complex_enum_generics_borsh_skip_tuple_field.snap +++ b/borsh-derive/src/internals/schema/enums/snapshots/complex_enum_generics_borsh_skip_tuple_field.snap @@ -1,6 +1,7 @@ --- source: borsh-derive/src/internals/schema/enums/mod.rs expression: pretty_print_syn_str(&actual).unwrap() +snapshot_kind: text --- impl borsh::BorshSchema for A where @@ -46,7 +47,7 @@ where as borsh::BorshSchema>::add_definitions_recursively(definitions); as borsh::BorshSchema>::add_definitions_recursively(definitions); let definition = borsh::schema::Definition::Enum { - tag_width: 1, + tag_width: 1u8, variants: borsh::__private::maybestd::vec![ (u8::from(0u8) as i64, "Bacon".into(), < ABacon as borsh::BorshSchema > ::declaration()), (u8::from(1u8) as i64, "Eggs".into(), < AEggs as @@ -55,6 +56,7 @@ where (u8::from(3u8) as i64, "Sausage".into(), < ASausage < W > as borsh::BorshSchema > ::declaration()) ], + tag_signed: false, }; borsh::schema::add_definition( ::declaration(), @@ -63,4 +65,3 @@ where ); } } - diff --git a/borsh-derive/src/internals/schema/enums/snapshots/filter_foreign_attrs.snap b/borsh-derive/src/internals/schema/enums/snapshots/filter_foreign_attrs.snap index 3eaab191e..e166faa44 100644 --- a/borsh-derive/src/internals/schema/enums/snapshots/filter_foreign_attrs.snap +++ b/borsh-derive/src/internals/schema/enums/snapshots/filter_foreign_attrs.snap @@ -1,6 +1,7 @@ --- source: borsh-derive/src/internals/schema/enums/mod.rs expression: pretty_print_syn_str(&actual).unwrap() +snapshot_kind: text --- impl borsh::BorshSchema for A { fn declaration() -> borsh::schema::Declaration { @@ -30,12 +31,13 @@ impl borsh::BorshSchema for A { ::add_definitions_recursively(definitions); ::add_definitions_recursively(definitions); let definition = borsh::schema::Definition::Enum { - tag_width: 1, + tag_width: 1u8, variants: borsh::__private::maybestd::vec![ (u8::from(0u8) as i64, "B".into(), < AB as borsh::BorshSchema > ::declaration()), (u8::from(1u8) as i64, "Negative".into(), < ANegative as borsh::BorshSchema > ::declaration()) ], + tag_signed: false, }; borsh::schema::add_definition( ::declaration(), @@ -44,4 +46,3 @@ impl borsh::BorshSchema for A { ); } } - diff --git a/borsh-derive/src/internals/schema/enums/snapshots/generic_associated_type.snap b/borsh-derive/src/internals/schema/enums/snapshots/generic_associated_type.snap index c87de892b..d573051b7 100644 --- a/borsh-derive/src/internals/schema/enums/snapshots/generic_associated_type.snap +++ b/borsh-derive/src/internals/schema/enums/snapshots/generic_associated_type.snap @@ -1,6 +1,7 @@ --- source: borsh-derive/src/internals/schema/enums/mod.rs expression: pretty_print_syn_str(&actual).unwrap() +snapshot_kind: text --- impl borsh::BorshSchema for EnumParametrized where @@ -57,12 +58,13 @@ where T, > as borsh::BorshSchema>::add_definitions_recursively(definitions); let definition = borsh::schema::Definition::Enum { - tag_width: 1, + tag_width: 1u8, variants: borsh::__private::maybestd::vec![ (u8::from(0u8) as i64, "B".into(), < EnumParametrizedB < K, V > as borsh::BorshSchema > ::declaration()), (u8::from(1u8) as i64, "C".into(), < EnumParametrizedC < T > as borsh::BorshSchema > ::declaration()) ], + tag_signed: false, }; borsh::schema::add_definition( ::declaration(), @@ -71,4 +73,3 @@ where ); } } - diff --git a/borsh-derive/src/internals/schema/enums/snapshots/generic_associated_type_param_override.snap b/borsh-derive/src/internals/schema/enums/snapshots/generic_associated_type_param_override.snap index 9389a848d..4f6bb9f3a 100644 --- a/borsh-derive/src/internals/schema/enums/snapshots/generic_associated_type_param_override.snap +++ b/borsh-derive/src/internals/schema/enums/snapshots/generic_associated_type_param_override.snap @@ -1,6 +1,7 @@ --- source: borsh-derive/src/internals/schema/enums/mod.rs expression: pretty_print_syn_str(&actual).unwrap() +snapshot_kind: text --- impl borsh::BorshSchema for EnumParametrized where @@ -58,12 +59,13 @@ where T, > as borsh::BorshSchema>::add_definitions_recursively(definitions); let definition = borsh::schema::Definition::Enum { - tag_width: 1, + tag_width: 1u8, variants: borsh::__private::maybestd::vec![ (u8::from(0u8) as i64, "B".into(), < EnumParametrizedB < K, V > as borsh::BorshSchema > ::declaration()), (u8::from(1u8) as i64, "C".into(), < EnumParametrizedC < T > as borsh::BorshSchema > ::declaration()) ], + tag_signed: false, }; borsh::schema::add_definition( ::declaration(), @@ -72,4 +74,3 @@ where ); } } - diff --git a/borsh-derive/src/internals/schema/enums/snapshots/recursive_enum.snap b/borsh-derive/src/internals/schema/enums/snapshots/recursive_enum.snap index 856f9f630..df9e24cb6 100644 --- a/borsh-derive/src/internals/schema/enums/snapshots/recursive_enum.snap +++ b/borsh-derive/src/internals/schema/enums/snapshots/recursive_enum.snap @@ -1,6 +1,7 @@ --- source: borsh-derive/src/internals/schema/enums/mod.rs expression: pretty_print_syn_str(&actual).unwrap() +snapshot_kind: text --- impl borsh::BorshSchema for A where @@ -38,12 +39,13 @@ where as borsh::BorshSchema>::add_definitions_recursively(definitions); as borsh::BorshSchema>::add_definitions_recursively(definitions); let definition = borsh::schema::Definition::Enum { - tag_width: 1, + tag_width: 1u8, variants: borsh::__private::maybestd::vec![ (u8::from(0u8) as i64, "B".into(), < AB < K, V > as borsh::BorshSchema > ::declaration()), (u8::from(1u8) as i64, "C".into(), < AC < K > as borsh::BorshSchema > ::declaration()) ], + tag_signed: false, }; borsh::schema::add_definition( ::declaration(), @@ -52,4 +54,3 @@ where ); } } - diff --git a/borsh-derive/src/internals/schema/enums/snapshots/simple_enum.snap b/borsh-derive/src/internals/schema/enums/snapshots/simple_enum.snap index cbadbaf7f..65671829c 100644 --- a/borsh-derive/src/internals/schema/enums/snapshots/simple_enum.snap +++ b/borsh-derive/src/internals/schema/enums/snapshots/simple_enum.snap @@ -1,6 +1,7 @@ --- source: borsh-derive/src/internals/schema/enums/mod.rs expression: pretty_print_syn_str(&actual).unwrap() +snapshot_kind: text --- impl borsh::BorshSchema for A { fn declaration() -> borsh::schema::Declaration { @@ -23,12 +24,13 @@ impl borsh::BorshSchema for A { ::add_definitions_recursively(definitions); ::add_definitions_recursively(definitions); let definition = borsh::schema::Definition::Enum { - tag_width: 1, + tag_width: 1u8, variants: borsh::__private::maybestd::vec![ (u8::from(0u8) as i64, "Bacon".into(), < ABacon as borsh::BorshSchema > ::declaration()), (u8::from(1u8) as i64, "Eggs".into(), < AEggs as borsh::BorshSchema > ::declaration()) ], + tag_signed: false, }; borsh::schema::add_definition( ::declaration(), @@ -37,4 +39,3 @@ impl borsh::BorshSchema for A { ); } } - diff --git a/borsh-derive/src/internals/schema/enums/snapshots/simple_enum_with_custom_crate.snap b/borsh-derive/src/internals/schema/enums/snapshots/simple_enum_with_custom_crate.snap index 368a39d92..ff25f7db9 100644 --- a/borsh-derive/src/internals/schema/enums/snapshots/simple_enum_with_custom_crate.snap +++ b/borsh-derive/src/internals/schema/enums/snapshots/simple_enum_with_custom_crate.snap @@ -1,6 +1,7 @@ --- source: borsh-derive/src/internals/schema/enums/mod.rs expression: pretty_print_syn_str(&actual).unwrap() +snapshot_kind: text --- impl reexporter::borsh::BorshSchema for A { fn declaration() -> reexporter::borsh::schema::Declaration { @@ -27,13 +28,14 @@ impl reexporter::borsh::BorshSchema for A { definitions, ); let definition = reexporter::borsh::schema::Definition::Enum { - tag_width: 1, + tag_width: 1u8, variants: reexporter::borsh::__private::maybestd::vec![ (u8::from(0u8) as i64, "Bacon".into(), < ABacon as reexporter::borsh::BorshSchema > ::declaration()), (u8::from(1u8) as i64, "Eggs".into(), < AEggs as reexporter::borsh::BorshSchema > ::declaration()) ], + tag_signed: false, }; reexporter::borsh::schema::add_definition( ::declaration(), @@ -42,4 +44,3 @@ impl reexporter::borsh::BorshSchema for A { ); } } - diff --git a/borsh-derive/src/internals/schema/enums/snapshots/single_field_enum.snap b/borsh-derive/src/internals/schema/enums/snapshots/single_field_enum.snap index f3a886f9b..3ed379674 100644 --- a/borsh-derive/src/internals/schema/enums/snapshots/single_field_enum.snap +++ b/borsh-derive/src/internals/schema/enums/snapshots/single_field_enum.snap @@ -1,6 +1,7 @@ --- source: borsh-derive/src/internals/schema/enums/mod.rs expression: pretty_print_syn_str(&actual).unwrap() +snapshot_kind: text --- impl borsh::BorshSchema for A { fn declaration() -> borsh::schema::Declaration { @@ -18,11 +19,12 @@ impl borsh::BorshSchema for A { struct ABacon; ::add_definitions_recursively(definitions); let definition = borsh::schema::Definition::Enum { - tag_width: 1, + tag_width: 1u8, variants: borsh::__private::maybestd::vec![ (u8::from(0u8) as i64, "Bacon".into(), < ABacon as borsh::BorshSchema > ::declaration()) ], + tag_signed: false, }; borsh::schema::add_definition( ::declaration(), @@ -31,4 +33,3 @@ impl borsh::BorshSchema for A { ); } } - diff --git a/borsh-derive/src/internals/schema/enums/snapshots/trailing_comma_generics.snap b/borsh-derive/src/internals/schema/enums/snapshots/trailing_comma_generics.snap index d6dfffc9a..1dc700234 100644 --- a/borsh-derive/src/internals/schema/enums/snapshots/trailing_comma_generics.snap +++ b/borsh-derive/src/internals/schema/enums/snapshots/trailing_comma_generics.snap @@ -1,6 +1,7 @@ --- source: borsh-derive/src/internals/schema/enums/mod.rs expression: pretty_print_syn_str(&actual).unwrap() +snapshot_kind: text --- impl borsh::BorshSchema for Side where @@ -41,12 +42,13 @@ where as borsh::BorshSchema>::add_definitions_recursively(definitions); as borsh::BorshSchema>::add_definitions_recursively(definitions); let definition = borsh::schema::Definition::Enum { - tag_width: 1, + tag_width: 1u8, variants: borsh::__private::maybestd::vec![ (u8::from(0u8) as i64, "Left".into(), < SideLeft < A > as borsh::BorshSchema > ::declaration()), (u8::from(1u8) as i64, "Right" .into(), < SideRight < B > as borsh::BorshSchema > ::declaration()) ], + tag_signed: false, }; borsh::schema::add_definition( ::declaration(), @@ -55,4 +57,3 @@ where ); } } - diff --git a/borsh-derive/src/internals/schema/enums/snapshots/with_funcs_attr.snap b/borsh-derive/src/internals/schema/enums/snapshots/with_funcs_attr.snap index 17f16077c..d3085df8e 100644 --- a/borsh-derive/src/internals/schema/enums/snapshots/with_funcs_attr.snap +++ b/borsh-derive/src/internals/schema/enums/snapshots/with_funcs_attr.snap @@ -1,6 +1,7 @@ --- source: borsh-derive/src/internals/schema/enums/mod.rs expression: pretty_print_syn_str(&actual).unwrap() +snapshot_kind: text --- impl borsh::BorshSchema for C where @@ -42,12 +43,13 @@ where ::add_definitions_recursively(definitions); as borsh::BorshSchema>::add_definitions_recursively(definitions); let definition = borsh::schema::Definition::Enum { - tag_width: 1, + tag_width: 1u8, variants: borsh::__private::maybestd::vec![ (u8::from(0u8) as i64, "C3".into(), < CC3 as borsh::BorshSchema > ::declaration()), (u8::from(1u8) as i64, "C4".into(), < CC4 < K, V > as borsh::BorshSchema > ::declaration()) ], + tag_signed: false, }; borsh::schema::add_definition( ::declaration(), @@ -56,4 +58,3 @@ where ); } } - diff --git a/borsh-derive/src/internals/serialize/enums/mod.rs b/borsh-derive/src/internals/serialize/enums/mod.rs index 4e86ca2d4..9c85c09e2 100644 --- a/borsh-derive/src/internals/serialize/enums/mod.rs +++ b/borsh-derive/src/internals/serialize/enums/mod.rs @@ -17,12 +17,19 @@ pub fn process(input: &ItemEnum, cratename: Path) -> syn::Result { let mut all_variants_idx_body = TokenStream2::new(); let mut fields_body = TokenStream2::new(); let use_discriminant = item::contains_use_discriminant(input)?; - let discriminants = Discriminants::new(&input.variants); + let maybe_borsh_tag_width = item::get_maybe_borsh_tag_width(input)?; + let maybe_rust_repr = item::get_maybe_rust_repr(input); + let discriminants = Discriminants::new( + &input.variants, + maybe_borsh_tag_width, + maybe_rust_repr, + use_discriminant, + )?; let mut has_unit_variant = false; for (variant_idx, variant) in input.variants.iter().enumerate() { let variant_ident = &variant.ident; - let discriminant_value = discriminants.get(variant_ident, use_discriminant, variant_idx)?; + let discriminant_value = discriminants.get(variant_ident, variant_idx)?; let variant_output = process_variant( variant, enum_ident, @@ -42,11 +49,11 @@ pub fn process(input: &ItemEnum, cratename: Path) -> syn::Result { } let fields_body = optimize_fields_body(fields_body, has_unit_variant); generics_output.extend(&mut where_clause, &cratename); - + let discriminant_type = discriminants.discriminant_type(); Ok(quote! { impl #impl_generics #cratename::ser::BorshSerialize for #enum_ident #ty_generics #where_clause { fn serialize<__W: #cratename::io::Write>(&self, writer: &mut __W) -> ::core::result::Result<(), #cratename::io::Error> { - let variant_idx: u8 = match self { + let variant_idx: #discriminant_type = match self { #all_variants_idx_body }; writer.write_all(&variant_idx.to_le_bytes())?; diff --git a/borsh/examples/serde_json_value.rs b/borsh/examples/serde_json_value.rs index 21cadf758..a4d11b7bf 100644 --- a/borsh/examples/serde_json_value.rs +++ b/borsh/examples/serde_json_value.rs @@ -75,7 +75,7 @@ mod serde_json_value { &(u32::try_from(array.len()).map_err(|_| ErrorKind::InvalidData)?).to_le_bytes(), )?; for item in array { - serialize_value(&item, writer)?; + serialize_value(item, writer)?; } Ok(()) } @@ -92,7 +92,7 @@ mod serde_json_value { for (key, value) in map { key.serialize(writer)?; - serialize_value(&value, writer)?; + serialize_value(value, writer)?; } Ok(()) @@ -260,8 +260,8 @@ fn main() { "negative_integer": -88888, "positive_float": 123.45, "negative_float": -888.88, - "positive_max": 1.7976931348623157e+308, - "negative_max": -1.7976931348623157e+308, + "positive_max": 1.797_693_134_862_315_7e308, + "negative_max": -1.797_693_134_862_315_7e308, "string": "Larry", "array_of_nulls": [null, null, null], "array_of_numbers": [0, -1, 1, 1.1, -1.1, 34798324], diff --git a/borsh/src/schema.rs b/borsh/src/schema.rs index 65b5bb1c9..810c8cfa5 100644 --- a/borsh/src/schema.rs +++ b/borsh/src/schema.rs @@ -112,6 +112,10 @@ pub enum Definition { /// invalid if the value is greater than eight. tag_width: u8, + /// If true, than tag is signed value. + /// If false, unsigned value. + tag_signed: bool, + /// Possible variants of the enumeration. /// `VariantName` is metadata, not present in a type's serialized representation. variants: Vec<(DiscriminantValue, VariantName, Declaration)>, @@ -601,6 +605,7 @@ where (0u8 as i64, "None".to_string(), <()>::declaration()), (1u8 as i64, "Some".to_string(), T::declaration()), ], + tag_signed: false, }; add_definition(Self::declaration(), definition, definitions); T::add_definitions_recursively(definitions); @@ -624,6 +629,7 @@ where (1u8 as i64, "Ok".to_string(), T::declaration()), (0u8 as i64, "Err".to_string(), E::declaration()), ], + tag_signed: false, }; add_definition(Self::declaration(), definition, definitions); T::add_definitions_recursively(definitions); diff --git a/borsh/src/schema/container_ext/max_size.rs b/borsh/src/schema/container_ext/max_size.rs index 193f3be1e..f963b9e77 100644 --- a/borsh/src/schema/container_ext/max_size.rs +++ b/borsh/src/schema/container_ext/max_size.rs @@ -126,6 +126,7 @@ fn max_serialized_size_impl<'a>( Ok(Definition::Enum { tag_width, variants, + tag_signed: _, }) => { let mut max = 0; for (_, _, variant) in variants { @@ -232,6 +233,7 @@ fn is_zero_size_impl<'a>( Ok(Definition::Enum { tag_width: 0, variants, + tag_signed: _, }) => all( variants.iter(), |(_variant_discrim, _variant_name, declaration)| declaration, diff --git a/borsh/src/schema/container_ext/validate.rs b/borsh/src/schema/container_ext/validate.rs index 6bd2738fb..524fe344f 100644 --- a/borsh/src/schema/container_ext/validate.rs +++ b/borsh/src/schema/container_ext/validate.rs @@ -106,10 +106,12 @@ fn validate_impl<'a>( Definition::Enum { tag_width, variants, + tag_signed: _, } => { if *tag_width > U64_LEN { return Err(Error::TagTooWide(declaration.to_string())); } + for (_, _, variant) in variants { validate_impl(variant, schema, stack)?; } diff --git a/borsh/tests/init_in_deserialize/test_init_in_deserialize.rs b/borsh/tests/init_in_deserialize/test_init_in_deserialize.rs index 652cfc923..1941bc0dc 100644 --- a/borsh/tests/init_in_deserialize/test_init_in_deserialize.rs +++ b/borsh/tests/init_in_deserialize/test_init_in_deserialize.rs @@ -27,7 +27,7 @@ fn test_simple_struct() { } #[derive(BorshSerialize, BorshDeserialize, PartialEq, Debug)] -#[borsh(init=initialization_method)] +#[borsh(init = initialization_method)] enum AEnum { A, B, diff --git a/borsh/tests/roundtrip/requires_derive_category/test_enum_discriminants.rs b/borsh/tests/roundtrip/requires_derive_category/test_enum_discriminants.rs index ac719fe57..237674bbb 100644 --- a/borsh/tests/roundtrip/requires_derive_category/test_enum_discriminants.rs +++ b/borsh/tests/roundtrip/requires_derive_category/test_enum_discriminants.rs @@ -29,7 +29,7 @@ enum XYNoDiscriminant { #[test] fn test_discriminant_serde_no_unit_type() { - let values = vec![XY::A, XY::B, XY::C, XY::E, XY::D(12, 14), XY::F(35325423)]; + let values = [XY::A, XY::B, XY::C, XY::E, XY::D(12, 14), XY::F(35325423)]; let expected_discriminants = [0u8, 20, 21, 10, 22, 11]; for (ind, value) in values.iter().enumerate() { @@ -39,16 +39,62 @@ fn test_discriminant_serde_no_unit_type() { } } +#[test] +pub fn u16_discriminant() { + use borsh::{BorshDeserialize, BorshSerialize}; + #[derive(BorshSerialize, BorshDeserialize, Debug, Eq, PartialEq)] + #[borsh(use_discriminant = true, tag_width = 2)] + #[repr(u16)] + enum U16Discriminant { + U8 = 42, + U16 = 666, + } + let mut buf = vec![]; + let data = U16Discriminant::U16; + data.serialize(&mut buf).unwrap(); + assert_eq!(buf[0], 154); + assert_eq!(buf[1], 2); + + assert_eq!(buf.len(), 2, "Serialized data should be 2 bytes long"); + + let deserialized = U16Discriminant::deserialize(&mut buf.as_slice()).unwrap(); + assert_eq!(deserialized, data); +} + +#[test] +pub fn u32_discriminant() { + use borsh::{BorshDeserialize, BorshSerialize}; + #[derive(BorshSerialize, BorshDeserialize, Debug, Eq, PartialEq)] + #[borsh(use_discriminant = true, tag_width = 4)] + #[repr(u32)] + enum U32Discriminant { + U8 = 42u32, + U32 = u32::MAX, + } + let mut buf = vec![]; + let data = U32Discriminant::U32; + data.serialize(&mut buf).unwrap(); + assert_eq!(buf.len(), 4, "Serialized data should be 4 bytes long"); + + assert_eq!(buf[0], 255); + assert_eq!(buf[1], 255); + assert_eq!(buf[2], 255); + assert_eq!(buf[3], 255); + + + let deserialized = U32Discriminant::deserialize(&mut buf.as_slice()).unwrap(); + assert_eq!(deserialized, data); +} + + #[test] fn test_discriminant_serde_no_unit_type_no_use_discriminant() { - let values = vec![ - XYNoDiscriminant::A, + let values = [XYNoDiscriminant::A, XYNoDiscriminant::B, XYNoDiscriminant::C, XYNoDiscriminant::D(12, 14), XYNoDiscriminant::E, - XYNoDiscriminant::F(35325423), - ]; + XYNoDiscriminant::F(35325423)]; let expected_discriminants = [0u8, 1, 2, 3, 4, 5]; for (ind, value) in values.iter().enumerate() { @@ -105,14 +151,12 @@ enum XNoDiscriminant { #[test] fn test_discriminant_serde_no_use_discriminant() { - let values = vec![ - XNoDiscriminant::A, + let values = [XNoDiscriminant::A, XNoDiscriminant::B, XNoDiscriminant::C, XNoDiscriminant::D, XNoDiscriminant::E, - XNoDiscriminant::F, - ]; + XNoDiscriminant::F]; let expected_discriminants = [0u8, 1, 2, 3, 4, 5]; for (index, value) in values.iter().enumerate() { let data = to_vec(value).unwrap(); @@ -169,7 +213,7 @@ fn test_deserialize_invalid_discriminant() { #[test] fn test_discriminant_serde() { - let values = vec![X::A, X::B, X::C, X::D, X::E, X::F]; + let values = [X::A, X::B, X::C, X::D, X::E, X::F]; let expected_discriminants = [0u8, 20, 21, 22, 10, 11]; for (index, value) in values.iter().enumerate() { let data = to_vec(value).unwrap(); diff --git a/borsh/tests/schema/container_extension/test_max_size.rs b/borsh/tests/schema/container_extension/test_max_size.rs index d179b0e83..7369d6502 100644 --- a/borsh/tests/schema/container_extension/test_max_size.rs +++ b/borsh/tests/schema/container_extension/test_max_size.rs @@ -133,6 +133,7 @@ fn max_serialized_size_custom_enum() { } fn add_definitions_recursively(definitions: &mut BTreeMap) { let definition = Definition::Enum { + tag_signed: false, tag_width: N, variants: vec![ (0, "Just".into(), T::declaration()), diff --git a/borsh/tests/schema/snapshots/tests__schema__test_ip_addr__ip_addr_schema.snap b/borsh/tests/schema/snapshots/tests__schema__test_ip_addr__ip_addr_schema.snap index 68f0c7109..8037a8da9 100644 --- a/borsh/tests/schema/snapshots/tests__schema__test_ip_addr__ip_addr_schema.snap +++ b/borsh/tests/schema/snapshots/tests__schema__test_ip_addr__ip_addr_schema.snap @@ -1,10 +1,12 @@ --- source: borsh/tests/schema/test_ip_addr.rs expression: "format!(\"{:#?}\", defs)" +snapshot_kind: text --- { "IpAddr": Enum { tag_width: 1, + tag_signed: false, variants: [ ( 0, diff --git a/borsh/tests/schema/snapshots/tests__schema__test_simple_enums__complex_enum_with_schema.snap b/borsh/tests/schema/snapshots/tests__schema__test_simple_enums__complex_enum_with_schema.snap index 1cdf94d5f..674c08f43 100644 --- a/borsh/tests/schema/snapshots/tests__schema__test_simple_enums__complex_enum_with_schema.snap +++ b/borsh/tests/schema/snapshots/tests__schema__test_simple_enums__complex_enum_with_schema.snap @@ -1,6 +1,7 @@ --- source: borsh/tests/schema/test_simple_enums.rs expression: data +snapshot_kind: text --- [ 1, @@ -19,6 +20,7 @@ expression: data 65, 3, 1, + 0, 4, 0, 0, diff --git a/borsh/tests/schema/test_enum_discriminants.rs b/borsh/tests/schema/test_enum_discriminants.rs index eb502ab12..770855715 100644 --- a/borsh/tests/schema/test_enum_discriminants.rs +++ b/borsh/tests/schema/test_enum_discriminants.rs @@ -21,6 +21,7 @@ fn test_schema_discriminant_no_unit_type() { assert_eq!( schema_map! { "XY" => Definition::Enum { + tag_signed: false, tag_width: 1, variants: vec![ (0, "A".to_string(), "XYA".to_string()), @@ -29,7 +30,7 @@ fn test_schema_discriminant_no_unit_type() { (22, "D".to_string(), "XYD".to_string()), (10, "E".to_string(), "XYE".to_string()), (11, "F".to_string(), "XYF".to_string()) - ] + ], }, "XYA" => Definition::Struct{ fields: Fields::Empty }, "XYB" => Definition::Struct{ fields: Fields::Empty }, @@ -73,6 +74,7 @@ fn test_schema_discriminant_no_unit_type_no_use_discriminant() { assert_eq!( schema_map! { "XYNoDiscriminant" => Definition::Enum { + tag_signed: false, tag_width: 1, variants: vec![ (0, "A".to_string(), "XYNoDiscriminantA".to_string()), @@ -100,3 +102,55 @@ fn test_schema_discriminant_no_unit_type_no_use_discriminant() { defs ); } + + +#[test] +fn tag_widths() { + #[derive(BorshSchema)] + #[borsh(use_discriminant=true, tag_width = 2)] + #[repr(u16)] + #[allow(dead_code)] + enum U16Discriminant { + A = 42u16, + } + + let mut defs = Default::default(); + U16Discriminant::add_definitions_recursively(&mut defs); + assert_eq!( + schema_map! { + "U16Discriminant" => Definition::Enum { + tag_signed: false, + tag_width: 2, + variants: vec![ + (42, "A".to_string(), "U16DiscriminantA".to_string()), + ], + }, + "U16DiscriminantA" => Definition::Struct{ fields: Fields::Empty } + }, + defs + ); + + #[derive(BorshSchema)] + #[borsh(use_discriminant = true, tag_width = 4)] + #[repr(u32)] + #[allow(dead_code)] + enum U32Discriminant { + A = 42u32, + } + + let mut defs = Default::default(); + U32Discriminant::add_definitions_recursively(&mut defs); + assert_eq!( + schema_map! { + "U32Discriminant" => Definition::Enum { + tag_signed: false, + tag_width: 4, + variants: vec![ + (42, "A".to_string(), "U32DiscriminantA".to_string()), + ], + }, + "U32DiscriminantA" => Definition::Struct{ fields: Fields::Empty } + }, + defs + ); +} \ No newline at end of file diff --git a/borsh/tests/schema/test_generic_enums.rs b/borsh/tests/schema/test_generic_enums.rs index 61965bbb8..72d1c7c1b 100644 --- a/borsh/tests/schema/test_generic_enums.rs +++ b/borsh/tests/schema/test_generic_enums.rs @@ -40,6 +40,7 @@ pub fn complex_enum_generics() { "ABacon" => Definition::Struct {fields: Fields::Empty}, "Oil" => Definition::Struct {fields: Fields::Empty}, "A" => Definition::Enum { + tag_signed: false, tag_width: 1, variants: vec![ (0, "Bacon".to_string(), "ABacon".to_string()), @@ -102,6 +103,7 @@ pub fn complex_enum_generics2() { assert_eq!( schema_map! { "A>" => Definition::Enum { + tag_signed: false, tag_width: 1, variants: vec![ (0, "Bacon".to_string(), "ABacon".to_string()), @@ -111,6 +113,7 @@ pub fn complex_enum_generics2() { ] }, "A" => Definition::Enum { + tag_signed: false, tag_width: 1, variants: vec![ (0, "Bacon".to_string(), "ABacon".to_string()), @@ -134,6 +137,7 @@ pub fn complex_enum_generics2() { }, "Oil" => Definition::Struct { fields: Fields::NamedFields(vec![("seeds".to_string(), "HashMap".to_string()), ("liquid".to_string(), "Option".to_string())])}, "Option" => Definition::Enum { + tag_signed: false, tag_width: 1, variants: vec![ (0, "None".to_string(), "()".to_string()), @@ -141,6 +145,7 @@ pub fn complex_enum_generics2() { ] }, "Option" => Definition::Enum { + tag_signed: false, tag_width: 1, variants: vec![ (0, "None".to_string(), "()".to_string()), @@ -166,6 +171,7 @@ pub fn complex_enum_generics2() { fn common_map_associated() -> BTreeMap { schema_map! { "EnumParametrized" => Definition::Enum { + tag_signed: false, tag_width: 1, variants: vec![ (0, "B".to_string(), "EnumParametrizedB".to_string()), diff --git a/borsh/tests/schema/test_option.rs b/borsh/tests/schema/test_option.rs index f04c1f906..e9baa9e2b 100644 --- a/borsh/tests/schema/test_option.rs +++ b/borsh/tests/schema/test_option.rs @@ -9,6 +9,7 @@ fn simple_option() { assert_eq!( schema_map! { "Option" => Definition::Enum { + tag_signed: false, tag_width: 1, variants: vec![ (0, "None".to_string(), "()".to_string()), @@ -31,6 +32,7 @@ fn nested_option() { assert_eq!( schema_map! { "Option" => Definition::Enum { + tag_signed: false, tag_width: 1, variants: vec![ (0, "None".to_string(), "()".to_string()), @@ -38,6 +40,7 @@ fn nested_option() { ] }, "Option>" => Definition::Enum { + tag_signed: false, tag_width: 1, variants: vec![ (0, "None".to_string(), "()".to_string()), diff --git a/borsh/tests/schema/test_recursive_enums.rs b/borsh/tests/schema/test_recursive_enums.rs index f5bc64e49..dca9195dc 100644 --- a/borsh/tests/schema/test_recursive_enums.rs +++ b/borsh/tests/schema/test_recursive_enums.rs @@ -14,6 +14,7 @@ pub fn recursive_enum_schema() { assert_eq!( schema_map! { "ERecD" => Definition::Enum { + tag_signed: false, tag_width: 1, variants: vec![ (0, "B".to_string(), "ERecDB".to_string()), diff --git a/borsh/tests/schema/test_schema_with_third_party.rs b/borsh/tests/schema/test_schema_with_third_party.rs index 30562f89a..c246c6d05 100644 --- a/borsh/tests/schema/test_schema_with_third_party.rs +++ b/borsh/tests/schema/test_schema_with_third_party.rs @@ -10,7 +10,7 @@ mod third_party_impl { pub(super) fn declaration( ) -> borsh::schema::Declaration { - let params = vec![::declaration(), ::declaration()]; + let params = [::declaration(), ::declaration()]; format!(r#"{}<{}>"#, "ThirdParty", params.join(", ")) } @@ -100,6 +100,7 @@ pub fn enum_overriden() { assert_eq!( schema_map! { "C" => Definition::Enum { + tag_signed: false, tag_width: 1, variants: vec![ (0, "C3".to_string(), "CC3".to_string()), diff --git a/borsh/tests/schema/test_simple_enums.rs b/borsh/tests/schema/test_simple_enums.rs index 27e94804e..0bb647b95 100644 --- a/borsh/tests/schema/test_simple_enums.rs +++ b/borsh/tests/schema/test_simple_enums.rs @@ -25,6 +25,7 @@ pub fn simple_enum() { "ABacon" => Definition::Struct{ fields: Fields::Empty }, "AEggs" => Definition::Struct{ fields: Fields::Empty }, "A" => Definition::Enum { + tag_signed: false, tag_width: 1, variants: vec![(0, "Bacon".to_string(), "ABacon".to_string()), (1, "Eggs".to_string(), "AEggs".to_string())] } @@ -47,6 +48,7 @@ pub fn single_field_enum() { schema_map! { "ABacon" => Definition::Struct {fields: Fields::Empty}, "A" => Definition::Enum { + tag_signed: false, tag_width: 1, variants: vec![(0, "Bacon".to_string(), "ABacon".to_string())] } @@ -131,6 +133,7 @@ pub fn complex_enum_with_schema() { "ABacon" => Definition::Struct {fields: Fields::Empty}, "Oil" => Definition::Struct {fields: Fields::Empty}, "A" => Definition::Enum { + tag_signed: false, tag_width: 1, variants: vec![ (0, "Bacon".to_string(), "ABacon".to_string()),