Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add hip support #330

Merged
merged 2 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion optimum/quanto/library/extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@


if torch.cuda.is_available():
from .cuda import *
if torch.version.cuda:
from .cuda import *
elif torch.version.hip:
from .hip import *

if torch.backends.mps.is_available():
from .mps import *
36 changes: 36 additions & 0 deletions optimum/quanto/library/extensions/hip/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import torch

from ..extension import Extension, register_extension


__all__ = []


ext = Extension(
"quanto_hip",
root_dir=os.path.dirname(__file__),
sources=["unpack.cu", "pybind_module.cpp"],
extra_cflags=["-std=c++17"],
)
register_extension(ext)


@torch.library.impl("quanto::unpack", ["CUDA"])
def unpack_hip(t: torch.Tensor, bits: int):
return ext.lib.unpack(t, bits)
21 changes: 21 additions & 0 deletions optimum/quanto/library/extensions/hip/pybind_module.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright 2024 The HuggingFace Team. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <torch/extension.h>
#include "unpack.h"


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("unpack", &unpack, "unpack");
}
97 changes: 97 additions & 0 deletions optimum/quanto/library/extensions/hip/unpack.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright 2024 The HuggingFace Team. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <c10/cuda/CUDAException.h>

inline unsigned int cdiv(unsigned int a, unsigned int b) { return (a + b - 1) / b;}
#define BLOCK_SIZE 256

using namespace at;


static torch::Tensor allocate_output(const torch::Tensor& input, int bits) {
int n_packed = 8 / bits;
auto output_shape = input.sizes().vec();
output_shape[0] = output_shape[0] * n_packed;
return torch::empty(output_shape, input.options());
}

__global__ void unpack_4bit_kernel(unsigned char* input, unsigned char* output, int n) {
int i = blockIdx.x*blockDim.x + threadIdx.x;
if(i>=n) return;

output[i] = (input[i] & 0x0F);
output[i + n] = (input[i] & 0xF0) >> 4;
}

static torch::Tensor unpack_4bit(const torch::Tensor& input){

auto output = allocate_output(input, 4);

const auto numel = input.numel();
int blocks = cdiv(numel, BLOCK_SIZE);
unpack_4bit_kernel<<<blocks, BLOCK_SIZE>>>(
input.data_ptr<unsigned char>(),
output.data_ptr<unsigned char>(),
numel
);

C10_CUDA_KERNEL_LAUNCH_CHECK();

return output;
}

__global__ void unpack_2bit_kernel(unsigned char* input, unsigned char* output, int n) {
int i = blockIdx.x*blockDim.x + threadIdx.x;
if(i>=n) return;

output[i] = (input[i] & 0x03);
output[i + n] = (input[i] & 0x0C) >> 2;
output[i + n*2] = (input[i] & 0x30) >> 4;
output[i + n*3] = (input[i] & 0xC0) >> 6;
}

static torch::Tensor unpack_2bit(const torch::Tensor& input){

auto output = allocate_output(input, 2);

const auto numel = input.numel();
int blocks = cdiv(numel, BLOCK_SIZE);
unpack_2bit_kernel<<<blocks, BLOCK_SIZE>>>(
input.data_ptr<unsigned char>(),
output.data_ptr<unsigned char>(),
numel
);

C10_CUDA_KERNEL_LAUNCH_CHECK();

return output;
}

torch::Tensor unpack(torch::Tensor &t, int bits) {
TORCH_CHECK(t.scalar_type() == torch::kUInt8, "Unsupported data type: ", t.scalar_type());
TORCH_CHECK(t.device().is_cuda(), "t must be a CUDA tensor.");
TORCH_CHECK(t.is_contiguous(), "t must be contiguous.");
switch(bits) {
case 4:
return unpack_4bit(t);
case 2:
return unpack_2bit(t);
default:
throw std::invalid_argument("Can only unpack 2-bit or 4-bit tensors.");
}
}
17 changes: 17 additions & 0 deletions optimum/quanto/library/extensions/hip/unpack.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright 2024 The HuggingFace Team. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <torch/extension.h>

torch::Tensor unpack(torch::Tensor &t, int bits);
4 changes: 2 additions & 2 deletions optimum/quanto/tensor/weights/qbits.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,15 @@ def create(qtype, axis, group_size, size, stride, data, scale, shift, requires_g
and axis == 0
and group_size == 128
and len(size) == 2
and data.device.type == "cuda"
and (data.device.type == "cuda" and torch.version.cuda)
and torch.cuda.get_device_capability(data.device)[0] >= 8
):
if type(data) is PackedTensor:
data = data.unpack()
return AWQWeightQBitsTensor(qtype, axis, group_size, size, stride, data, scale, shift, requires_grad)
if qtype == qint4 and scale.dtype == torch.bfloat16 and axis == 0 and group_size == 128 and len(size) == 2:
if data.device.type == "cpu" or (
data.device.type == "cuda"
(data.device.type == "cuda" and torch.version.cuda)
and version.parse(torch.version.cuda).release >= (12, 1)
and torch.cuda.get_device_capability(data.device)[0] >= 8
):
Expand Down
2 changes: 1 addition & 1 deletion optimum/quanto/tensor/weights/qbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def create(
and activation_qtype is None
and scale.dtype in [torch.float16, torch.bfloat16]
and len(size) == 2
and data.device.type == "cuda"
and (data.device.type == "cuda" and torch.version.cuda)
and axis == 0
and torch.cuda.get_device_capability(data.device)[0] >= 8
):
Expand Down
5 changes: 4 additions & 1 deletion test/library/test_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@

extension_names = ["quanto_cpp"]
if torch.cuda.is_available():
extension_names.append("quanto_cuda")
if torch.version.cuda:
extension_names.append("quanto_cuda")
if torch.version.hip:
extension_names.append("quanto_hip")
if torch.backends.mps.is_available():
extension_names.append("quanto_mps")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch
from helpers import device_eq

from optimum.quanto.library.extensions import is_extension_available
from optimum.quanto.tensor.weights.marlin.fp8 import MarlinF8PackedTensor


Expand All @@ -36,7 +37,7 @@ def get_fp8_tensor(shape, device, random=False):
return t.view(torch.float8_e4m3fn).to(device)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not is_extension_available("quanto_cuda"), reason="CUDA extension is not available")
@pytest.mark.parametrize("in_features", [128, 256, 512, 1024])
@pytest.mark.parametrize("out_features", [128, 256, 512, 1024])
@pytest.mark.parametrize("random", [True, False])
Expand All @@ -50,7 +51,7 @@ def test_pack_marlin_fp8_tensor(in_features, out_features, random):
assert torch.equal(t, packed.unpack())


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not is_extension_available("quanto_cuda"), reason="CUDA extension is not available")
def test_move_marlin_fp8_tensor():
shape = (256, 256)
device = torch.device("cuda")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
@pytest.mark.parametrize("out_features", [128, 256, 512, 1024])
def test_tinygemm_weight_qbits_tensor_from_qbits_tensor(in_features, out_features, device):
if device.type == "cuda":
if torch.version.hip:
pytest.skip(reason="TinyGemm not available for ROCm devices")
if version.parse(torch.version.cuda).release < (12, 1):
pytest.skip(reason="CUDA runtime must be at least 12.1")
if torch.cuda.get_device_capability()[0] < 8:
Expand Down Expand Up @@ -98,6 +100,8 @@ def test_tinygemm_weight_qbits_tensor_move(device):
@pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"])
def test_tinygemm_weight_qbits_tensor_linear(batch_size, tokens, embeddings, use_bias, device):
if device.type == "cuda":
if torch.version.hip:
pytest.skip(reason="TinyGemm not available for ROCm devices")
if version.parse(torch.version.cuda).release < (12, 1):
pytest.skip(reason="CUDA runtime must be at least 12.1")
if torch.cuda.get_device_capability()[0] < 8:
Expand Down
Loading