summaryrefslogtreecommitdiff
path: root/ext/webgpu/src/pipeline.rs
diff options
context:
space:
mode:
Diffstat (limited to 'ext/webgpu/src/pipeline.rs')
-rw-r--r--ext/webgpu/src/pipeline.rs142
1 files changed, 67 insertions, 75 deletions
diff --git a/ext/webgpu/src/pipeline.rs b/ext/webgpu/src/pipeline.rs
index 8dd0e7e0f..1b69e118d 100644
--- a/ext/webgpu/src/pipeline.rs
+++ b/ext/webgpu/src/pipeline.rs
@@ -2,12 +2,12 @@
use deno_core::error::AnyError;
use deno_core::op;
+use deno_core::OpState;
+use deno_core::Resource;
use deno_core::ResourceId;
-use deno_core::{OpState, Resource};
use serde::Deserialize;
use serde::Serialize;
use std::borrow::Cow;
-use std::convert::{TryFrom, TryInto};
use super::error::WebGpuError;
use super::error::WebGpuResult;
@@ -43,59 +43,69 @@ impl Resource for WebGpuRenderPipeline {
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
-struct GpuProgrammableStage {
- module: ResourceId,
- entry_point: String,
- // constants: HashMap<String, GPUPipelineConstantValue>
+pub enum GPUAutoLayoutMode {
+ Auto,
+}
+
+#[derive(Deserialize)]
+#[serde(untagged)]
+pub enum GPUPipelineLayoutOrGPUAutoLayoutMode {
+ Layout(ResourceId),
+ Auto(GPUAutoLayoutMode),
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
-pub struct CreateComputePipelineArgs {
- device_rid: ResourceId,
- label: Option<String>,
- layout: Option<ResourceId>,
- compute: GpuProgrammableStage,
+pub struct GpuProgrammableStage {
+ module: ResourceId,
+ entry_point: String,
+ // constants: HashMap<String, GPUPipelineConstantValue>
}
#[op]
pub fn op_webgpu_create_compute_pipeline(
state: &mut OpState,
- args: CreateComputePipelineArgs,
+ device_rid: ResourceId,
+ label: Option<String>,
+ layout: GPUPipelineLayoutOrGPUAutoLayoutMode,
+ compute: GpuProgrammableStage,
) -> Result<WebGpuResult, AnyError> {
let instance = state.borrow::<super::Instance>();
let device_resource = state
.resource_table
- .get::<super::WebGpuDevice>(args.device_rid)?;
+ .get::<super::WebGpuDevice>(device_rid)?;
let device = device_resource.0;
- let pipeline_layout = if let Some(rid) = args.layout {
- let id = state.resource_table.get::<WebGpuPipelineLayout>(rid)?;
- Some(id.0)
- } else {
- None
+ let pipeline_layout = match layout {
+ GPUPipelineLayoutOrGPUAutoLayoutMode::Layout(rid) => {
+ let id = state.resource_table.get::<WebGpuPipelineLayout>(rid)?;
+ Some(id.0)
+ }
+ GPUPipelineLayoutOrGPUAutoLayoutMode::Auto(GPUAutoLayoutMode::Auto) => None,
};
let compute_shader_module_resource =
state
.resource_table
- .get::<super::shader::WebGpuShaderModule>(args.compute.module)?;
+ .get::<super::shader::WebGpuShaderModule>(compute.module)?;
let descriptor = wgpu_core::pipeline::ComputePipelineDescriptor {
- label: args.label.map(Cow::from),
+ label: label.map(Cow::from),
layout: pipeline_layout,
stage: wgpu_core::pipeline::ProgrammableStageDescriptor {
module: compute_shader_module_resource.0,
- entry_point: Cow::from(args.compute.entry_point),
+ entry_point: Cow::from(compute.entry_point),
// TODO(lucacasonato): support args.compute.constants
},
};
- let implicit_pipelines = match args.layout {
- Some(_) => None,
- None => Some(wgpu_core::device::ImplicitPipelineIds {
- root_id: std::marker::PhantomData,
- group_ids: &[std::marker::PhantomData; MAX_BIND_GROUPS],
- }),
+ let implicit_pipelines = match layout {
+ GPUPipelineLayoutOrGPUAutoLayoutMode::Layout(_) => None,
+ GPUPipelineLayoutOrGPUAutoLayoutMode::Auto(GPUAutoLayoutMode::Auto) => {
+ Some(wgpu_core::device::ImplicitPipelineIds {
+ root_id: std::marker::PhantomData,
+ group_ids: &[std::marker::PhantomData; MAX_BIND_GROUPS],
+ })
+ }
};
let (compute_pipeline, maybe_err) = gfx_select!(device => instance.device_create_compute_pipeline(
@@ -112,13 +122,6 @@ pub fn op_webgpu_create_compute_pipeline(
Ok(WebGpuResult::rid_err(rid, maybe_err))
}
-#[derive(Deserialize)]
-#[serde(rename_all = "camelCase")]
-pub struct ComputePipelineGetBindGroupLayoutArgs {
- compute_pipeline_rid: ResourceId,
- index: u32,
-}
-
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub struct PipelineLayout {
@@ -130,15 +133,16 @@ pub struct PipelineLayout {
#[op]
pub fn op_webgpu_compute_pipeline_get_bind_group_layout(
state: &mut OpState,
- args: ComputePipelineGetBindGroupLayoutArgs,
+ compute_pipeline_rid: ResourceId,
+ index: u32,
) -> Result<PipelineLayout, AnyError> {
let instance = state.borrow::<super::Instance>();
let compute_pipeline_resource = state
.resource_table
- .get::<WebGpuComputePipeline>(args.compute_pipeline_rid)?;
+ .get::<WebGpuComputePipeline>(compute_pipeline_rid)?;
let compute_pipeline = compute_pipeline_resource.0;
- let (bind_group_layout, maybe_err) = gfx_select!(compute_pipeline => instance.compute_pipeline_get_bind_group_layout(compute_pipeline, args.index, std::marker::PhantomData));
+ let (bind_group_layout, maybe_err) = gfx_select!(compute_pipeline => instance.compute_pipeline_get_bind_group_layout(compute_pipeline, index, std::marker::PhantomData));
let label = gfx_select!(bind_group_layout => instance.bind_group_layout_label(bind_group_layout));
@@ -210,12 +214,9 @@ struct GpuDepthStencilState {
depth_bias_clamp: f32,
}
-impl TryFrom<GpuDepthStencilState> for wgpu_types::DepthStencilState {
- type Error = AnyError;
- fn try_from(
- state: GpuDepthStencilState,
- ) -> Result<wgpu_types::DepthStencilState, AnyError> {
- Ok(wgpu_types::DepthStencilState {
+impl From<GpuDepthStencilState> for wgpu_types::DepthStencilState {
+ fn from(state: GpuDepthStencilState) -> wgpu_types::DepthStencilState {
+ wgpu_types::DepthStencilState {
format: state.format,
depth_write_enabled: state.depth_write_enabled,
depth_compare: state.depth_compare,
@@ -230,7 +231,7 @@ impl TryFrom<GpuDepthStencilState> for wgpu_types::DepthStencilState {
slope_scale: state.depth_bias_slope_scale,
clamp: state.depth_bias_clamp,
},
- })
+ }
}
}
@@ -285,7 +286,7 @@ impl From<GpuMultisampleState> for wgpu_types::MultisampleState {
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct GpuFragmentState {
- targets: Vec<wgpu_types::ColorTargetState>,
+ targets: Vec<Option<wgpu_types::ColorTargetState>>,
module: u32,
entry_point: String,
// TODO(lucacasonato): constants
@@ -296,7 +297,7 @@ struct GpuFragmentState {
pub struct CreateRenderPipelineArgs {
device_rid: ResourceId,
label: Option<String>,
- layout: Option<ResourceId>,
+ layout: GPUPipelineLayoutOrGPUAutoLayoutMode,
vertex: GpuVertexState,
primitive: GpuPrimitiveState,
depth_stencil: Option<GpuDepthStencilState>,
@@ -315,12 +316,13 @@ pub fn op_webgpu_create_render_pipeline(
.get::<super::WebGpuDevice>(args.device_rid)?;
let device = device_resource.0;
- let layout = if let Some(rid) = args.layout {
- let pipeline_layout_resource =
- state.resource_table.get::<WebGpuPipelineLayout>(rid)?;
- Some(pipeline_layout_resource.0)
- } else {
- None
+ let layout = match args.layout {
+ GPUPipelineLayoutOrGPUAutoLayoutMode::Layout(rid) => {
+ let pipeline_layout_resource =
+ state.resource_table.get::<WebGpuPipelineLayout>(rid)?;
+ Some(pipeline_layout_resource.0)
+ }
+ GPUPipelineLayoutOrGPUAutoLayoutMode::Auto(GPUAutoLayoutMode::Auto) => None,
};
let vertex_shader_module_resource =
@@ -334,18 +336,12 @@ pub fn op_webgpu_create_render_pipeline(
.resource_table
.get::<super::shader::WebGpuShaderModule>(fragment.module)?;
- let mut targets = Vec::with_capacity(fragment.targets.len());
-
- for target in fragment.targets {
- targets.push(target);
- }
-
Some(wgpu_core::pipeline::FragmentState {
stage: wgpu_core::pipeline::ProgrammableStageDescriptor {
module: fragment_shader_module_resource.0,
entry_point: Cow::from(fragment.entry_point),
},
- targets: Cow::from(targets),
+ targets: Cow::from(fragment.targets),
})
} else {
None
@@ -370,18 +366,20 @@ pub fn op_webgpu_create_render_pipeline(
buffers: Cow::Owned(vertex_buffers),
},
primitive: args.primitive.into(),
- depth_stencil: args.depth_stencil.map(TryInto::try_into).transpose()?,
+ depth_stencil: args.depth_stencil.map(Into::into),
multisample: args.multisample,
fragment,
multiview: None,
};
let implicit_pipelines = match args.layout {
- Some(_) => None,
- None => Some(wgpu_core::device::ImplicitPipelineIds {
- root_id: std::marker::PhantomData,
- group_ids: &[std::marker::PhantomData; MAX_BIND_GROUPS],
- }),
+ GPUPipelineLayoutOrGPUAutoLayoutMode::Layout(_) => None,
+ GPUPipelineLayoutOrGPUAutoLayoutMode::Auto(GPUAutoLayoutMode::Auto) => {
+ Some(wgpu_core::device::ImplicitPipelineIds {
+ root_id: std::marker::PhantomData,
+ group_ids: &[std::marker::PhantomData; MAX_BIND_GROUPS],
+ })
+ }
};
let (render_pipeline, maybe_err) = gfx_select!(device => instance.device_create_render_pipeline(
@@ -398,25 +396,19 @@ pub fn op_webgpu_create_render_pipeline(
Ok(WebGpuResult::rid_err(rid, maybe_err))
}
-#[derive(Deserialize)]
-#[serde(rename_all = "camelCase")]
-pub struct RenderPipelineGetBindGroupLayoutArgs {
- render_pipeline_rid: ResourceId,
- index: u32,
-}
-
#[op]
pub fn op_webgpu_render_pipeline_get_bind_group_layout(
state: &mut OpState,
- args: RenderPipelineGetBindGroupLayoutArgs,
+ render_pipeline_rid: ResourceId,
+ index: u32,
) -> Result<PipelineLayout, AnyError> {
let instance = state.borrow::<super::Instance>();
let render_pipeline_resource = state
.resource_table
- .get::<WebGpuRenderPipeline>(args.render_pipeline_rid)?;
+ .get::<WebGpuRenderPipeline>(render_pipeline_rid)?;
let render_pipeline = render_pipeline_resource.0;
- let (bind_group_layout, maybe_err) = gfx_select!(render_pipeline => instance.render_pipeline_get_bind_group_layout(render_pipeline, args.index, std::marker::PhantomData));
+ let (bind_group_layout, maybe_err) = gfx_select!(render_pipeline => instance.render_pipeline_get_bind_group_layout(render_pipeline, index, std::marker::PhantomData));
let label = gfx_select!(bind_group_layout => instance.bind_group_layout_label(bind_group_layout));