Skip to content

Commit

Permalink
feat: 添加 attention cuda
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Jan 29, 2024
1 parent bbed31e commit f1faf3e
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 7 deletions.
27 changes: 22 additions & 5 deletions src/04kernel/src/collectors/attention.cc
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
#include "kernel/collectors/attention.h"
#include "kernel/kernel.h"
#include "kernel/tensor.h"
// #include "../kernels/attention/cpu_kernel.hh"
// #include "../kernels/attention/cuda_kernel.hh"
#include "../kernels/attention/cuda_kernel.hh"

namespace refactor::kernel {

Expand All @@ -14,12 +12,31 @@ namespace refactor::kernel {

std::vector<KernelBox>
AttentionCollector::filter(TensorRefs inputs, TensorRefs outputs) const {
auto const &query = inputs[0].get();
auto const &key = inputs[1].get();
auto pastSeqLen = inputs.size() == 3 ? 0 : *inputs[2].get().data->get<int64_t>();
auto cacheLen = outputs.size() == 1 ? 0 : outputs[1].get().shape[2];

std::vector<KernelBox> ans;
switch (_target) {
case decltype(_target)::Cpu:
break;
case decltype(_target)::Nvidia:
break;
case decltype(_target)::Nvidia: {
decltype(AttentionCuda::info) info{
.dataType = query.dataType,
.batch = query.shape[0],
.nHead = query.shape[1],
.nKVHead = key.shape[1],
.pastSeqLen = static_cast<dim_t>(pastSeqLen),
.seqLen = query.shape[2],
.cacheLen = cacheLen,
.headDim = query.shape[3],
.resetCache = false,
};
if (auto ptr = AttentionCuda::build(info); ptr) {
ans.emplace_back(std::move(ptr));
}
} break;
case decltype(_target)::Mlu:
break;
default:
Expand Down
28 changes: 28 additions & 0 deletions src/04kernel/src/kernels/attention/cuda_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#include "cuda_kernel.hh"

namespace refactor::kernel {
using K = AttentionCuda;

K::AttentionCuda(decltype(info) info_) noexcept
: Kernel(), info(info_) {}

auto K::build(decltype(info) info) noexcept -> KernelBox {
#ifndef USE_CUDA
return nullptr;
#endif

return std::make_unique<K>(info);
}
auto K::typeId() noexcept -> size_t {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}

auto K::kernelTypeId() const noexcept -> size_t {
return typeId();
}
auto K::description() const noexcept -> std::string_view {
return "Performing multihead attention on Nvidia gpu";
}

}// namespace refactor::kernel
30 changes: 30 additions & 0 deletions src/04kernel/src/kernels/attention/cuda_kernel.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#ifndef KERNEL_ATTENTION_CUDA_KERNEL_HH
#define KERNEL_ATTENTION_CUDA_KERNEL_HH

#include "kernel/kernel.h"
#include "kernel/tensor.h"

namespace refactor::kernel {

struct AttentionCuda final : public Kernel {
struct {
DataType dataType;
dim_t batch, nHead, nKVHead, pastSeqLen, seqLen, cacheLen, headDim;
bool resetCache;
} info;

AttentionCuda(decltype(info)) noexcept;

static KernelBox build(decltype(info)) noexcept;
static size_t typeId() noexcept;

size_t kernelTypeId() const noexcept final;
std::string_view description() const noexcept final;
#ifdef USE_CUDA
RoutineWorkspace lower(Resources &) const final;
#endif
};

}// namespace refactor::kernel

#endif// KERNEL_ATTENTION_CUDA_KERNEL_HH
4 changes: 2 additions & 2 deletions src/08-01llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ y = (x^2 + δ)^(-1/2) * w * x

- **Y(heterogeneous) - T**: 输出张量。形状与 `X` 相同。

## Attention
## Attention (with Causual Decoder Mask)

### Summary

Expand All @@ -36,7 +36,7 @@ Multi-head Self Attention 的封装形式,用于 transformer 模型。

| 序号 | 输入数量 | `max_seq_len` | 使用 kv cache | 输出数量 | cache s 维度 | 备注
|:-:|:-:|:-----:|:-------:|:-:|:------------------------:|:-
| 1 | 3 | 0 | none | 1 | - |
| 1 | 3 | 0 | none | 1 | - | -
| 2 | 3 | S > 0 | init | 3 | `S` | `assert(S >= seq_len)`
| 3 | 4 | 0 | inplace | 3 | `past_seq_len + seq_len` | `past_seq_len` 必须是常量
| 4 | 4 | S > 0 | inplace | 3 | `S` | `assert(S >= past_seq_len + seq_len)`
Expand Down

0 comments on commit f1faf3e

Please sign in to comment.