diff --git a/Cargo.toml b/Cargo.toml index 2a1e1907c..82020ecb9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,4 +4,4 @@ members = ["borsh", "borsh-derive", "fuzz/fuzz-run", "benchmarks"] [workspace.package] # shared version of all public crates in the workspace version = "1.5.1" -rust-version = "1.67.0" +rust-version = "1.81.0" diff --git a/borsh-derive/src/internals/attributes/item/mod.rs b/borsh-derive/src/internals/attributes/item/mod.rs index ff3551486..9bb3baeb2 100644 --- a/borsh-derive/src/internals/attributes/item/mod.rs +++ b/borsh-derive/src/internals/attributes/item/mod.rs @@ -1,8 +1,8 @@ use crate::internals::attributes::{BORSH, CRATE, INIT, USE_DISCRIMINANT}; use quote::ToTokens; -use syn::{spanned::Spanned, Attribute, DeriveInput, Error, Expr, ItemEnum, Path}; +use syn::{spanned::Spanned, Attribute, DeriveInput, Error, Expr, ItemEnum, Path, TypePath}; -use super::{get_one_attribute, parsing}; +use super::{get_one_attribute, parsing, REPR}; pub fn check_attributes(derive_input: &DeriveInput) -> Result<(), Error> { let borsh = get_one_attribute(&derive_input.attrs)?; @@ -34,10 +34,11 @@ pub fn check_attributes(derive_input: &DeriveInput) -> Result<(), Error> { } pub(crate) fn contains_use_discriminant(input: &ItemEnum) -> Result { - if input.variants.len() > 256 { + const MAX_VARIANTS: usize = u16::MAX as usize + 1; + if input.variants.len() > MAX_VARIANTS { return Err(syn::Error::new( input.span(), - "up to 256 enum variants are supported", + "up to {MAX_VARIANTS} enum variants are supported", )); } @@ -80,6 +81,14 @@ pub(crate) fn contains_use_discriminant(input: &ItemEnum) -> Result Option { + input.attrs.iter().find(|x| { + x.path() == REPR + }) + ?.parse_args().ok() +} + pub(crate) fn contains_initialize_with(attrs: &[Attribute]) -> Result, Error> { let mut res = None; let attr = attrs.iter().find(|attr| attr.path() == BORSH); diff --git a/borsh-derive/src/internals/attributes/mod.rs b/borsh-derive/src/internals/attributes/mod.rs index 4c5a69d4d..cc8ea1d4a 100644 --- a/borsh-derive/src/internals/attributes/mod.rs +++ b/borsh-derive/src/internals/attributes/mod.rs @@ -29,6 +29,7 @@ pub const SERIALIZE_WITH: Symbol = Symbol("serialize_with", "serialize_with = .. pub const DESERIALIZE_WITH: Symbol = Symbol("deserialize_with", "deserialize_with = ..."); /// crate - sub-borsh nested meta, item-level only, `BorshSerialize`, `BorshDeserialize`, `BorshSchema` contexts pub const CRATE: Symbol = Symbol("crate", "crate = ..."); +pub const REPR: Symbol = Symbol("repr", "repr(...)"); #[cfg(feature = "schema")] pub mod schema_keys { diff --git a/borsh-derive/src/internals/deserialize/enums/mod.rs b/borsh-derive/src/internals/deserialize/enums/mod.rs index fb405e90d..9d90e75ac 100644 --- a/borsh-derive/src/internals/deserialize/enums/mod.rs +++ b/borsh-derive/src/internals/deserialize/enums/mod.rs @@ -11,7 +11,8 @@ pub fn process(input: &ItemEnum, cratename: Path) -> syn::Result { let mut where_clause = generics::default_where(where_clause); let mut variant_arms = TokenStream2::new(); let use_discriminant = item::contains_use_discriminant(input)?; - let discriminants = Discriminants::new(&input.variants); + let maybe_reprc_attribute = item::get_maybe_reprc_attribute(input); + let discriminants: Discriminants = Discriminants::new(&input.variants, maybe_reprc_attribute); let mut generics_output = deserialize::GenericsOutput::new(&generics); for (variant_idx, variant) in input.variants.iter().enumerate() { @@ -20,7 +21,7 @@ pub fn process(input: &ItemEnum, cratename: Path) -> syn::Result { let discriminant_value = discriminants.get(variant_ident, use_discriminant, variant_idx)?; variant_arms.extend(quote! { - if variant_tag == #discriminant_value { #name::#variant_ident #variant_body } else + if variant_tag == #discriminant_value.into() { #name::#variant_ident #variant_body } else }); } let init = if let Some(method_ident) = item::contains_initialize_with(&input.attrs)? { @@ -32,18 +33,21 @@ pub fn process(input: &ItemEnum, cratename: Path) -> syn::Result { }; generics_output.extend(&mut where_clause, &cratename); - Ok(quote! { + let discriminant_type = discriminants.discriminant_type(); + let x = quote! { impl #impl_generics #cratename::de::BorshDeserialize for #name #ty_generics #where_clause { fn deserialize_reader<__R: #cratename::io::Read>(reader: &mut __R) -> ::core::result::Result { - let tag = ::deserialize_reader(reader)?; - ::deserialize_variant(reader, tag) + let tag = <#discriminant_type as #cratename::de::BorshDeserialize>::deserialize_reader(reader)?; + ::deserialize_variant::<_, #discriminant_type>(reader, tag) } } impl #impl_generics #cratename::de::EnumExt for #name #ty_generics #where_clause { - fn deserialize_variant<__R: #cratename::io::Read>( + fn deserialize_variant<__R: #cratename::io::Read, + Tag: borsh::BorshDeserialize + ::core::fmt::Debug + Eq + >( reader: &mut __R, - variant_tag: u8, + variant_tag: u8, //#discriminant_type, ) -> ::core::result::Result { let mut return_value = #variant_arms { @@ -56,7 +60,8 @@ pub fn process(input: &ItemEnum, cratename: Path) -> syn::Result { Ok(return_value) } } - }) + }; + Ok(x) } fn process_variant( diff --git a/borsh-derive/src/internals/enum_discriminant.rs b/borsh-derive/src/internals/enum_discriminant.rs index 03b3f5829..71e257722 100644 --- a/borsh-derive/src/internals/enum_discriminant.rs +++ b/borsh-derive/src/internals/enum_discriminant.rs @@ -1,15 +1,16 @@ +use core::convert::TryInto; use std::collections::HashMap; use std::convert::TryFrom; use proc_macro2::{Ident, TokenStream}; use quote::quote; -use syn::{punctuated::Punctuated, token::Comma, Variant}; +use syn::{parse::{Parse, ParseBuffer}, punctuated::Punctuated, token::{Comma, Type}, Path, Variant}; -pub struct Discriminants(HashMap); +pub struct Discriminants((HashMap, syn::TypePath)); impl Discriminants { /// Calculates the discriminant that will be assigned by the compiler. /// See: https://doc.rust-lang.org/reference/items/enumerations.html#assigning-discriminant-values - pub fn new(variants: &Punctuated) -> Self { + pub fn new(variants: &Punctuated, maybe_discriminant_type: Option) -> Self { let mut map = HashMap::new(); let mut next_discriminant_if_not_specified = quote! {0}; @@ -18,12 +19,20 @@ impl Discriminants { || quote! { #next_discriminant_if_not_specified }, |(_, e)| quote! { #e }, ); - + next_discriminant_if_not_specified = quote! { #this_discriminant + 1 }; map.insert(variant.ident.clone(), this_discriminant); } + let discriminant_type = //maybe_discriminant_type.unwrap_or( + syn::parse_str("u8").expect("numeric") + //) + ; + + Self((map, discriminant_type)) + } - Self(map) + pub fn discriminant_type(&self) -> &syn::TypePath { + &self.0.1 } pub fn get( @@ -32,14 +41,14 @@ impl Discriminants { use_discriminant: bool, variant_idx: usize, ) -> syn::Result { - let variant_idx = u8::try_from(variant_idx).map_err(|err| { + let variant_idx: u8 = u8::try_from(variant_idx).map_err(|err| { syn::Error::new( variant_ident.span(), - format!("up to 256 enum variants are supported: {}", err), + format!("up to {} enum variants are supported: {}", u8::MAX as usize + 1, err), ) })?; let result = if use_discriminant { - let discriminant_value = self.0.get(variant_ident).unwrap(); + let discriminant_value = self.0.0.get(variant_ident).unwrap(); quote! { #discriminant_value } } else { quote! { #variant_idx } diff --git a/borsh-derive/src/internals/serialize/enums/mod.rs b/borsh-derive/src/internals/serialize/enums/mod.rs index 4e86ca2d4..a62ec324e 100644 --- a/borsh-derive/src/internals/serialize/enums/mod.rs +++ b/borsh-derive/src/internals/serialize/enums/mod.rs @@ -17,7 +17,8 @@ pub fn process(input: &ItemEnum, cratename: Path) -> syn::Result { let mut all_variants_idx_body = TokenStream2::new(); let mut fields_body = TokenStream2::new(); let use_discriminant = item::contains_use_discriminant(input)?; - let discriminants = Discriminants::new(&input.variants); + let maybe_discriminant_type = item::get_maybe_reprc_attribute(input); + let discriminants = Discriminants::new(&input.variants, maybe_discriminant_type); let mut has_unit_variant = false; for (variant_idx, variant) in input.variants.iter().enumerate() { @@ -42,11 +43,11 @@ pub fn process(input: &ItemEnum, cratename: Path) -> syn::Result { } let fields_body = optimize_fields_body(fields_body, has_unit_variant); generics_output.extend(&mut where_clause, &cratename); - + let discriminant_type = discriminants.discriminant_type(); Ok(quote! { impl #impl_generics #cratename::ser::BorshSerialize for #enum_ident #ty_generics #where_clause { fn serialize<__W: #cratename::io::Write>(&self, writer: &mut __W) -> ::core::result::Result<(), #cratename::io::Error> { - let variant_idx: u8 = match self { + let variant_idx: #discriminant_type = match self { #all_variants_idx_body }; writer.write_all(&variant_idx.to_le_bytes())?; diff --git a/borsh/src/de/mod.rs b/borsh/src/de/mod.rs index 0abb36e3e..47bd3f5f9 100644 --- a/borsh/src/de/mod.rs +++ b/borsh/src/de/mod.rs @@ -132,7 +132,7 @@ pub trait EnumExt: BorshDeserialize { /// # #[cfg(feature = "derive")] /// assert!(from_slice::(&data[..]).is_err()); /// ``` - fn deserialize_variant(reader: &mut R, tag: u8) -> Result; + fn deserialize_variant(reader: &mut R, tag: u8) -> Result; } fn unexpected_eof_to_unexpected_length_of_input(e: Error) -> Error { diff --git a/borsh/tests/roundtrip/requires_derive_category/test_enum_discriminants.rs b/borsh/tests/roundtrip/requires_derive_category/test_enum_discriminants.rs index ac719fe57..60c62ff91 100644 --- a/borsh/tests/roundtrip/requires_derive_category/test_enum_discriminants.rs +++ b/borsh/tests/roundtrip/requires_derive_category/test_enum_discriminants.rs @@ -39,6 +39,34 @@ fn test_discriminant_serde_no_unit_type() { } } +#[test] +pub fn u16_discriminat() { + use borsh::{BorshSerialize, BorshDeserialize}; + #[derive(BorshSerialize, BorshDeserialize, Debug)] + #[borsh(use_discriminant = true)] + #[repr(u16)] + enum ZEnum { + AA = 42, + Z=2, + // A { a: u16, b: u64, d: bool, s: String } = 1u16, + // Z { a: u16, b: u64, d: bool, s: String } = 257u16, + } + let mut ss = vec![]; + ZEnum::AA.serialize(&mut ss).unwrap(); + assert!(ss[0] == 42); + assert!(ss[1] == 0); + // let s = Enum::A { + // a: 13, + // b: 42, + // d: true, + // s: "hello my bonny".to_string(), + // }; + // let mut buf = Vec::new(); + // s.serialize(&mut buf).expect("must serialize"); + // panic!("{:?}", buf); +} + + #[test] fn test_discriminant_serde_no_unit_type_no_use_discriminant() { let values = vec![ diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 000000000..f602afad8 --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,10 @@ +[toolchain] +channel = "1.81.0" + +# Default profile includes `rustfmt`, `clippy`, `rust-docs`. +# https://rust-lang.github.io/rustup/concepts/profiles.html +profile = "default" +components = ["rust-analyzer"] + +# For static linking for deployment. +targets = ["x86_64-unknown-linux-musl"]