diff --git a/optimum/quanto/library/extensions/__init__.py b/optimum/quanto/library/extensions/__init__.py index 2089cef8..44979e5d 100644 --- a/optimum/quanto/library/extensions/__init__.py +++ b/optimum/quanto/library/extensions/__init__.py @@ -19,7 +19,10 @@ if torch.cuda.is_available(): - from .cuda import * + if torch.version.cuda: + from .cuda import * + elif torch.version.hip: + from .hip import * if torch.backends.mps.is_available(): from .mps import * diff --git a/optimum/quanto/library/extensions/hip/__init__.py b/optimum/quanto/library/extensions/hip/__init__.py new file mode 100644 index 00000000..79836e3b --- /dev/null +++ b/optimum/quanto/library/extensions/hip/__init__.py @@ -0,0 +1,36 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch + +from ..extension import Extension, register_extension + + +__all__ = [] + + +ext = Extension( + "quanto_hip", + root_dir=os.path.dirname(__file__), + sources=["unpack.cu", "pybind_module.cpp"], + extra_cflags=["-std=c++17"], +) +register_extension(ext) + + +@torch.library.impl("quanto::unpack", ["CUDA"]) +def unpack_hip(t: torch.Tensor, bits: int): + return ext.lib.unpack(t, bits) diff --git a/optimum/quanto/library/extensions/hip/pybind_module.cpp b/optimum/quanto/library/extensions/hip/pybind_module.cpp new file mode 100644 index 00000000..0d0baea4 --- /dev/null +++ b/optimum/quanto/library/extensions/hip/pybind_module.cpp @@ -0,0 +1,21 @@ +// Copyright 2024 The HuggingFace Team. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "unpack.h" + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("unpack", &unpack, "unpack"); +} diff --git a/optimum/quanto/library/extensions/hip/unpack.cu b/optimum/quanto/library/extensions/hip/unpack.cu new file mode 100644 index 00000000..1309c833 --- /dev/null +++ b/optimum/quanto/library/extensions/hip/unpack.cu @@ -0,0 +1,97 @@ +// Copyright 2024 The HuggingFace Team. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +inline unsigned int cdiv(unsigned int a, unsigned int b) { return (a + b - 1) / b;} +#define BLOCK_SIZE 256 + +using namespace at; + + +static torch::Tensor allocate_output(const torch::Tensor& input, int bits) { + int n_packed = 8 / bits; + auto output_shape = input.sizes().vec(); + output_shape[0] = output_shape[0] * n_packed; + return torch::empty(output_shape, input.options()); +} + +__global__ void unpack_4bit_kernel(unsigned char* input, unsigned char* output, int n) { + int i = blockIdx.x*blockDim.x + threadIdx.x; + if(i>=n) return; + + output[i] = (input[i] & 0x0F); + output[i + n] = (input[i] & 0xF0) >> 4; +} + +static torch::Tensor unpack_4bit(const torch::Tensor& input){ + + auto output = allocate_output(input, 4); + + const auto numel = input.numel(); + int blocks = cdiv(numel, BLOCK_SIZE); + unpack_4bit_kernel<<>>( + input.data_ptr(), + output.data_ptr(), + numel + ); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return output; +} + +__global__ void unpack_2bit_kernel(unsigned char* input, unsigned char* output, int n) { + int i = blockIdx.x*blockDim.x + threadIdx.x; + if(i>=n) return; + + output[i] = (input[i] & 0x03); + output[i + n] = (input[i] & 0x0C) >> 2; + output[i + n*2] = (input[i] & 0x30) >> 4; + output[i + n*3] = (input[i] & 0xC0) >> 6; +} + +static torch::Tensor unpack_2bit(const torch::Tensor& input){ + + auto output = allocate_output(input, 2); + + const auto numel = input.numel(); + int blocks = cdiv(numel, BLOCK_SIZE); + unpack_2bit_kernel<<>>( + input.data_ptr(), + output.data_ptr(), + numel + ); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return output; +} + +torch::Tensor unpack(torch::Tensor &t, int bits) { + TORCH_CHECK(t.scalar_type() == torch::kUInt8, "Unsupported data type: ", t.scalar_type()); + TORCH_CHECK(t.device().is_cuda(), "t must be a CUDA tensor."); + TORCH_CHECK(t.is_contiguous(), "t must be contiguous."); + switch(bits) { + case 4: + return unpack_4bit(t); + case 2: + return unpack_2bit(t); + default: + throw std::invalid_argument("Can only unpack 2-bit or 4-bit tensors."); + } +} diff --git a/optimum/quanto/library/extensions/hip/unpack.h b/optimum/quanto/library/extensions/hip/unpack.h new file mode 100644 index 00000000..788024fa --- /dev/null +++ b/optimum/quanto/library/extensions/hip/unpack.h @@ -0,0 +1,17 @@ +// Copyright 2024 The HuggingFace Team. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +torch::Tensor unpack(torch::Tensor &t, int bits); diff --git a/optimum/quanto/tensor/weights/qbits.py b/optimum/quanto/tensor/weights/qbits.py index a63a906f..3afce3f5 100644 --- a/optimum/quanto/tensor/weights/qbits.py +++ b/optimum/quanto/tensor/weights/qbits.py @@ -101,7 +101,7 @@ def create(qtype, axis, group_size, size, stride, data, scale, shift, requires_g and axis == 0 and group_size == 128 and len(size) == 2 - and data.device.type == "cuda" + and (data.device.type == "cuda" and torch.version.cuda) and torch.cuda.get_device_capability(data.device)[0] >= 8 ): if type(data) is PackedTensor: @@ -109,7 +109,7 @@ def create(qtype, axis, group_size, size, stride, data, scale, shift, requires_g return AWQWeightQBitsTensor(qtype, axis, group_size, size, stride, data, scale, shift, requires_grad) if qtype == qint4 and scale.dtype == torch.bfloat16 and axis == 0 and group_size == 128 and len(size) == 2: if data.device.type == "cpu" or ( - data.device.type == "cuda" + (data.device.type == "cuda" and torch.version.cuda) and version.parse(torch.version.cuda).release >= (12, 1) and torch.cuda.get_device_capability(data.device)[0] >= 8 ): diff --git a/optimum/quanto/tensor/weights/qbytes.py b/optimum/quanto/tensor/weights/qbytes.py index d6b10355..6d316218 100644 --- a/optimum/quanto/tensor/weights/qbytes.py +++ b/optimum/quanto/tensor/weights/qbytes.py @@ -124,7 +124,7 @@ def create( and activation_qtype is None and scale.dtype in [torch.float16, torch.bfloat16] and len(size) == 2 - and data.device.type == "cuda" + and (data.device.type == "cuda" and torch.version.cuda) and axis == 0 and torch.cuda.get_device_capability(data.device)[0] >= 8 ): diff --git a/test/library/test_extensions.py b/test/library/test_extensions.py index d4fee049..ab05f2aa 100644 --- a/test/library/test_extensions.py +++ b/test/library/test_extensions.py @@ -6,7 +6,10 @@ extension_names = ["quanto_cpp"] if torch.cuda.is_available(): - extension_names.append("quanto_cuda") + if torch.version.cuda: + extension_names.append("quanto_cuda") + if torch.version.hip: + extension_names.append("quanto_hip") if torch.backends.mps.is_available(): extension_names.append("quanto_mps") diff --git a/test/tensor/weights/optimized/test_tinygemm_weight_qbits_tensor.py b/test/tensor/weights/optimized/test_tinygemm_weight_qbits_tensor.py index c7d3b6fc..b0448c26 100644 --- a/test/tensor/weights/optimized/test_tinygemm_weight_qbits_tensor.py +++ b/test/tensor/weights/optimized/test_tinygemm_weight_qbits_tensor.py @@ -27,6 +27,8 @@ @pytest.mark.parametrize("out_features", [128, 256, 512, 1024]) def test_tinygemm_weight_qbits_tensor_from_qbits_tensor(in_features, out_features, device): if device.type == "cuda": + if torch.version.hip: + pytest.skip(reason="TinyGemm not available for ROCm devices") if version.parse(torch.version.cuda).release < (12, 1): pytest.skip(reason="CUDA runtime must be at least 12.1") if torch.cuda.get_device_capability()[0] < 8: @@ -98,6 +100,8 @@ def test_tinygemm_weight_qbits_tensor_move(device): @pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) def test_tinygemm_weight_qbits_tensor_linear(batch_size, tokens, embeddings, use_bias, device): if device.type == "cuda": + if torch.version.hip: + pytest.skip(reason="TinyGemm not available for ROCm devices") if version.parse(torch.version.cuda).release < (12, 1): pytest.skip(reason="CUDA runtime must be at least 12.1") if torch.cuda.get_device_capability()[0] < 8: