summaryrefslogtreecommitdiff
path: root/ext/ffi/jit_trampoline.rs
diff options
context:
space:
mode:
authorDivy Srivastava <dj.srivastava23@gmail.com>2022-07-28 18:08:22 +0530
committerGitHub <noreply@github.com>2022-07-28 18:08:22 +0530
commitef7bc2e6cc4856a0372086b3ceb7d470508aaa52 (patch)
tree02045406d0fe21bf77bcb3697ed2e75006aebeed /ext/ffi/jit_trampoline.rs
parent519ed44ebb4bab71c6b80f7c1ef432354654da8c (diff)
perf(ext/ffi): use fast api calls for 64bit return types (#15313)
Diffstat (limited to 'ext/ffi/jit_trampoline.rs')
-rw-r--r--ext/ffi/jit_trampoline.rs64
1 files changed, 49 insertions, 15 deletions
diff --git a/ext/ffi/jit_trampoline.rs b/ext/ffi/jit_trampoline.rs
index 92e63348d..4785fd092 100644
--- a/ext/ffi/jit_trampoline.rs
+++ b/ext/ffi/jit_trampoline.rs
@@ -58,12 +58,15 @@ fn native_to_c(ty: &NativeType) -> &'static str {
pub(crate) fn codegen(sym: &crate::Symbol) -> String {
let mut c = String::from(include_str!("prelude.h"));
- let ret = native_to_c(&sym.result_type);
+ let needs_unwrap = crate::needs_unwrap(sym.result_type);
+
+ // Return type of the FFI call.
+ let ffi_ret = native_to_c(&sym.result_type);
+ // Return type of the trampoline.
+ let ret = if needs_unwrap { "void" } else { ffi_ret };
// extern <return_type> func(
- c += "\nextern ";
- c += ret;
- c += " func(";
+ let _ = write!(c, "\nextern {ffi_ret} func(");
// <param_type> p0, <param_type> p1, ...);
for (i, ty) in sym.parameter_types.iter().enumerate() {
if i > 0 {
@@ -83,20 +86,35 @@ pub(crate) fn codegen(sym: &crate::Symbol) -> String {
c += native_arg_to_c(ty);
let _ = write!(c, " p{i}");
}
+ if needs_unwrap {
+ let _ = write!(c, ", struct FastApiTypedArray* const p_ret");
+ }
c += ") {\n";
- // return func(p0, p1, ...);
- c += " return func(";
- for (i, ty) in sym.parameter_types.iter().enumerate() {
- if i > 0 {
- c += ", ";
- }
- if matches!(ty, NativeType::Pointer) {
- let _ = write!(c, "p{i}->data");
- } else {
- let _ = write!(c, "p{i}");
+ // func(p0, p1, ...);
+ let mut call_s = String::from("func(");
+ {
+ for (i, ty) in sym.parameter_types.iter().enumerate() {
+ if i > 0 {
+ call_s += ", ";
+ }
+ if matches!(ty, NativeType::Pointer) {
+ let _ = write!(call_s, "p{i}->data");
+ } else {
+ let _ = write!(call_s, "p{i}");
+ }
}
+ call_s += ");\n";
}
- c += ");\n}\n\n";
+ if needs_unwrap {
+ // <return_type> r = func(p0, p1, ...);
+ // ((<return_type>*)p_ret->data)[0] = r;
+ let _ = write!(c, " {ffi_ret} r = {call_s}");
+ let _ = writeln!(c, " (({ffi_ret}*)p_ret->data)[0] = r;");
+ } else {
+ // return func(p0, p1, ...);
+ let _ = write!(c, " return {call_s}");
+ }
+ c += "}\n\n";
c
}
@@ -190,6 +208,22 @@ mod tests {
\n return func(p0->data, p1->data);\n\
}\n\n",
);
+ assert_codegen(
+ codegen(vec![], NativeType::U64),
+ "extern uint64_t func();\n\n\
+ void func_trampoline(void* recv, struct FastApiTypedArray* const p_ret) {\
+ \n uint64_t r = func();\
+ \n ((uint64_t*)p_ret->data)[0] = r;\n\
+ }\n\n",
+ );
+ assert_codegen(
+ codegen(vec![NativeType::Pointer, NativeType::Pointer], NativeType::U64),
+ "extern uint64_t func(void* p0, void* p1);\n\n\
+ void func_trampoline(void* recv, struct FastApiTypedArray* p0, struct FastApiTypedArray* p1, struct FastApiTypedArray* const p_ret) {\
+ \n uint64_t r = func(p0->data, p1->data);\
+ \n ((uint64_t*)p_ret->data)[0] = r;\n\
+ }\n\n",
+ );
}
#[test]