diff options
author | Luca Casonato <hello@lcas.dev> | 2021-08-24 20:32:25 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-08-24 20:32:25 +0200 |
commit | 4853be20f2d649842ebc97124d8479c7aad7cc9b (patch) | |
tree | 1be4dcc96c72166b3e1e4f19ee70eb791e1304aa /ext/webgpu/compute_pass.rs | |
parent | e10d30c8eaf41ad68b48f21c8d563d192b82afe8 (diff) |
refactor(webgpu): use op interface idiomatically (#11835)
Diffstat (limited to 'ext/webgpu/compute_pass.rs')
-rw-r--r-- | ext/webgpu/compute_pass.rs | 56 |
1 files changed, 35 insertions, 21 deletions
diff --git a/ext/webgpu/compute_pass.rs b/ext/webgpu/compute_pass.rs index 4fc0af538..fe1186c4e 100644 --- a/ext/webgpu/compute_pass.rs +++ b/ext/webgpu/compute_pass.rs @@ -1,6 +1,5 @@ // Copyright 2018-2021 the Deno authors. All rights reserved. MIT license. -use deno_core::error::null_opbuf; use deno_core::error::AnyError; use deno_core::ResourceId; use deno_core::ZeroCopyBuf; @@ -24,7 +23,7 @@ impl Resource for WebGpuComputePass { #[serde(rename_all = "camelCase")] pub struct ComputePassSetPipelineArgs { compute_pass_rid: ResourceId, - pipeline: u32, + pipeline: ResourceId, } pub fn op_webgpu_compute_pass_set_pipeline( @@ -80,7 +79,7 @@ pub fn op_webgpu_compute_pass_dispatch( #[serde(rename_all = "camelCase")] pub struct ComputePassDispatchIndirectArgs { compute_pass_rid: ResourceId, - indirect_buffer: u32, + indirect_buffer: ResourceId, indirect_offset: u64, } @@ -109,7 +108,7 @@ pub fn op_webgpu_compute_pass_dispatch_indirect( #[serde(rename_all = "camelCase")] pub struct ComputePassBeginPipelineStatisticsQueryArgs { compute_pass_rid: ResourceId, - query_set: u32, + query_set: ResourceId, query_index: u32, } @@ -160,7 +159,7 @@ pub fn op_webgpu_compute_pass_end_pipeline_statistics_query( #[serde(rename_all = "camelCase")] pub struct ComputePassWriteTimestampArgs { compute_pass_rid: ResourceId, - query_set: u32, + query_set: ResourceId, query_index: u32, } @@ -220,8 +219,8 @@ pub fn op_webgpu_compute_pass_end_pass( pub struct ComputePassSetBindGroupArgs { compute_pass_rid: ResourceId, index: u32, - bind_group: u32, - dynamic_offsets_data: Option<Vec<u32>>, + bind_group: ResourceId, + dynamic_offsets_data: ZeroCopyBuf, dynamic_offsets_data_start: usize, dynamic_offsets_data_length: usize, } @@ -229,7 +228,7 @@ pub struct ComputePassSetBindGroupArgs { pub fn op_webgpu_compute_pass_set_bind_group( state: &mut OpState, args: ComputePassSetBindGroupArgs, - zero_copy: Option<ZeroCopyBuf>, + _: (), ) -> Result<WebGpuResult, AnyError> { let bind_group_resource = state @@ -239,22 +238,33 @@ pub fn op_webgpu_compute_pass_set_bind_group( .resource_table .get::<WebGpuComputePass>(args.compute_pass_rid)?; + // Align the data + assert!(args.dynamic_offsets_data_start % std::mem::size_of::<u32>() == 0); + // SAFETY: A u8 to u32 cast is safe because we asserted that the length is a + // multiple of 4. + let (prefix, dynamic_offsets_data, suffix) = + unsafe { args.dynamic_offsets_data.align_to::<u32>() }; + assert!(prefix.is_empty()); + assert!(suffix.is_empty()); + + let start = args.dynamic_offsets_data_start; + let len = args.dynamic_offsets_data_length; + + // Assert that length and start are both in bounds + assert!(start <= dynamic_offsets_data.len()); + assert!(len <= dynamic_offsets_data.len() - start); + + let dynamic_offsets_data: &[u32] = &dynamic_offsets_data[start..start + len]; + + // SAFETY: the raw pointer and length are of the same slice, and that slice + // lives longer than the below function invocation. unsafe { wgpu_core::command::compute_ffi::wgpu_compute_pass_set_bind_group( &mut compute_pass_resource.0.borrow_mut(), args.index, bind_group_resource.0, - match args.dynamic_offsets_data { - Some(data) => data.as_ptr(), - None => { - let zero_copy = zero_copy.ok_or_else(null_opbuf)?; - let (prefix, data, suffix) = zero_copy.align_to::<u32>(); - assert!(prefix.is_empty()); - assert!(suffix.is_empty()); - data[args.dynamic_offsets_data_start..].as_ptr() - } - }, - args.dynamic_offsets_data_length, + dynamic_offsets_data.as_ptr(), + dynamic_offsets_data.len(), ); } @@ -277,8 +287,10 @@ pub fn op_webgpu_compute_pass_push_debug_group( .resource_table .get::<WebGpuComputePass>(args.compute_pass_rid)?; + let label = std::ffi::CString::new(args.group_label).unwrap(); + // SAFETY: the string the raw pointer points to lives longer than the below + // function invocation. unsafe { - let label = std::ffi::CString::new(args.group_label).unwrap(); wgpu_core::command::compute_ffi::wgpu_compute_pass_push_debug_group( &mut compute_pass_resource.0.borrow_mut(), label.as_ptr(), @@ -327,8 +339,10 @@ pub fn op_webgpu_compute_pass_insert_debug_marker( .resource_table .get::<WebGpuComputePass>(args.compute_pass_rid)?; + let label = std::ffi::CString::new(args.marker_label).unwrap(); + // SAFETY: the string the raw pointer points to lives longer than the below + // function invocation. unsafe { - let label = std::ffi::CString::new(args.marker_label).unwrap(); wgpu_core::command::compute_ffi::wgpu_compute_pass_insert_debug_marker( &mut compute_pass_resource.0.borrow_mut(), label.as_ptr(), |