From e2c4db5e9b4ec97a891f3178c5de0218c97d45ac Mon Sep 17 00:00:00 2001 From: Campbell He Date: Mon, 29 Jul 2024 04:44:25 +0000 Subject: [PATCH] fix: filter catch_all variant in EnumInfo methods This commit filters out catch_all variant in the `variant_idents` and `variant_expressions` of `EnumInfo`. This allows the catch_all variant to be in the middle of enum. A test of catch_all variant in the middle of enum is also added. Fix #149 --- num_enum/tests/from_primitive.rs | 24 ++++++++++++++++++++++++ num_enum_derive/src/parsing.rs | 2 ++ 2 files changed, 26 insertions(+) diff --git a/num_enum/tests/from_primitive.rs b/num_enum/tests/from_primitive.rs index da52d43..de61d6b 100644 --- a/num_enum/tests/from_primitive.rs +++ b/num_enum/tests/from_primitive.rs @@ -114,6 +114,30 @@ fn from_primitive_number_catch_all() { assert_eq!(two, Enum::NonZero(2_u8)); } +#[test] +fn from_primitive_number_catch_all_in_middle() { + #[derive(Debug, PartialEq, Eq, FromPrimitive)] + #[repr(u8)] + enum Enum { + Zero = 0, + #[num_enum(catch_all)] + Else(u8) = 2, + One = 1, + } + + let zero = Enum::from_primitive(0_u8); + assert_eq!(zero, Enum::Zero); + + let one = Enum::from_primitive(1_u8); + assert_eq!(one, Enum::One); + + let two = Enum::from_primitive(2_u8); + assert_eq!(two, Enum::Else(2_u8)); + + let three = Enum::from_primitive(3_u8); + assert_eq!(three, Enum::Else(3_u8)); +} + #[cfg(feature = "complex-expressions")] #[test] fn from_primitive_number_with_inclusive_range() { diff --git a/num_enum_derive/src/parsing.rs b/num_enum_derive/src/parsing.rs index cfc0b46..1a24b0c 100644 --- a/num_enum_derive/src/parsing.rs +++ b/num_enum_derive/src/parsing.rs @@ -61,6 +61,7 @@ impl EnumInfo { pub(crate) fn variant_idents(&self) -> Vec { self.variants .iter() + .filter(|variant| !variant.is_catch_all) .map(|variant| variant.ident.clone()) .collect() } @@ -81,6 +82,7 @@ impl EnumInfo { pub(crate) fn variant_expressions(&self) -> Vec> { self.variants .iter() + .filter(|variant| !variant.is_catch_all) .map(|variant| variant.all_values().cloned().collect()) .collect() }