diff options
Diffstat (limited to 'ops/op2/signature.rs')
-rw-r--r-- | ops/op2/signature.rs | 243 |
1 files changed, 228 insertions, 15 deletions
diff --git a/ops/op2/signature.rs b/ops/op2/signature.rs index 6158b2a55..15c40e007 100644 --- a/ops/op2/signature.rs +++ b/ops/op2/signature.rs @@ -4,12 +4,15 @@ use proc_macro2::Ident; use proc_macro2::Span; use quote::quote; use quote::ToTokens; +use std::collections::BTreeMap; use strum::IntoEnumIterator; use strum::IntoStaticStr; use strum_macros::EnumIter; use strum_macros::EnumString; use syn2::Attribute; use syn2::FnArg; +use syn2::GenericParam; +use syn2::Generics; use syn2::Pat; use syn2::ReturnType; use syn2::Signature; @@ -136,9 +139,16 @@ pub enum RetVal { #[derive(Clone, Debug, Eq, PartialEq)] pub struct ParsedSignature { + // The parsed arguments pub args: Vec<Arg>, + // The argument names pub names: Vec<String>, + // The parsed return value pub ret_val: RetVal, + // One and only one lifetime allowed + pub lifetime: Option<String>, + // Generic bounds: each generic must have one and only simple trait bound + pub generic_bounds: BTreeMap<String, String>, } #[derive(Copy, Clone, Debug, Eq, PartialEq)] @@ -153,10 +163,24 @@ enum AttributeModifier { #[derive(Error, Debug)] pub enum SignatureError { - #[error("Invalid argument: {0}")] + #[error("Invalid argument: '{0}'")] ArgError(String, #[source] ArgError), #[error("Invalid return type")] RetError(#[from] ArgError), + #[error("Only one lifetime is permitted")] + TooManyLifetimes, + #[error("Generic '{0}' must have one and only bound (either <T> and 'where T: Trait', or <T: Trait>)")] + GenericBoundCardinality(String), + #[error("Where clause predicate '{0}' (eg: where T: Trait) must appear in generics list (eg: <T>)")] + WherePredicateMustAppearInGenerics(String), + #[error("All generics must appear only once in the generics parameter list or where clause")] + DuplicateGeneric(String), + #[error("Generic lifetime '{0}' may not have bounds (eg: <'a: 'b>)")] + LifetimesMayNotHaveBounds(String), + #[error("Invalid generic: '{0}' Only simple generics bounds are allowed (eg: T: Trait)")] + InvalidGeneric(String), + #[error("Invalid predicate: '{0}' Only simple where predicates are allowed (eg: T: Trait)")] + InvalidWherePredicate(String), } #[derive(Error, Debug)] @@ -216,13 +240,107 @@ pub fn parse_signature( parse_arg(input).map_err(|err| SignatureError::ArgError(name, err))?, ); } + let ret_val = + parse_return(parse_attributes(&attributes)?, &signature.output)?; + let lifetime = parse_lifetime(&signature.generics)?; + let generic_bounds = parse_generics(&signature.generics)?; Ok(ParsedSignature { args, names, - ret_val: parse_return(parse_attributes(&attributes)?, &signature.output)?, + ret_val, + lifetime, + generic_bounds, }) } +/// Extract one lifetime from the [`syn2::Generics`], ensuring that the lifetime is valid +/// and has no bounds. +fn parse_lifetime( + generics: &Generics, +) -> Result<Option<String>, SignatureError> { + let mut res = None; + for param in &generics.params { + if let GenericParam::Lifetime(lt) = param { + if !lt.bounds.is_empty() { + return Err(SignatureError::LifetimesMayNotHaveBounds( + lt.lifetime.to_string(), + )); + } + if res.is_some() { + return Err(SignatureError::TooManyLifetimes); + } + res = Some(lt.lifetime.ident.to_string()); + } + } + Ok(res) +} + +/// Parse and validate generics. We require one and only one trait bound for each generic +/// parameter. Tries to sanity check and return reasonable errors for possible signature errors. +fn parse_generics( + generics: &Generics, +) -> Result<BTreeMap<String, String>, SignatureError> { + let mut where_clauses = BTreeMap::new(); + + // First, extract the where clause so we can detect duplicated predicates + if let Some(where_clause) = &generics.where_clause { + for predicate in &where_clause.predicates { + let predicate = predicate.to_token_stream(); + let (generic_name, bound) = std::panic::catch_unwind(|| { + use syn2 as syn; + rules!(predicate => { + ($t:ident : $bound:path) => (t.to_string(), stringify_token(bound)), + }) + }) + .map_err(|_| { + SignatureError::InvalidWherePredicate(predicate.to_string()) + })?; + if where_clauses.insert(generic_name.clone(), bound).is_some() { + return Err(SignatureError::DuplicateGeneric(generic_name)); + } + } + } + + let mut res = BTreeMap::new(); + for param in &generics.params { + if let GenericParam::Type(ty) = param { + let ty = ty.to_token_stream(); + let (name, bound) = std::panic::catch_unwind(|| { + use syn2 as syn; + rules!(ty => { + ($t:ident : $bound:path) => (t.to_string(), Some(stringify_token(bound))), + ($t:ident) => (t.to_string(), None), + }) + }).map_err(|_| SignatureError::InvalidGeneric(ty.to_string()))?; + let bound = match bound { + Some(bound) => { + if where_clauses.contains_key(&name) { + return Err(SignatureError::GenericBoundCardinality(name)); + } + bound + } + None => { + let Some(bound) = where_clauses.remove(&name) else { + return Err(SignatureError::GenericBoundCardinality(name)); + }; + bound + } + }; + if res.contains_key(&name) { + return Err(SignatureError::DuplicateGeneric(name)); + } + res.insert(name, bound); + } + } + if !where_clauses.is_empty() { + return Err(SignatureError::WherePredicateMustAppearInGenerics( + where_clauses.into_keys().next().unwrap(), + )); + } + + Ok(res) +} + fn parse_attributes(attributes: &[Attribute]) -> Result<Attributes, ArgError> { let attrs = attributes .iter() @@ -447,11 +565,27 @@ mod tests { // We can't test pattern args :/ // https://github.com/rust-lang/rfcs/issues/2688 macro_rules! test { - ( $(# [ $fn_attr:ident ])? fn $name:ident ( $( $(# [ $attr:ident ])? $ident:ident : $ty:ty ),* ) $(-> $(# [ $ret_attr:ident ])? $ret:ty)?, ( $( $arg_res:expr ),* ) -> $ret_res:expr ) => { + ( + // Function attributes + $(# [ $fn_attr:ident ])? + // fn name < 'scope, GENERIC1, GENERIC2, ... > + fn $name:ident $( < $scope:lifetime $( , $generic:ident)* >)? + ( + // Argument attribute, argument + $( $(# [ $attr:ident ])? $ident:ident : $ty:ty ),* + ) + // Return value + $(-> $(# [ $ret_attr:ident ])? $ret:ty)? + // Where clause + $( where $($trait:ident : $bounds:path),* )? + ; + // Expected return value + $( < $( $lifetime_res:lifetime )? $(, $generic_res:ident : $bounds_res:path )* >)? ( $( $arg_res:expr ),* ) -> $ret_res:expr ) => { #[test] fn $name() { test( - stringify!($( #[$fn_attr] )? fn op( $( $( #[$attr] )? $ident : $ty ),* ) $(-> $( #[$ret_attr] )? $ret)? {}), + stringify!($( #[$fn_attr] )? fn op $( < $scope $( , $generic)* >)? ( $( $( #[$attr] )? $ident : $ty ),* ) $(-> $( #[$ret_attr] )? $ret)? $( where $($trait : $bounds),* )? {}), + stringify!($( < $( $lifetime_res )? $(, $generic_res : $bounds_res)* > )?), stringify!($($arg_res),*), stringify!($ret_res) ); @@ -459,14 +593,35 @@ mod tests { }; } - fn test(op: &str, args_expected: &str, return_expected: &str) { + fn test( + op: &str, + generics_expected: &str, + args_expected: &str, + return_expected: &str, + ) { + // Parse the provided macro input as an ItemFn let item_fn = parse_str::<ItemFn>(op) .unwrap_or_else(|_| panic!("Failed to parse {op} as a ItemFn")); + let attrs = item_fn.attrs; - let sig = parse_signature(attrs, item_fn.sig).unwrap_or_else(|_| { - panic!("Failed to successfully parse signature from {op}") + let sig = parse_signature(attrs, item_fn.sig).unwrap_or_else(|err| { + panic!("Failed to successfully parse signature from {op} ({err:?})") }); + println!("Raw parsed signatures = {sig:?}"); + let mut generics_res = vec![]; + if let Some(lifetime) = sig.lifetime { + generics_res.push(format!("'{lifetime}")); + } + for (name, bounds) in sig.generic_bounds { + generics_res.push(format!("{name} : {bounds}")); + } + if !generics_res.is_empty() { + assert_eq!( + generics_expected, + format!("< {} >", generics_res.join(", ")) + ); + } assert_eq!( args_expected, format!("{:?}", sig.args).trim_matches(|c| c == '[' || c == ']') @@ -474,38 +629,96 @@ mod tests { assert_eq!(return_expected, format!("{:?}", sig.ret_val)); } + macro_rules! expect_fail { + ($name:ident, $error:expr, $f:item) => { + #[test] + pub fn $name() { + expect_fail(stringify!($f), stringify!($error)); + } + }; + } + + fn expect_fail(op: &str, error: &str) { + // Parse the provided macro input as an ItemFn + let item_fn = parse_str::<ItemFn>(op) + .unwrap_or_else(|_| panic!("Failed to parse {op} as a ItemFn")); + let attrs = item_fn.attrs; + let err = parse_signature(attrs, item_fn.sig) + .expect_err("Expected function to fail to parse"); + assert_eq!(format!("{err:?}"), error.to_owned()); + } + test!( - fn op_state_and_number(opstate: &mut OpState, a: u32) -> (), + fn op_state_and_number(opstate: &mut OpState, a: u32) -> (); (Ref(Mut, OpState), Numeric(u32)) -> Infallible(Void) ); test!( - fn op_slices(r#in: &[u8], out: &mut [u8]), + fn op_slices(r#in: &[u8], out: &mut [u8]); (Slice(Ref, u8), Slice(Mut, u8)) -> Infallible(Void) ); test!( - #[serde] fn op_serde(#[serde] input: package::SerdeInputType) -> Result<package::SerdeReturnType, Error>, + #[serde] fn op_serde(#[serde] input: package::SerdeInputType) -> Result<package::SerdeReturnType, Error>; (SerdeV8("package::SerdeInputType")) -> Result(SerdeV8("package::SerdeReturnType")) ); test!( - fn op_local(input: v8::Local<v8::String>) -> Result<v8::Local<v8::String>, Error>, + fn op_local(input: v8::Local<v8::String>) -> Result<v8::Local<v8::String>, Error>; (V8Local(String)) -> Result(V8Local(String)) ); test!( - fn op_resource(#[smi] rid: ResourceId, buffer: &[u8]), + fn op_resource(#[smi] rid: ResourceId, buffer: &[u8]); (Numeric(__SMI__), Slice(Ref, u8)) -> Infallible(Void) ); test!( - fn op_option_numeric_result(state: &mut OpState) -> Result<Option<u32>, AnyError>, + fn op_option_numeric_result(state: &mut OpState) -> Result<Option<u32>, AnyError>; (Ref(Mut, OpState)) -> Result(OptionNumeric(u32)) ); test!( - fn op_ffi_read_f64(state: &mut OpState, ptr: * mut c_void, offset: isize) -> Result <f64, AnyError>, + fn op_ffi_read_f64(state: &mut OpState, ptr: * mut c_void, offset: isize) -> Result <f64, AnyError>; (Ref(Mut, OpState), Ptr(Mut, __VOID__), Numeric(isize)) -> Result(Numeric(f64)) ); test!( - fn op_print(#[string] msg: &str, is_err: bool) -> Result<(), Error>, + fn op_print(#[string] msg: &str, is_err: bool) -> Result<(), Error>; (Special(RefStr), Numeric(bool)) -> Result(Void) ); + test!( + fn op_scope<'s>(#[string] msg: &'s str); + <'s> (Special(RefStr)) -> Infallible(Void) + ); + test!( + fn op_scope_and_generics<'s, AB, BC>(#[string] msg: &'s str) where AB: some::Trait, BC: OtherTrait; + <'s, AB: some::Trait, BC: OtherTrait> (Special(RefStr)) -> Infallible(Void) + ); + + expect_fail!(op_with_two_lifetimes, TooManyLifetimes, fn f<'a, 'b>() {}); + expect_fail!( + op_with_lifetime_bounds, + LifetimesMayNotHaveBounds("'a"), + fn f<'a: 'b, 'b>() {} + ); + expect_fail!( + op_with_missing_bounds, + GenericBoundCardinality("B"), + fn f<'a, B>() {} + ); + expect_fail!( + op_with_duplicate_bounds, + GenericBoundCardinality("B"), + fn f<'a, B: Trait>() + where + B: Trait, + { + } + ); + expect_fail!( + op_with_extra_bounds, + WherePredicateMustAppearInGenerics("C"), + fn f<'a, B>() + where + B: Trait, + C: Trait, + { + } + ); #[test] fn test_parse_result() { |