Skip to content

Commit

Permalink
did most of u16 disciminator, failed on deserialize_variant
Browse files Browse the repository at this point in the history
  • Loading branch information
dzmitry-lahoda committed Oct 14, 2024
1 parent b416d11 commit dca2e53
Show file tree
Hide file tree
Showing 9 changed files with 88 additions and 25 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ members = ["borsh", "borsh-derive", "fuzz/fuzz-run", "benchmarks"]
[workspace.package]
# shared version of all public crates in the workspace
version = "1.5.1"
rust-version = "1.67.0"
rust-version = "1.81.0"
17 changes: 13 additions & 4 deletions borsh-derive/src/internals/attributes/item/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::internals::attributes::{BORSH, CRATE, INIT, USE_DISCRIMINANT};
use quote::ToTokens;
use syn::{spanned::Spanned, Attribute, DeriveInput, Error, Expr, ItemEnum, Path};
use syn::{spanned::Spanned, Attribute, DeriveInput, Error, Expr, ItemEnum, Path, TypePath};

use super::{get_one_attribute, parsing};
use super::{get_one_attribute, parsing, REPR};

pub fn check_attributes(derive_input: &DeriveInput) -> Result<(), Error> {
let borsh = get_one_attribute(&derive_input.attrs)?;
Expand Down Expand Up @@ -34,10 +34,11 @@ pub fn check_attributes(derive_input: &DeriveInput) -> Result<(), Error> {
}

pub(crate) fn contains_use_discriminant(input: &ItemEnum) -> Result<bool, syn::Error> {
if input.variants.len() > 256 {
const MAX_VARIANTS: usize = u16::MAX as usize + 1;
if input.variants.len() > MAX_VARIANTS {
return Err(syn::Error::new(
input.span(),
"up to 256 enum variants are supported",
"up to {MAX_VARIANTS} enum variants are supported",
));
}

Expand Down Expand Up @@ -80,6 +81,14 @@ pub(crate) fn contains_use_discriminant(input: &ItemEnum) -> Result<bool, syn::E
Ok(use_discriminant.unwrap_or(false))
}

/// Gets type of reprc attribute if it exists
pub (crate) fn get_maybe_reprc_attribute(input: &ItemEnum) -> Option<TypePath> {
input.attrs.iter().find(|x| {
x.path() == REPR
})
?.parse_args().ok()
}

pub(crate) fn contains_initialize_with(attrs: &[Attribute]) -> Result<Option<Path>, Error> {
let mut res = None;
let attr = attrs.iter().find(|attr| attr.path() == BORSH);
Expand Down
1 change: 1 addition & 0 deletions borsh-derive/src/internals/attributes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pub const SERIALIZE_WITH: Symbol = Symbol("serialize_with", "serialize_with = ..
pub const DESERIALIZE_WITH: Symbol = Symbol("deserialize_with", "deserialize_with = ...");
/// crate - sub-borsh nested meta, item-level only, `BorshSerialize`, `BorshDeserialize`, `BorshSchema` contexts
pub const CRATE: Symbol = Symbol("crate", "crate = ...");
pub const REPR: Symbol = Symbol("repr", "repr(...)");

#[cfg(feature = "schema")]
pub mod schema_keys {
Expand Down
21 changes: 13 additions & 8 deletions borsh-derive/src/internals/deserialize/enums/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ pub fn process(input: &ItemEnum, cratename: Path) -> syn::Result<TokenStream2> {
let mut where_clause = generics::default_where(where_clause);
let mut variant_arms = TokenStream2::new();
let use_discriminant = item::contains_use_discriminant(input)?;
let discriminants = Discriminants::new(&input.variants);
let maybe_reprc_attribute = item::get_maybe_reprc_attribute(input);
let discriminants: Discriminants = Discriminants::new(&input.variants, maybe_reprc_attribute);
let mut generics_output = deserialize::GenericsOutput::new(&generics);

for (variant_idx, variant) in input.variants.iter().enumerate() {
Expand All @@ -20,7 +21,7 @@ pub fn process(input: &ItemEnum, cratename: Path) -> syn::Result<TokenStream2> {

let discriminant_value = discriminants.get(variant_ident, use_discriminant, variant_idx)?;
variant_arms.extend(quote! {
if variant_tag == #discriminant_value { #name::#variant_ident #variant_body } else
if variant_tag == #discriminant_value.into() { #name::#variant_ident #variant_body } else
});
}
let init = if let Some(method_ident) = item::contains_initialize_with(&input.attrs)? {
Expand All @@ -32,18 +33,21 @@ pub fn process(input: &ItemEnum, cratename: Path) -> syn::Result<TokenStream2> {
};
generics_output.extend(&mut where_clause, &cratename);

Ok(quote! {
let discriminant_type = discriminants.discriminant_type();
let x = quote! {
impl #impl_generics #cratename::de::BorshDeserialize for #name #ty_generics #where_clause {
fn deserialize_reader<__R: #cratename::io::Read>(reader: &mut __R) -> ::core::result::Result<Self, #cratename::io::Error> {
let tag = <u8 as #cratename::de::BorshDeserialize>::deserialize_reader(reader)?;
<Self as #cratename::de::EnumExt>::deserialize_variant(reader, tag)
let tag = <#discriminant_type as #cratename::de::BorshDeserialize>::deserialize_reader(reader)?;
<Self as #cratename::de::EnumExt>::deserialize_variant::<_, #discriminant_type>(reader, tag)
}
}

impl #impl_generics #cratename::de::EnumExt for #name #ty_generics #where_clause {
fn deserialize_variant<__R: #cratename::io::Read>(
fn deserialize_variant<__R: #cratename::io::Read,
Tag: borsh::BorshDeserialize + ::core::fmt::Debug + Eq
>(
reader: &mut __R,
variant_tag: u8,
variant_tag: u8, //#discriminant_type,
) -> ::core::result::Result<Self, #cratename::io::Error> {
let mut return_value =
#variant_arms {
Expand All @@ -56,7 +60,8 @@ pub fn process(input: &ItemEnum, cratename: Path) -> syn::Result<TokenStream2> {
Ok(return_value)
}
}
})
};
Ok(x)
}

fn process_variant(
Expand Down
25 changes: 17 additions & 8 deletions borsh-derive/src/internals/enum_discriminant.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
use core::convert::TryInto;
use std::collections::HashMap;
use std::convert::TryFrom;

use proc_macro2::{Ident, TokenStream};
use quote::quote;
use syn::{punctuated::Punctuated, token::Comma, Variant};
use syn::{parse::{Parse, ParseBuffer}, punctuated::Punctuated, token::{Comma, Type}, Path, Variant};

pub struct Discriminants(HashMap<Ident, TokenStream>);
pub struct Discriminants((HashMap<Ident, TokenStream>, syn::TypePath));
impl Discriminants {
/// Calculates the discriminant that will be assigned by the compiler.
/// See: https://doc.rust-lang.org/reference/items/enumerations.html#assigning-discriminant-values
pub fn new(variants: &Punctuated<Variant, Comma>) -> Self {
pub fn new(variants: &Punctuated<Variant, Comma>, maybe_discriminant_type: Option<syn::TypePath>) -> Self {
let mut map = HashMap::new();
let mut next_discriminant_if_not_specified = quote! {0};

Expand All @@ -18,12 +19,20 @@ impl Discriminants {
|| quote! { #next_discriminant_if_not_specified },
|(_, e)| quote! { #e },
);

next_discriminant_if_not_specified = quote! { #this_discriminant + 1 };
map.insert(variant.ident.clone(), this_discriminant);
}
let discriminant_type = //maybe_discriminant_type.unwrap_or(
syn::parse_str("u8").expect("numeric")
//)
;

Self((map, discriminant_type))
}

Self(map)
pub fn discriminant_type(&self) -> &syn::TypePath {
&self.0.1
}

pub fn get(
Expand All @@ -32,14 +41,14 @@ impl Discriminants {
use_discriminant: bool,
variant_idx: usize,
) -> syn::Result<TokenStream> {
let variant_idx = u8::try_from(variant_idx).map_err(|err| {
let variant_idx: u8 = u8::try_from(variant_idx).map_err(|err| {
syn::Error::new(
variant_ident.span(),
format!("up to 256 enum variants are supported: {}", err),
format!("up to {} enum variants are supported: {}", u8::MAX as usize + 1, err),
)
})?;
let result = if use_discriminant {
let discriminant_value = self.0.get(variant_ident).unwrap();
let discriminant_value = self.0.0.get(variant_ident).unwrap();
quote! { #discriminant_value }
} else {
quote! { #variant_idx }
Expand Down
7 changes: 4 additions & 3 deletions borsh-derive/src/internals/serialize/enums/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ pub fn process(input: &ItemEnum, cratename: Path) -> syn::Result<TokenStream2> {
let mut all_variants_idx_body = TokenStream2::new();
let mut fields_body = TokenStream2::new();
let use_discriminant = item::contains_use_discriminant(input)?;
let discriminants = Discriminants::new(&input.variants);
let maybe_discriminant_type = item::get_maybe_reprc_attribute(input);
let discriminants = Discriminants::new(&input.variants, maybe_discriminant_type);
let mut has_unit_variant = false;

for (variant_idx, variant) in input.variants.iter().enumerate() {
Expand All @@ -42,11 +43,11 @@ pub fn process(input: &ItemEnum, cratename: Path) -> syn::Result<TokenStream2> {
}
let fields_body = optimize_fields_body(fields_body, has_unit_variant);
generics_output.extend(&mut where_clause, &cratename);

let discriminant_type = discriminants.discriminant_type();
Ok(quote! {
impl #impl_generics #cratename::ser::BorshSerialize for #enum_ident #ty_generics #where_clause {
fn serialize<__W: #cratename::io::Write>(&self, writer: &mut __W) -> ::core::result::Result<(), #cratename::io::Error> {
let variant_idx: u8 = match self {
let variant_idx: #discriminant_type = match self {
#all_variants_idx_body
};
writer.write_all(&variant_idx.to_le_bytes())?;
Expand Down
2 changes: 1 addition & 1 deletion borsh/src/de/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ pub trait EnumExt: BorshDeserialize {
/// # #[cfg(feature = "derive")]
/// assert!(from_slice::<OneOrZero>(&data[..]).is_err());
/// ```
fn deserialize_variant<R: Read>(reader: &mut R, tag: u8) -> Result<Self>;
fn deserialize_variant<R: Read, Tag: BorshDeserialize + ::core::fmt::Debug + Eq>(reader: &mut R, tag: u8) -> Result<Self>;
}

fn unexpected_eof_to_unexpected_length_of_input(e: Error) -> Error {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,34 @@ fn test_discriminant_serde_no_unit_type() {
}
}

#[test]
pub fn u16_discriminat() {
use borsh::{BorshSerialize, BorshDeserialize};
#[derive(BorshSerialize, BorshDeserialize, Debug)]
#[borsh(use_discriminant = true)]
#[repr(u16)]
enum ZEnum {
AA = 42,
Z=2,
// A { a: u16, b: u64, d: bool, s: String } = 1u16,
// Z { a: u16, b: u64, d: bool, s: String } = 257u16,
}
let mut ss = vec![];
ZEnum::AA.serialize(&mut ss).unwrap();
assert!(ss[0] == 42);
assert!(ss[1] == 0);
// let s = Enum::A {
// a: 13,
// b: 42,
// d: true,
// s: "hello my bonny".to_string(),
// };
// let mut buf = Vec::new();
// s.serialize(&mut buf).expect("must serialize");
// panic!("{:?}", buf);
}


#[test]
fn test_discriminant_serde_no_unit_type_no_use_discriminant() {
let values = vec![
Expand Down
10 changes: 10 additions & 0 deletions rust-toolchain.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[toolchain]
channel = "1.81.0"

# Default profile includes `rustfmt`, `clippy`, `rust-docs`.
# https://rust-lang.github.io/rustup/concepts/profiles.html
profile = "default"
components = ["rust-analyzer"]

# For static linking for deployment.
targets = ["x86_64-unknown-linux-musl"]

0 comments on commit dca2e53

Please sign in to comment.