diff options
Diffstat (limited to 'ext/ffi/jit_trampoline.rs')
-rw-r--r-- | ext/ffi/jit_trampoline.rs | 26 |
1 files changed, 22 insertions, 4 deletions
diff --git a/ext/ffi/jit_trampoline.rs b/ext/ffi/jit_trampoline.rs index 6a3efa876..92e63348d 100644 --- a/ext/ffi/jit_trampoline.rs +++ b/ext/ffi/jit_trampoline.rs @@ -32,7 +32,8 @@ fn native_arg_to_c(ty: &NativeType) -> &'static str { NativeType::I64 => "int64_t", NativeType::ISize => "intptr_t", NativeType::USize => "uintptr_t", - NativeType::Pointer | NativeType::Function => "void*", + NativeType::Pointer => "struct FastApiTypedArray*", + NativeType::Function => "void*", } } @@ -85,11 +86,15 @@ pub(crate) fn codegen(sym: &crate::Symbol) -> String { c += ") {\n"; // return func(p0, p1, ...); c += " return func("; - for (i, _) in sym.parameter_types.iter().enumerate() { + for (i, ty) in sym.parameter_types.iter().enumerate() { if i > 0 { c += ", "; } - let _ = write!(c, "p{i}"); + if matches!(ty, NativeType::Pointer) { + let _ = write!(c, "p{i}->data"); + } else { + let _ = write!(c, "p{i}"); + } } c += ");\n}\n\n"; c @@ -103,7 +108,6 @@ pub(crate) fn gen_trampoline( // SAFETY: symbol satisfies ABI requirement. unsafe { ctx.add_symbol(cstr!("func"), sym.ptr.0 as *const c_void) }; let c = codegen(&sym); - ctx.compile_string(cstr!(c))?; let alloc = Allocation { addr: ctx.relocate_and_get_symbol(cstr!("func_trampoline"))?, @@ -172,6 +176,20 @@ mod tests { \n return func(p0, p1);\n\ }\n\n", ); + assert_codegen( + codegen(vec![NativeType::Pointer, NativeType::U32], NativeType::U32), + "extern uint32_t func(void* p0, uint32_t p1);\n\n\ + uint32_t func_trampoline(void* recv, struct FastApiTypedArray* p0, uint32_t p1) {\ + \n return func(p0->data, p1);\n\ + }\n\n", + ); + assert_codegen( + codegen(vec![NativeType::Pointer, NativeType::Pointer], NativeType::U32), + "extern uint32_t func(void* p0, void* p1);\n\n\ + uint32_t func_trampoline(void* recv, struct FastApiTypedArray* p0, struct FastApiTypedArray* p1) {\ + \n return func(p0->data, p1->data);\n\ + }\n\n", + ); } #[test] |