Skip to content

Commit

Permalink
Implement AbsKernelOp for WebGPU backend (#896)
Browse files Browse the repository at this point in the history
* Removed some of the more low level commands in favor of a wrapper struct

Also added tests for higher code coverage.

* AtomicPtr unsound fix

* Partial implementation of `Device<E>` for Webgpu

* Remove foolish Mutex

* Add Mutex back, since evidently it was causing issues.

Hopefully I can figure out a way to remove it again.

* Removed `num_traits::Num` requirement from Zeros.

Had to figure out a way to store zeros in place

* Implement abs kernel, and use broken unary operation for all the compiler errors

* cargo fmt

* disable f16, since we don't support it yet

* no-std

* Added test for abs on webgpu. Also added `backward` implementation,
though I won't be able to test that until I fix `mean`.

* cargo fmt

* Managed to get built spirv working as long as we go through the
non-passthrough route.

Can't get sum_to working until wgpu supports atomic operations. Which is
super unfortunate.

Maybe I'll work on that soon...

* Have the code work correctly, almost got sum_to working, too

Weird magic number issue that I can't figure out...

* Cargo fmt

* Do we need to skip webgpu features?
  • Loading branch information
favilo authored Jan 3, 2024
1 parent 4615ac1 commit 630514f
Show file tree
Hide file tree
Showing 35 changed files with 1,098 additions and 574 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/cargo-check-features.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ jobs:
matrix:
config:
- toolchain: stable
command: cargo hack check --feature-powerset --no-dev-deps --depth 2 --skip default,nightly,cuda,cudnn
command: cargo hack check --feature-powerset --no-dev-deps --depth 2 --skip default,nightly,cuda,cudnn,webgpu
- toolchain: nightly
command: cargo hack check --each-feature --no-dev-deps --features nightly --skip default,cuda,cudnn
command: cargo hack check --each-feature --no-dev-deps --features nightly --skip default,cuda,cudnn,webgpu

steps:
- uses: actions/checkout@v2
Expand Down
12 changes: 10 additions & 2 deletions dfdx-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ half = { version = "2.3.1", optional = true, features = ["num-traits", "rand_dis
gemm = { version = "0.16.14", default-features = false, optional = true, features = ["rayon"] }
rayon = { version = "1.7.0", optional = true }
libm = { workspace = true }
wgpu = { version = "0.18.0", optional = true }
wgpu = { version = "0.18.0", features = ["glsl", "spirv"], optional = true }
naga = { version = "0.14.1", optional = true }
futures-lite = { version = "2.0.1", optional = true }
thingbuf = { version = "0.1.4", optional = true }

Expand All @@ -62,7 +63,14 @@ fast-alloc = ["std"]

cuda = ["dep:cudarc", "dep:glob"]
cudnn = ["cuda", "cudarc?/cudnn"]
webgpu = ["dep:wgpu", "dep:futures-lite", "dep:thingbuf", "wgpu/expose-ids"]
webgpu = [
"dep:wgpu",
"dep:futures-lite",
"dep:thingbuf",
"dep:naga",
"dep:glob",
"wgpu/expose-ids",
]

f16 = ["dep:half", "cudarc?/f16", "gemm?/f16"]

Expand Down
55 changes: 55 additions & 0 deletions dfdx-core/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ fn main() {

#[cfg(feature = "cuda")]
cuda::build_ptx();

#[cfg(feature = "webgpu")]
webgpu::build_spv();
}

fn maybe_enable_nightly() {
Expand Down Expand Up @@ -210,3 +213,55 @@ mod cuda {
}
}
}

#[cfg(feature = "webgpu")]
mod webgpu {
pub fn build_spv() {
let out_dir = std::env::var("OUT_DIR").unwrap();
let kernel_paths: Vec<std::path::PathBuf> = glob::glob("src/**/*.glsl")
.unwrap()
.map(|p| p.unwrap())
.collect();
for path in &kernel_paths {
println!("cargo:rerun-if-changed={}", path.display());
}

kernel_paths
.iter()
.for_each(|p| println!("cargo:rerun-if-changed={}", p.display()));

let children = kernel_paths
.iter()
.map(|p| {
["float", "double"].iter().map(|ty| {
// TODO: we need to build this for both float and double
let out_path: std::path::PathBuf = out_dir.clone().into();
let base = p.file_stem().unwrap();
let new_name = format!("{}.{ty}.spv", base.to_str().unwrap());
let out_file = &out_path.join(new_name);
std::process::Command::new("glslc")
.args(["-std=460core"])
.args(["-fshader-stage=compute"])
.args([format!("-DTYPENAME={ty}")])
.args(["-o", &out_file.as_os_str().to_str().unwrap()])
.arg(p)
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
.expect("glslc failed to start. Ensure that you have shaderc installed and that `glslc` is in your PATH.")
}).collect::<Vec<_>>()
})
.collect::<Vec<_>>();
for (kernel_path, childs) in kernel_paths.iter().zip(children.into_iter()) {
for child in childs {
let output = child.wait_with_output().expect("glslc failed to run. Ensure that you have shaderc installed and that `glslc` is in your PATH.");
assert!(
output.status.success(),
"glslc error while compiling {kernel_path:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
);
}
}
}
}
3 changes: 3 additions & 0 deletions dfdx-core/src/tensor/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ pub enum Error {

#[cfg(feature = "webgpu")]
WebgpuRequestDeviceError(wgpu::RequestDeviceError),

#[cfg(feature = "webgpu")]
WebgpuSourceLoadError,
}

impl std::fmt::Display for Error {
Expand Down
4 changes: 2 additions & 2 deletions dfdx-core/src/tensor/webgpu/allocate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ impl Webgpu {
shape: S,
buf: Vec<E>,
) -> Result<Tensor<S, E, Self>, Error> {
let buffer = unsafe { self.alloc_empty::<E>(buf.len()) }?;
let buffer = self.alloc_empty::<E>(buf.len())?;
buffer.copy_to_device::<E>(&self.dev, &self.queue, &buf);

Ok(self.build_tensor(shape, shape.strides(), buffer))
Expand Down Expand Up @@ -56,7 +56,7 @@ impl<E: Unit + SafeZeros> ZerosTensor<E> for Webgpu {
fn try_zeros_like<S: HasShape>(&self, src: &S) -> Result<Tensor<S::Shape, E, Self>, Error> {
let shape = *src.shape();
let strides = shape.strides();
let data = unsafe { self.alloc_empty::<E>(shape.num_elements()) }?;
let data = self.alloc_empty::<E>(shape.num_elements())?;
data.copy_to_device(&self.dev, &self.queue, &vec![0u8; data.size()]);

Ok(self.build_tensor(shape, strides, data))
Expand Down
100 changes: 90 additions & 10 deletions dfdx-core/src/tensor/webgpu/device.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
use wgpu::{
Adapter, BufferDescriptor, BufferUsages, Device, Instance, InstanceDescriptor, Maintain, Queue,
RequestDeviceError,
util::{make_spirv, make_spirv_raw, BufferInitDescriptor, DeviceExt},
Adapter, BufferDescriptor, BufferUsages, Device, DeviceDescriptor, Features, Instance,
InstanceDescriptor, Maintain, Queue, RequestDeviceError, ShaderModule, ShaderModuleDescriptor,
ShaderModuleDescriptorSpirV,
};

use crate::{
prelude::webgpu_kernels::HasGlslType,
shapes::{Shape, Unit},
tensor::{
cache::TensorCache, cpu::Cpu, Cache, Error, NoneTape, RandomU64, Storage, Synchronize,
Expand All @@ -12,12 +15,13 @@ use crate::{
};

#[cfg(feature = "no-std")]
use spin::Mutex;
use spin::{Mutex, RwLock};

use core::any::TypeId;
#[cfg(not(feature = "no-std"))]
use std::sync::Mutex;
use std::sync::{Mutex, RwLock};

use std::{marker::PhantomData, sync::Arc, vec::Vec};
use std::{collections::HashMap, marker::PhantomData, sync::Arc, vec::Vec};

use super::allocate::round_to_buffer_alignment;

Expand All @@ -40,12 +44,16 @@ impl Buffer {
self.size
}

pub(crate) fn len<E: Unit>(&self) -> usize {
self.size / std::mem::size_of::<E>()
}

#[allow(unused)]
pub(crate) fn capacity(&self) -> usize {
self.data.size() as usize
}

pub(crate) fn copy_to_device<E: Unit>(&self, dev: &Device, queue: &Queue, slice: &[E]) {
pub(crate) fn copy_to_device<E>(&self, dev: &Device, queue: &Queue, slice: &[E]) {
let slice = unsafe {
std::slice::from_raw_parts(
slice.as_ptr() as *const u8,
Expand Down Expand Up @@ -102,6 +110,7 @@ pub struct Webgpu {
pub(crate) queue: Arc<Queue>,

pub(crate) cache: Arc<TensorCache<Buffer>>,
pub(crate) cs_cache: Arc<RwLock<HashMap<TypeId, Arc<ShaderModule>>>>,
}

impl From<RequestDeviceError> for Error {
Expand Down Expand Up @@ -134,8 +143,13 @@ impl Webgpu {
let adapter = futures_lite::future::block_on(instance.request_adapter(&Default::default()))
.ok_or(Error::WebgpuAdapterNotFound)?;
let adapter = Arc::new(adapter);
let descriptor = DeviceDescriptor {
label: None,
features: Features::default() | Features::SPIRV_SHADER_PASSTHROUGH,
limits: Default::default(),
};
let (dev, queue) =
futures_lite::future::block_on(adapter.request_device(&Default::default(), None))?;
futures_lite::future::block_on(adapter.request_device(&descriptor, None))?;
let dev = Arc::new(dev);
let queue = Arc::new(queue);

Expand All @@ -147,18 +161,19 @@ impl Webgpu {
queue,

cache: Default::default(),
cs_cache: Default::default(),
})
}
}

impl Webgpu {
pub(crate) unsafe fn alloc_empty<E>(&self, len: usize) -> Result<Buffer, Error> {
pub(crate) fn alloc_empty<E>(&self, len: usize) -> Result<Buffer, Error> {
let data = self.cache.try_pop::<E>(len).map_or_else(
|| Buffer {
data: self.dev.create_buffer(&BufferDescriptor {
label: None,
size: round_to_buffer_alignment((len * std::mem::size_of::<E>()) as u64),
usage: BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
mapped_at_creation: false,
}),
size: len * std::mem::size_of::<E>(),
Expand All @@ -168,6 +183,71 @@ impl Webgpu {
Ok(data)
}

pub(crate) fn alloc_init<E>(&self, init: &[E]) -> Result<Buffer, Error> {
let data = self.cache.try_pop::<E>(init.len()).map_or_else(
|| {
let contents = unsafe {
std::slice::from_raw_parts(
init.as_ptr() as *const u8,
init.len() * std::mem::size_of::<E>(),
)
};
Buffer {
data: self.dev.create_buffer_init(&BufferInitDescriptor {
label: None,
usage: BufferUsages::STORAGE
| BufferUsages::COPY_SRC
| BufferUsages::COPY_DST,
contents,
}),
size: init.len() * std::mem::size_of::<E>(),
}
},
|bfr| {
bfr.copy_to_device::<E>(&self.dev, &self.queue, init);
bfr
},
);
Ok(data)
}

#[cfg(not(feature = "no-std"))]
pub(crate) fn shader_module_loaded(&self, name: TypeId) -> bool {
self.cs_cache.read().unwrap().contains_key(&name)
}

#[cfg(feature = "no-std")]
pub(crate) fn shader_module_loaded(&self, name: TypeId) -> bool {
self.cs_cache.read().contains_key(&name)
}

pub(crate) fn load_shader_module<E>(&self, name: TypeId, source: &[u8])
where
E: HasGlslType,
{
let module = Arc::new(unsafe {
self.dev
.create_shader_module_spirv(&ShaderModuleDescriptorSpirV {
label: None,
source: make_spirv_raw(source),
})
});
#[cfg(not(feature = "no-std"))]
self.cs_cache.write().unwrap().insert(name, module);
#[cfg(feature = "no-std")]
self.cs_cache.write().insert(name, module);
}

#[cfg(not(feature = "no-std"))]
pub(crate) fn get_shader_module(&self, name: TypeId) -> Option<Arc<ShaderModule>> {
self.cs_cache.read().unwrap().get(&name).cloned()
}

#[cfg(feature = "no-std")]
pub(crate) fn get_shader_module(&self, name: TypeId) -> Option<Arc<ShaderModule>> {
self.cs_cache.read().get(&name).cloned()
}

// #[allow(unused)]
// pub(crate) unsafe fn get_workspace<E>(&self, len: usize) -> Result<MutexGuard<Buffer>, Error> {
// let num_bytes_required = len * std::mem::size_of::<E>();
Expand Down Expand Up @@ -312,7 +392,7 @@ impl<E: Unit> Storage<E> for Webgpu {
type Vec = CachableBuffer<E>;

fn try_alloc_len(&self, len: usize) -> Result<Self::Vec, Error> {
let data = unsafe { self.alloc_empty::<E>(len) }?;
let data = self.alloc_empty::<E>(len)?;
Ok(CachableBuffer {
dev: self.dev.clone(),
queue: self.queue.clone(),
Expand Down
28 changes: 28 additions & 0 deletions dfdx-core/src/tensor_ops/abs/abs.bwd.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#version 460 core

#extension GL_ARB_compute_shader: enable
#extension GL_ARB_shader_storage_buffer_object: enable

layout(local_size_x = 128) in;

layout(std430, binding = 1) buffer inpBlock {
TYPENAME inp[];
};

layout(std430, binding = 2) buffer outpBlock {
TYPENAME outp[];
};

layout(std430, binding = 3) buffer input_gradBlock {
TYPENAME input_grad[];
};

layout(std430, binding = 4) buffer output_gradBlock {
TYPENAME output_grad[];
};

void main() {
TYPENAME dx = sign(inp[gl_GlobalInvocationID.x]);

input_grad[gl_GlobalInvocationID.x] = dx * output_grad[gl_GlobalInvocationID.x];
}
22 changes: 22 additions & 0 deletions dfdx-core/src/tensor_ops/abs/abs.fwd.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#version 460 core

#extension GL_ARB_compute_shader: enable
#extension GL_ARB_shader_storage_buffer_object: enable

layout(local_size_x = 128) in;

layout(std430, binding = 1) buffer inpBlock {
TYPENAME inp[];
};

layout(std430, binding = 2) buffer outpBlock{
TYPENAME outp[];
};

void main() {
if (inp.length() == 0) {
outp[gl_GlobalInvocationID.x] = abs(outp[gl_GlobalInvocationID.x]);
} else {
outp[gl_GlobalInvocationID.x] = abs(inp[gl_GlobalInvocationID.x]);
}
}
Loading

0 comments on commit 630514f

Please sign in to comment.