Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,14 @@ jobs:
fail-fast: false
matrix:
os: [ ubuntu-24.04, windows-2022, macOS-latest ]
target_env: [ "vulkan1.1,vulkan1.2,vulkan1.3,vulkan1.4,spv1.3,spv1.4" ]
experimental: [ false ]
include:
- os: ubuntu-24.04
target_env: naga-wgsl
experimental: true
runs-on: ${{ matrix.os }}
continue-on-error: ${{ matrix.experimental }}
steps:
- uses: actions/checkout@v4
- name: Install Vulkan SDK
Expand All @@ -143,7 +150,7 @@ jobs:
- name: cargo fetch --locked
run: cargo fetch --locked --target $TARGET
- name: compiletest
run: cargo run -p compiletests --release --no-default-features --features "use-installed-tools" -- --target-env vulkan1.1,vulkan1.2,vulkan1.3,vulkan1.4,spv1.3,spv1.4
run: cargo run -p compiletests --release --no-default-features --features "use-installed-tools" -- --target-env ${{ matrix.target_env }}

difftest:
name: Difftest
Expand Down
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 5 additions & 2 deletions crates/rustc_codegen_spirv/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@ crate-type = ["dylib"]
default = ["use-compiled-tools"]
# If enabled, uses spirv-tools binaries installed in PATH, instead of
# compiling and linking the spirv-tools C++ code
use-installed-tools = ["spirv-tools/use-installed-tools"]
use-installed-tools = ["spirv-tools/use-installed-tools", "naga"]
# If enabled will compile and link the C++ code for the spirv tools, the compiled
# version is preferred if both this and `use-installed-tools` are enabled
use-compiled-tools = ["spirv-tools/use-compiled-tools"]
use-compiled-tools = ["spirv-tools/use-compiled-tools", "naga"]
Comment on lines +23 to +26
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, this effectively makes naga non-optional if I'm reading it correctly. Is it intentional?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's from the old PR, added with comment

wgsl: enable naga feature by default, cargo-gpu can't handle it

So this was more a hack than anything else. I'll have a look what the compile time impact is, and if it's actually significant (which next to spirv-tools-sys may not be), look into making it optional and supported by cargo-gpu.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if naga should really be a dependency of either of those two features. Why not let user specify it explicitly instead?

Copy link
Member Author

@Firestar99 Firestar99 Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cargo-gpu can't handle it

we'd need to build the infra in cargo-gpu first

# If enabled, this will not check whether the current rustc version is set to the
# appropriate channel. rustc_cogeden_spirv requires a specific nightly version,
# and will likely produce compile errors when built against a different toolchain.
# Enable this feature to be able to experiment with other versions.
skip-toolchain-check = []
naga = ["dep:naga"]

[dependencies]
# HACK(eddyb) these only exist to unify features across dependency trees,
Expand Down Expand Up @@ -61,6 +62,8 @@ itertools = "0.14.0"
tracing.workspace = true
tracing-subscriber.workspace = true
tracing-tree = "0.4.0"
naga = { version = "27.0.3", features = ["spv-in", "wgsl-out"], optional = true }
strum = { version = "0.27.2", features = ["derive"] }

[dev-dependencies]
pretty_assertions = "1.0"
Expand Down
1 change: 1 addition & 0 deletions crates/rustc_codegen_spirv/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ mod custom_decorations;
mod custom_insts;
mod link;
mod linker;
mod naga_transpile;
mod spirv_type;
mod spirv_type_constraints;
mod symbols;
Expand Down
5 changes: 5 additions & 0 deletions crates/rustc_codegen_spirv/src/link.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
use crate::maybe_pqp_cg_ssa as rustc_codegen_ssa;

use crate::codegen_cx::{CodegenArgs, SpirvMetadata};
use crate::naga_transpile::should_transpile;
use crate::target::{SpirvTarget, SpirvTargetVariant};
use crate::{SpirvCodegenBackend, SpirvModuleBuffer, linker};
use ar::{Archive, GnuBuilder, Header};
Expand Down Expand Up @@ -323,6 +324,10 @@ fn post_link_single_module(

drop(save_modules_timer);
}

if let Ok(Some(transpile)) = should_transpile(sess) {
transpile(sess, cg_args, &spv_binary, out_filename).ok();
}
}

fn do_spirv_opt(
Expand Down
89 changes: 89 additions & 0 deletions crates/rustc_codegen_spirv/src/naga_transpile.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
use crate::codegen_cx::CodegenArgs;
use crate::target::{NagaTarget, SpirvTarget};
use rustc_session::Session;
use rustc_span::ErrorGuaranteed;
use std::path::Path;

pub type NagaTranspile = fn(
sess: &Session,
cg_args: &CodegenArgs,
spv_binary: &[u32],
out_filename: &Path,
) -> Result<(), ErrorGuaranteed>;

pub fn should_transpile(sess: &Session) -> Result<Option<NagaTranspile>, ErrorGuaranteed> {
let target = SpirvTarget::parse_target(sess.opts.target_triple.tuple())
.expect("parsing should fail earlier");
let result: Result<Option<NagaTranspile>, ()> = match target {
#[cfg(feature = "naga")]
SpirvTarget::Naga(NagaTarget::NAGA_WGSL) => Ok(Some(transpile::wgsl_transpile)),
#[cfg(not(feature = "naga"))]
SpirvTarget::Naga(NagaTarget::NAGA_WGSL) => Err(()),
_ => Ok(None),
};
result.map_err(|_e| {
sess.dcx().err(format!(
"Target `{}` requires feature \"naga\" on rustc_codegen_spirv",
target.target()
))
})
}

#[cfg(feature = "naga")]
mod transpile {
use crate::codegen_cx::CodegenArgs;
use naga::error::ShaderError;
use naga::valid::Capabilities;
use rustc_session::Session;
use rustc_span::ErrorGuaranteed;
use std::path::Path;

pub fn wgsl_transpile(
sess: &Session,
_cg_args: &CodegenArgs,
spv_binary: &[u32],
out_filename: &Path,
) -> Result<(), ErrorGuaranteed> {
// these should be params via spirv-builder
let opts = naga::front::spv::Options::default();
let capabilities = Capabilities::default();
let writer_flags = naga::back::wgsl::WriterFlags::empty();

let module = naga::front::spv::parse_u8_slice(bytemuck::cast_slice(spv_binary), &opts)
.map_err(|err| {
sess.dcx().err(format!(
"Naga failed to parse spv: \n{}",
ShaderError {
source: String::new(),
label: None,
inner: Box::new(err),
}
))
})?;
let mut validator =
naga::valid::Validator::new(naga::valid::ValidationFlags::default(), capabilities);
let info = validator.validate(&module).map_err(|err| {
sess.dcx().err(format!(
"Naga validation failed: \n{}",
ShaderError {
source: String::new(),
label: None,
inner: Box::new(err),
}
))
})?;

let wgsl_dst = out_filename.with_extension("wgsl");
let wgsl = naga::back::wgsl::write_string(&module, &info, writer_flags).map_err(|err| {
sess.dcx()
.err(format!("Naga failed to write wgsl : \n{err}"))
})?;

std::fs::write(&wgsl_dst, wgsl).map_err(|err| {
sess.dcx()
.err(format!("failed to write wgsl to file: {err}"))
})?;

Ok(())
}
}
77 changes: 77 additions & 0 deletions crates/rustc_codegen_spirv/src/target.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,16 @@ use std::cmp::Ordering;
use std::fmt::{Debug, Display, Formatter};
use std::ops::{Deref, DerefMut};
use std::str::FromStr;
use strum::{Display, EnumString, IntoStaticStr};

#[derive(Clone, Eq, PartialEq)]
pub enum TargetError {
/// If during parsing a target variant returns `UnknownTarget`, further variants will attempt to parse the string.
/// Returning another error means that you have recognized the target but something else is invalid, and we should
/// abort the parsing with your error.
UnknownTarget(String),
InvalidTargetVersion(SpirvTarget),
InvalidNagaVariant(String),
}

impl Display for TargetError {
Expand All @@ -21,6 +26,9 @@ impl Display for TargetError {
TargetError::InvalidTargetVersion(target) => {
write!(f, "Invalid version in target `{}`", target.env())
}
TargetError::InvalidNagaVariant(target) => {
write!(f, "Unknown naga out variant `{target}`")
}
}
}
}
Expand Down Expand Up @@ -439,13 +447,71 @@ impl Display for OpenGLTarget {
}
}

/// A naga target
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub struct NagaTarget {
pub out: NagaOut,
}

#[derive(Copy, Clone, Debug, Eq, PartialEq, IntoStaticStr, Display, EnumString)]
#[allow(clippy::upper_case_acronyms)]
pub enum NagaOut {
#[strum(to_string = "wgsl")]
WGSL,
}

impl NagaTarget {
pub const NAGA_WGSL: Self = NagaTarget::new(NagaOut::WGSL);
pub const ALL_NAGA_TARGETS: &'static [Self] = &[Self::NAGA_WGSL];
/// emit spirv like naga targets were this target
pub const EMIT_SPIRV_LIKE: SpirvTarget = SpirvTarget::VULKAN_1_3;

pub const fn new(out: NagaOut) -> Self {
Self { out }
}
}

impl SpirvTargetVariant for NagaTarget {
fn validate(&self) -> Result<(), TargetError> {
Ok(())
}

fn to_spirv_tools(&self) -> spirv_tools::TargetEnv {
Self::EMIT_SPIRV_LIKE.to_spirv_tools()
}

fn spirv_version(&self) -> SpirvVersion {
Self::EMIT_SPIRV_LIKE.spirv_version()
}
}

impl FromStr for NagaTarget {
type Err = TargetError;

fn from_str(s: &str) -> Result<Self, Self::Err> {
let s = s
.strip_prefix("naga-")
.ok_or_else(|| TargetError::UnknownTarget(s.to_owned()))?;
Ok(Self::new(FromStr::from_str(s).map_err(|_e| {
TargetError::InvalidNagaVariant(s.to_owned())
})?))
}
}

impl Display for NagaTarget {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "naga-{}", self.out)
}
}

/// A rust-gpu target
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
#[non_exhaustive]
pub enum SpirvTarget {
Universal(UniversalTarget),
Vulkan(VulkanTarget),
OpenGL(OpenGLTarget),
Naga(NagaTarget),
}

impl SpirvTarget {
Expand All @@ -467,12 +533,15 @@ impl SpirvTarget {
pub const OPENGL_4_2: Self = Self::OpenGL(OpenGLTarget::OPENGL_4_2);
pub const OPENGL_4_3: Self = Self::OpenGL(OpenGLTarget::OPENGL_4_3);
pub const OPENGL_4_5: Self = Self::OpenGL(OpenGLTarget::OPENGL_4_5);
pub const NAGA_WGSL: Self = Self::Naga(NagaTarget::NAGA_WGSL);

#[allow(clippy::match_same_arms)]
pub const fn memory_model(&self) -> MemoryModel {
match self {
SpirvTarget::Universal(_) => MemoryModel::Simple,
SpirvTarget::Vulkan(_) => MemoryModel::Vulkan,
SpirvTarget::OpenGL(_) => MemoryModel::GLSL450,
SpirvTarget::Naga(_) => MemoryModel::Vulkan,
}
}
}
Expand All @@ -483,6 +552,7 @@ impl SpirvTargetVariant for SpirvTarget {
SpirvTarget::Universal(t) => t.validate(),
SpirvTarget::Vulkan(t) => t.validate(),
SpirvTarget::OpenGL(t) => t.validate(),
SpirvTarget::Naga(t) => t.validate(),
}
}

Expand All @@ -491,6 +561,7 @@ impl SpirvTargetVariant for SpirvTarget {
SpirvTarget::Universal(t) => t.to_spirv_tools(),
SpirvTarget::Vulkan(t) => t.to_spirv_tools(),
SpirvTarget::OpenGL(t) => t.to_spirv_tools(),
SpirvTarget::Naga(t) => t.to_spirv_tools(),
}
}

Expand All @@ -499,6 +570,7 @@ impl SpirvTargetVariant for SpirvTarget {
SpirvTarget::Universal(t) => t.spirv_version(),
SpirvTarget::Vulkan(t) => t.spirv_version(),
SpirvTarget::OpenGL(t) => t.spirv_version(),
SpirvTarget::Naga(t) => t.spirv_version(),
}
}
}
Expand All @@ -513,6 +585,9 @@ impl SpirvTarget {
if matches!(result, Err(TargetError::UnknownTarget(..))) {
result = OpenGLTarget::from_str(s).map(Self::OpenGL);
}
if matches!(result, Err(TargetError::UnknownTarget(..))) {
result = NagaTarget::from_str(s).map(Self::Naga);
}
result
}

Expand All @@ -533,6 +608,7 @@ impl SpirvTarget {
SpirvTarget::Universal(t) => t.to_string(),
SpirvTarget::Vulkan(t) => t.to_string(),
SpirvTarget::OpenGL(t) => t.to_string(),
SpirvTarget::Naga(t) => t.to_string(),
}
}

Expand All @@ -555,6 +631,7 @@ impl SpirvTarget {
.iter()
.map(|t| Self::OpenGL(*t)),
)
.chain(NagaTarget::ALL_NAGA_TARGETS.iter().map(|t| Self::Naga(*t)))
}
}

Expand Down
2 changes: 1 addition & 1 deletion tests/compiletests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use-compiled-tools = ["rustc_codegen_spirv/use-compiled-tools"]

[dependencies]
compiletest = { version = "0.11.2", package = "compiletest_rs" }
rustc_codegen_spirv = { workspace = true }
rustc_codegen_spirv = { workspace = true, features = ["naga"] }
rustc_codegen_spirv-types = { workspace = true }
clap = { version = "4", features = ["derive"] }
itertools = "0.14.0"
Expand Down
1 change: 1 addition & 0 deletions tests/compiletests/ui/arch/all_memory_barrier.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// build-pass
// compile-flags: -C target-feature=+VulkanMemoryModelDeviceScopeKHR,+ext:SPV_KHR_vulkan_memory_model
// compile-flags: -C llvm-args=--disassemble-fn=all_memory_barrier::all_memory_barrier
// ignore-naga

use spirv_std::spirv;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// build-pass
// compile-flags: -Ctarget-feature=+Int64,+RayTracingKHR,+ext:SPV_KHR_ray_tracing
// ignore-naga

use spirv_std::spirv;

Expand Down
1 change: 1 addition & 0 deletions tests/compiletests/ui/arch/debug_printf.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// build-pass
// compile-flags: -Ctarget-feature=+ext:SPV_KHR_non_semantic_info
// ignore-naga

use spirv_std::spirv;
use spirv_std::{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// build-pass
//
// compile-flags: -C target-feature=+DemoteToHelperInvocationEXT,+ext:SPV_EXT_demote_to_helper_invocation
// ignore-naga

use spirv_std::spirv;

Expand Down
1 change: 1 addition & 0 deletions tests/compiletests/ui/arch/emit_stream_vertex.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// build-pass
// compile-flags: -C target-feature=+Int64,+GeometryStreams
// ignore-naga

use spirv_std::spirv;

Expand Down
Loading
Loading