diff options
author | Divy Srivastava <dj.srivastava23@gmail.com> | 2022-11-10 03:53:31 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-11-10 17:23:31 +0530 |
commit | bc33a4b2e06dd5518e0d1bbf7b538d0b00df214d (patch) | |
tree | e139e95178892521ecb5807959e324422ed29045 /ops/optimizer.rs | |
parent | 92764c0decb370b0f8a78770314ceda7228d315f (diff) |
refactor(ops): Rewrite fast call optimizer and codegen (#16514)
Diffstat (limited to 'ops/optimizer.rs')
-rw-r--r-- | ops/optimizer.rs | 600 |
1 files changed, 600 insertions, 0 deletions
diff --git a/ops/optimizer.rs b/ops/optimizer.rs new file mode 100644 index 000000000..3e3887549 --- /dev/null +++ b/ops/optimizer.rs @@ -0,0 +1,600 @@ +/// Optimizer for #[op] +use crate::Op; +use pmutil::{q, Quote}; +use proc_macro2::TokenStream; +use quote::{quote, ToTokens}; +use std::collections::HashMap; +use std::fmt::Debug; +use std::fmt::Formatter; +use syn::{ + parse_quote, punctuated::Punctuated, token::Colon2, + AngleBracketedGenericArguments, FnArg, GenericArgument, PatType, Path, + PathArguments, PathSegment, ReturnType, Signature, Type, TypePath, + TypeReference, TypeSlice, +}; + +#[derive(Debug)] +pub(crate) enum BailoutReason { + // Recoverable errors + MustBeSingleSegment, + FastUnsupportedParamType, + + FastAsync, +} + +impl ToTokens for BailoutReason { + fn to_tokens(&self, tokens: &mut TokenStream) { + match self { + BailoutReason::FastAsync => { + tokens.extend(quote! { "fast async calls are not supported" }); + } + BailoutReason::MustBeSingleSegment + | BailoutReason::FastUnsupportedParamType => { + unreachable!("error not recovered"); + } + } + } +} + +#[derive(Debug, PartialEq)] +enum TransformKind { + // serde_v8::Value + V8Value, + SliceU32(bool), + SliceU8(bool), +} + +impl Transform { + fn serde_v8_value(index: usize) -> Self { + Transform { + kind: TransformKind::V8Value, + index, + } + } + + fn slice_u32(index: usize, is_mut: bool) -> Self { + Transform { + kind: TransformKind::SliceU32(is_mut), + index, + } + } + + fn slice_u8(index: usize, is_mut: bool) -> Self { + Transform { + kind: TransformKind::SliceU8(is_mut), + index, + } + } +} + +#[derive(Debug, PartialEq)] +pub(crate) struct Transform { + kind: TransformKind, + index: usize, +} + +impl Transform { + pub(crate) fn apply_for_fast_call( + &self, + core: &TokenStream, + input: &mut FnArg, + ) -> Quote { + let (ty, ident) = match input { + FnArg::Typed(PatType { + ref mut ty, + ref pat, + .. + }) => { + let ident = match &**pat { + syn::Pat::Ident(ident) => &ident.ident, + _ => unreachable!("error not recovered"), + }; + (ty, ident) + } + _ => unreachable!("error not recovered"), + }; + + match &self.kind { + // serde_v8::Value + TransformKind::V8Value => { + *ty = parse_quote! { #core::v8::Local<v8::Value> }; + + q!(Vars { var: &ident }, { + let var = serde_v8::Value { v8_value: var }; + }) + } + // &[u32] + TransformKind::SliceU32(_) => { + *ty = + parse_quote! { *const #core::v8::fast_api::FastApiTypedArray<u32> }; + + q!(Vars { var: &ident }, { + let var = match unsafe { &*var }.get_storage_if_aligned() { + Some(v) => v, + None => { + unsafe { &mut *fast_api_callback_options }.fallback = true; + return Default::default(); + } + }; + }) + } + // &[u8] + TransformKind::SliceU8(_) => { + *ty = + parse_quote! { *const #core::v8::fast_api::FastApiTypedArray<u8> }; + + q!(Vars { var: &ident }, { + let var = match unsafe { &*var }.get_storage_if_aligned() { + Some(v) => v, + None => { + unsafe { &mut *fast_api_callback_options }.fallback = true; + return Default::default(); + } + }; + }) + } + } + } +} + +fn get_fast_scalar(s: &str) -> Option<FastValue> { + match s { + "u32" => Some(FastValue::U32), + "i32" => Some(FastValue::I32), + "u64" => Some(FastValue::U64), + "i64" => Some(FastValue::I64), + "f32" => Some(FastValue::F32), + "f64" => Some(FastValue::F64), + "bool" => Some(FastValue::Bool), + "ResourceId" => Some(FastValue::U32), + _ => None, + } +} + +fn can_return_fast(v: &FastValue) -> bool { + !matches!( + v, + FastValue::U64 + | FastValue::I64 + | FastValue::Uint8Array + | FastValue::Uint32Array + ) +} + +#[derive(Debug, PartialEq, Clone)] +pub(crate) enum FastValue { + Void, + U32, + I32, + U64, + I64, + F32, + F64, + Bool, + V8Value, + Uint8Array, + Uint32Array, +} + +impl Default for FastValue { + fn default() -> Self { + Self::Void + } +} + +#[derive(Default, PartialEq)] +pub(crate) struct Optimizer { + pub(crate) returns_result: bool, + + pub(crate) has_ref_opstate: bool, + + pub(crate) has_rc_opstate: bool, + + pub(crate) has_fast_callback_option: bool, + + pub(crate) fast_result: Option<FastValue>, + pub(crate) fast_parameters: Vec<FastValue>, + + pub(crate) transforms: HashMap<usize, Transform>, + pub(crate) fast_compatible: bool, +} + +impl Debug for Optimizer { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + writeln!(f, "=== Optimizer Dump ===")?; + writeln!(f, "returns_result: {}", self.returns_result)?; + writeln!(f, "has_ref_opstate: {}", self.has_ref_opstate)?; + writeln!(f, "has_rc_opstate: {}", self.has_rc_opstate)?; + writeln!( + f, + "has_fast_callback_option: {}", + self.has_fast_callback_option + )?; + writeln!(f, "fast_result: {:?}", self.fast_result)?; + writeln!(f, "fast_parameters: {:?}", self.fast_parameters)?; + writeln!(f, "transforms: {:?}", self.transforms)?; + Ok(()) + } +} + +impl Optimizer { + pub(crate) fn new() -> Self { + Default::default() + } + + pub(crate) const fn has_opstate_in_parameters(&self) -> bool { + self.has_ref_opstate || self.has_rc_opstate + } + + pub(crate) const fn needs_opstate(&self) -> bool { + self.has_ref_opstate || self.has_rc_opstate || self.returns_result + } + + pub(crate) fn analyze(&mut self, op: &mut Op) -> Result<(), BailoutReason> { + if op.is_async && op.attrs.must_be_fast { + self.fast_compatible = false; + return Err(BailoutReason::FastAsync); + } + + if op.attrs.is_v8 || op.is_async { + self.fast_compatible = false; + return Ok(()); + } + + self.fast_compatible = true; + let sig = &op.item.sig; + + // Analyze return type + match &sig { + Signature { + output: ReturnType::Default, + .. + } => self.fast_result = Some(FastValue::default()), + Signature { + output: ReturnType::Type(_, ty), + .. + } => self.analyze_return_type(ty)?, + }; + + // The reciever, which we don't actually care about. + self.fast_parameters.push(FastValue::V8Value); + + // Analyze parameters + for (index, param) in sig.inputs.iter().enumerate() { + self.analyze_param_type(index, param)?; + } + + Ok(()) + } + + fn analyze_return_type(&mut self, ty: &Type) -> Result<(), BailoutReason> { + match ty { + Type::Path(TypePath { + path: Path { segments, .. }, + .. + }) => { + let segment = single_segment(segments)?; + + match segment { + // Result<T, E> + PathSegment { + ident, arguments, .. + } if ident == "Result" => { + self.returns_result = true; + + if let PathArguments::AngleBracketed( + AngleBracketedGenericArguments { args, .. }, + ) = arguments + { + match args.first() { + Some(GenericArgument::Type(Type::Path(TypePath { + path: Path { segments, .. }, + .. + }))) => { + let PathSegment { ident, .. } = single_segment(segments)?; + // Is `T` a scalar FastValue? + if let Some(val) = get_fast_scalar(ident.to_string().as_str()) + { + if can_return_fast(&val) { + self.fast_result = Some(val); + return Ok(()); + } + } + + self.fast_compatible = false; + return Err(BailoutReason::FastUnsupportedParamType); + } + _ => return Err(BailoutReason::FastUnsupportedParamType), + } + } + } + // Is `T` a scalar FastValue? + PathSegment { ident, .. } => { + if let Some(val) = get_fast_scalar(ident.to_string().as_str()) { + self.fast_result = Some(val); + return Ok(()); + } + + self.fast_compatible = false; + return Err(BailoutReason::FastUnsupportedParamType); + } + }; + } + _ => return Err(BailoutReason::FastUnsupportedParamType), + }; + + Ok(()) + } + + fn analyze_param_type( + &mut self, + index: usize, + arg: &FnArg, + ) -> Result<(), BailoutReason> { + match arg { + FnArg::Typed(typed) => match &*typed.ty { + Type::Path(TypePath { + path: Path { segments, .. }, + .. + }) if segments.len() == 2 => { + match double_segment(segments)? { + // -> serde_v8::Value + [PathSegment { ident: first, .. }, PathSegment { ident: last, .. }] + if first == "serde_v8" && last == "Value" => + { + self.fast_parameters.push(FastValue::V8Value); + assert!(self + .transforms + .insert(index, Transform::serde_v8_value(index)) + .is_none()); + } + _ => return Err(BailoutReason::FastUnsupportedParamType), + } + } + Type::Path(TypePath { + path: Path { segments, .. }, + .. + }) => { + let segment = single_segment(segments)?; + + match segment { + // -> Option<T> + PathSegment { + ident, arguments, .. + } if ident == "Option" => { + if let PathArguments::AngleBracketed( + AngleBracketedGenericArguments { args, .. }, + ) = arguments + { + // -> Option<&mut T> + if let Some(GenericArgument::Type(Type::Reference( + TypeReference { elem, .. }, + ))) = args.last() + { + if let Type::Path(TypePath { + path: Path { segments, .. }, + .. + }) = &**elem + { + let segment = single_segment(segments)?; + match segment { + // Is `T` a FastApiCallbackOption? + PathSegment { ident, .. } + if ident == "FastApiCallbackOption" => + { + self.has_fast_callback_option = true; + } + _ => {} + } + } + } + } + } + // -> Rc<T> + PathSegment { + ident, arguments, .. + } if ident == "Rc" => { + if let PathArguments::AngleBracketed( + AngleBracketedGenericArguments { args, .. }, + ) = arguments + { + match args.last() { + Some(GenericArgument::Type(Type::Path(TypePath { + path: Path { segments, .. }, + .. + }))) => { + let segment = single_segment(segments)?; + match segment { + // -> Rc<RefCell<T>> + PathSegment { ident, .. } if ident == "RefCell" => { + if let PathArguments::AngleBracketed( + AngleBracketedGenericArguments { args, .. }, + ) = arguments + { + match args.last() { + // -> Rc<RefCell<OpState>> + Some(GenericArgument::Type(Type::Path( + TypePath { + path: Path { segments, .. }, + .. + }, + ))) => { + let segment = single_segment(segments)?; + match segment { + PathSegment { ident, .. } + if ident == "OpState" => + { + self.has_rc_opstate = true; + } + _ => { + return Err( + BailoutReason::FastUnsupportedParamType, + ) + } + } + } + _ => { + return Err( + BailoutReason::FastUnsupportedParamType, + ) + } + } + } + } + _ => return Err(BailoutReason::FastUnsupportedParamType), + } + } + _ => return Err(BailoutReason::FastUnsupportedParamType), + } + } + } + // Is `T` a fast scalar? + PathSegment { ident, .. } => { + if let Some(val) = get_fast_scalar(ident.to_string().as_str()) { + self.fast_parameters.push(val); + } else { + return Err(BailoutReason::FastUnsupportedParamType); + } + } + }; + } + // &mut T + Type::Reference(TypeReference { + elem, mutability, .. + }) => match &**elem { + Type::Path(TypePath { + path: Path { segments, .. }, + .. + }) => { + let segment = single_segment(segments)?; + match segment { + // Is `T` a OpState? + PathSegment { ident, .. } if ident == "OpState" => { + self.has_ref_opstate = true; + } + _ => return Err(BailoutReason::FastUnsupportedParamType), + } + } + // &mut [T] + Type::Slice(TypeSlice { elem, .. }) => match &**elem { + Type::Path(TypePath { + path: Path { segments, .. }, + .. + }) => { + let segment = single_segment(segments)?; + let is_mut_ref = mutability.is_some(); + match segment { + // Is `T` a u8? + PathSegment { ident, .. } if ident == "u8" => { + self.has_fast_callback_option = true; + self.fast_parameters.push(FastValue::Uint8Array); + assert!(self + .transforms + .insert(index, Transform::slice_u8(index, is_mut_ref)) + .is_none()); + } + // Is `T` a u32? + PathSegment { ident, .. } if ident == "u32" => { + self.has_fast_callback_option = true; + self.fast_parameters.push(FastValue::Uint32Array); + assert!(self + .transforms + .insert(index, Transform::slice_u32(index, is_mut_ref)) + .is_none()); + } + _ => return Err(BailoutReason::FastUnsupportedParamType), + } + } + _ => return Err(BailoutReason::FastUnsupportedParamType), + }, + _ => return Err(BailoutReason::FastUnsupportedParamType), + }, + _ => return Err(BailoutReason::FastUnsupportedParamType), + }, + _ => return Err(BailoutReason::FastUnsupportedParamType), + }; + Ok(()) + } +} + +fn single_segment( + segments: &Punctuated<PathSegment, Colon2>, +) -> Result<&PathSegment, BailoutReason> { + if segments.len() != 1 { + return Err(BailoutReason::MustBeSingleSegment); + } + + match segments.last() { + Some(segment) => Ok(segment), + None => Err(BailoutReason::MustBeSingleSegment), + } +} + +fn double_segment( + segments: &Punctuated<PathSegment, Colon2>, +) -> Result<[&PathSegment; 2], BailoutReason> { + match (segments.first(), segments.last()) { + (Some(first), Some(last)) => Ok([first, last]), + // Caller ensures that there are only two segments. + _ => unreachable!(), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::Op; + use std::path::PathBuf; + use syn::parse_quote; + + #[test] + fn test_single_segment() { + let segments = parse_quote!(foo); + assert!(single_segment(&segments).is_ok()); + + let segments = parse_quote!(foo::bar); + assert!(single_segment(&segments).is_err()); + } + + #[test] + fn test_double_segment() { + let segments = parse_quote!(foo::bar); + assert!(double_segment(&segments).is_ok()); + assert_eq!(double_segment(&segments).unwrap()[0].ident, "foo"); + assert_eq!(double_segment(&segments).unwrap()[1].ident, "bar"); + } + + #[testing_macros::fixture("optimizer_tests/**/*.rs")] + fn test_analyzer(input: PathBuf) { + let update_expected = std::env::var("UPDATE_EXPECTED").is_ok(); + + let source = + std::fs::read_to_string(&input).expect("Failed to read test file"); + let expected = std::fs::read_to_string(input.with_extension("expected")) + .expect("Failed to read expected file"); + + let item = syn::parse_str(&source).expect("Failed to parse test file"); + let mut op = Op::new(item, Default::default()); + let mut optimizer = Optimizer::new(); + if let Err(e) = optimizer.analyze(&mut op) { + let e_str = format!("{:?}", e); + if update_expected { + std::fs::write(input.with_extension("expected"), e_str) + .expect("Failed to write expected file"); + } else { + assert_eq!(e_str, expected); + } + return; + } + + if update_expected { + std::fs::write( + input.with_extension("expected"), + format!("{:#?}", optimizer), + ) + .expect("Failed to write expected file"); + } else { + assert_eq!(format!("{:#?}", optimizer), expected); + } + } +} |