From d2d81a4bfddd0fa86cb0551c2aa359daeebcfc9f Mon Sep 17 00:00:00 2001 From: David Thomas Date: Wed, 29 Nov 2023 10:25:06 +0000 Subject: [PATCH] Fix argument parsing broken by serenity 0.12 port --- macros/src/command/slash.rs | 14 +-- src/slash_argument/slash_macro.rs | 54 +++++++++- src/slash_argument/slash_trait.rs | 162 +++++++++++------------------- 3 files changed, 110 insertions(+), 120 deletions(-) diff --git a/macros/src/command/slash.rs b/macros/src/command/slash.rs index 0e22e5edb6a9..d7a36d2dac4a 100644 --- a/macros/src/command/slash.rs +++ b/macros/src/command/slash.rs @@ -163,19 +163,7 @@ pub fn generate_slash_action(inv: &Invocation) -> Result #( (#param_names: #param_types), )* - ).await.map_err(|error| match error { - poise::SlashArgError::CommandStructureMismatch { description, .. } => { - poise::FrameworkError::new_command_structure_mismatch(ctx, description) - }, - poise::SlashArgError::Parse { error, input, .. } => { - poise::FrameworkError::new_argument_parse( - ctx.into(), - Some(input), - error, - ) - }, - poise::SlashArgError::__NonExhaustive => unreachable!(), - })?; + ).await.map_err(|error| error.to_framework_error(ctx))?; if !ctx.framework.options.manual_cooldowns { ctx.command.cooldowns.lock().unwrap().start_cooldown(ctx.cooldown_context()); diff --git a/src/slash_argument/slash_macro.rs b/src/slash_argument/slash_macro.rs index 15c419f81b7d..8f9baeb2b481 100644 --- a/src/slash_argument/slash_macro.rs +++ b/src/slash_argument/slash_macro.rs @@ -26,28 +26,73 @@ pub enum SlashArgError { /// Original input string input: String, }, + /// The argument passed by the user is invalid in this context. E.g. a Member parameter in DMs + #[non_exhaustive] + Invalid( + /// Human readable description of the error + &'static str, + ), + /// HTTP error occured while retrieving the model type from Discord + Http(serenity::Error), #[doc(hidden)] __NonExhaustive, } + /// Support functions for macro which can't create #[non_exhaustive] enum variants #[doc(hidden)] impl SlashArgError { pub fn new_command_structure_mismatch(description: &'static str) -> Self { Self::CommandStructureMismatch { description } } + + pub fn to_framework_error( + self, + ctx: crate::ApplicationContext<'_, U, E>, + ) -> crate::FrameworkError<'_, U, E> { + match self { + Self::CommandStructureMismatch { description } => { + crate::FrameworkError::CommandStructureMismatch { ctx, description } + } + Self::Parse { error, input } => crate::FrameworkError::ArgumentParse { + ctx: ctx.into(), + error, + input: Some(input), + }, + Self::Invalid(description) => crate::FrameworkError::ArgumentParse { + ctx: ctx.into(), + error: description.into(), + input: None, + }, + Self::Http(error) => crate::FrameworkError::ArgumentParse { + ctx: ctx.into(), + error: error.into(), + input: None, + }, + Self::__NonExhaustive => unreachable!(), + } + } } + impl std::fmt::Display for SlashArgError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::CommandStructureMismatch { description } => { write!( f, - "Bot author did not register their commands correctly ({})", - description + "Bot author did not register their commands correctly ({description})", ) } Self::Parse { error, input } => { - write!(f, "Failed to parse `{}` as argument: {}", input, error) + write!(f, "Failed to parse `{input}` as argument: {error}") + } + Self::Invalid(description) => { + write!(f, "You can't use this parameter here: {description}",) + } + Self::Http(error) => { + write!( + f, + "Error occured while retrieving data from Discord: {error}", + ) } Self::__NonExhaustive => unreachable!(), } @@ -56,8 +101,9 @@ impl std::fmt::Display for SlashArgError { impl std::error::Error for SlashArgError { fn cause(&self) -> Option<&dyn std::error::Error> { match self { + Self::Http(error) => Some(error), Self::Parse { error, input: _ } => Some(&**error), - Self::CommandStructureMismatch { description: _ } => None, + Self::Invalid { .. } | Self::CommandStructureMismatch { .. } => None, Self::__NonExhaustive => unreachable!(), } } diff --git a/src/slash_argument/slash_trait.rs b/src/slash_argument/slash_trait.rs index a7bcb1bd8e4f..df52687a0d32 100644 --- a/src/slash_argument/slash_trait.rs +++ b/src/slash_argument/slash_trait.rs @@ -139,14 +139,16 @@ macro_rules! impl_for_integer { _: &serenity::CommandInteraction, value: &serenity::ResolvedValue<'_>, ) -> Result<$t, SlashArgError> { - let value = match value { - serenity::ResolvedValue::Integer(int) => *int, - _ => return Err(SlashArgError::CommandStructureMismatch { description: "expected integer" }) - }; - - value - .try_into() - .map_err(|_| SlashArgError::CommandStructureMismatch { description: "received out of bounds integer" }) + match *value { + serenity::ResolvedValue::Integer(x) => x + .try_into() + .map_err(|_| SlashArgError::CommandStructureMismatch { + description: "received out of bounds integer", + }), + _ => Err(SlashArgError::CommandStructureMismatch { + description: "expected integer", + }), + } } fn create(builder: serenity::CreateCommandOption) -> serenity::CreateCommandOption { @@ -160,90 +162,6 @@ macro_rules! impl_for_integer { } impl_for_integer!(i8 i16 i32 i64 isize u8 u16 u32 u64 usize); -/// Implements slash argument trait for float types -macro_rules! impl_for_float { - ($($t:ty)*) => { $( - #[async_trait::async_trait] - impl SlashArgumentHack<$t> for &PhantomData<$t> { - async fn extract( - self, - _: &serenity::Context, - _: &serenity::CommandInteraction, - value: &serenity::ResolvedValue<'_>, - ) -> Result<$t, SlashArgError> { - match value { - serenity::ResolvedValue::Number(float) => Ok(*float as $t), - _ => Err(SlashArgError::CommandStructureMismatch { description: "expected float" }) - } - } - - fn create(self, builder: serenity::CreateCommandOption) -> serenity::CreateCommandOption { - builder.kind(serenity::CommandOptionType::Number) - } - } - )* }; -} -impl_for_float!(f32 f64); - -#[async_trait::async_trait] -impl SlashArgumentHack for &PhantomData { - async fn extract( - self, - _: &serenity::Context, - _: &serenity::CommandInteraction, - value: &serenity::ResolvedValue<'_>, - ) -> Result { - match value { - serenity::ResolvedValue::Boolean(val) => Ok(*val), - _ => Err(SlashArgError::CommandStructureMismatch { - description: "expected bool", - }), - } - } - - fn create(self, builder: serenity::CreateCommandOption) -> serenity::CreateCommandOption { - builder.kind(serenity::CommandOptionType::Boolean) - } -} - -#[async_trait::async_trait] -impl SlashArgumentHack for &PhantomData { - async fn extract( - self, - _: &serenity::Context, - interaction: &serenity::CommandInteraction, - value: &serenity::ResolvedValue<'_>, - ) -> Result { - let attachment_id = match value { - serenity::ResolvedValue::String(val) => { - val.parse() - .map_err(|_| SlashArgError::CommandStructureMismatch { - description: "improper attachment id passed", - })? - } - _ => { - return Err(SlashArgError::CommandStructureMismatch { - description: "expected attachment id", - }) - } - }; - - interaction - .data - .resolved - .attachments - .get(&attachment_id) - .cloned() - .ok_or(SlashArgError::CommandStructureMismatch { - description: "attachment id with no attachment", - }) - } - - fn create(self, builder: serenity::CreateCommandOption) -> serenity::CreateCommandOption { - builder.kind(serenity::CommandOptionType::Attachment) - } -} - #[async_trait::async_trait] impl SlashArgumentHack for &PhantomData { async fn extract( @@ -264,18 +182,22 @@ impl SlashArgumentHack for &PhantomData { } } -/// Implements `SlashArgumentHack` for a model type that is represented in interactions via an ID +/// Versatile macro to implement `SlashArgumentHack` for simple types macro_rules! impl_slash_argument { - ($type:ty, $slash_param_type:ident) => { + ($type:ty, |$ctx:pat, $interaction:pat, $slash_param_type:ident ( $($arg:pat),* )| $extractor:expr) => { #[async_trait::async_trait] impl SlashArgument for $type { async fn extract( - ctx: &serenity::Context, - interaction: &serenity::CommandInteraction, + $ctx: &serenity::Context, + $interaction: &serenity::CommandInteraction, value: &serenity::ResolvedValue<'_>, ) -> Result<$type, SlashArgError> { - // We can parse IDs by falling back to the generic serenity::ArgumentConvert impl - PhantomData::<$type>.extract(ctx, interaction, value).await + match *value { + serenity::ResolvedValue::$slash_param_type( $($arg),* ) => Ok( $extractor ), + _ => Err(SlashArgError::CommandStructureMismatch { + description: concat!("expected ", stringify!($slash_param_type)) + }), + } } fn create(builder: serenity::CreateCommandOption) -> serenity::CreateCommandOption { @@ -284,8 +206,42 @@ macro_rules! impl_slash_argument { } }; } -impl_slash_argument!(serenity::Member, User); -impl_slash_argument!(serenity::User, User); -impl_slash_argument!(serenity::Channel, Channel); -impl_slash_argument!(serenity::GuildChannel, Channel); -impl_slash_argument!(serenity::Role, Role); + +impl_slash_argument!(f32, |_, _, Number(x)| x as f32); +impl_slash_argument!(f64, |_, _, Number(x)| x); +impl_slash_argument!(bool, |_, _, Boolean(x)| x); +impl_slash_argument!(serenity::Attachment, |_, _, Attachment(att)| att.clone()); +impl_slash_argument!(serenity::Member, |ctx, interaction, User(user, _)| { + interaction + .guild_id + .ok_or(SlashArgError::Invalid("cannot use member parameter in DMs"))? + .member(ctx, user.id) + .await + .map_err(SlashArgError::Http)? +}); +impl_slash_argument!(serenity::PartialMember, |_, _, User(_, member)| { + member + .ok_or(SlashArgError::Invalid("cannot use member parameter in DMs"))? + .clone() +}); +impl_slash_argument!(serenity::User, |_, _, User(user, _)| user.clone()); +impl_slash_argument!(serenity::UserId, |_, _, User(user, _)| user.id); +impl_slash_argument!(serenity::Channel, |ctx, _, Channel(channel)| { + channel + .id + .to_channel(ctx) + .await + .map_err(SlashArgError::Http)? +}); +impl_slash_argument!(serenity::ChannelId, |_, _, Channel(channel)| channel.id); +impl_slash_argument!(serenity::PartialChannel, |_, _, Channel(channel)| channel + .clone()); +impl_slash_argument!(serenity::GuildChannel, |ctx, _, Channel(channel)| { + let channel_res = channel.id.to_channel(ctx).await; + let channel = channel_res.map_err(SlashArgError::Http)?.guild(); + channel.ok_or(SlashArgError::Http(serenity::Error::Model( + serenity::ModelError::InvalidChannelType, + )))? +}); +impl_slash_argument!(serenity::Role, |_, _, Role(role)| role.clone()); +impl_slash_argument!(serenity::RoleId, |_, _, Role(role)| role.id);