Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove overheads in library #328

Merged
merged 3 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion optimum/quanto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "0.2.5dev"
__version__ = "0.2.6dev"

from .calibrate import *
from .library import *
Expand Down
3 changes: 1 addition & 2 deletions optimum/quanto/library/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

from .extensions import *
from .ops import *
from .python import *
from .qbytes_mm import *
from .quantize import *
from .unpack import *
42 changes: 34 additions & 8 deletions optimum/quanto/library/extensions/README.md
Original file line number Diff line number Diff line change
@@ -1,23 +1,49 @@
# Quanto library extensions

This folder contains the implementations of all `quanto_ext::` operations.

This namespace corresponds to the device-specifc optimized implementations of quanto operations.
This folder contains device-specific `quanto::` operations.

Implementations can be provided as part of:

- the generic C++ pytorch extension under `cpp`,
- the CUDA extension under `cuda`,
- the Metal Performance Shader extension under `mps`.

The operations are defined in `library/ops.py`.

To provide an implementation for specific device types, use the following syntax:
To provide a device-specific implementation of an operation that already has a default implementation (such as unpack), use the following syntax:

```python
@torch.library.impl("quanto_ext::unpack", ["CPU", "CUDA"])
@torch.library.impl("quanto::unpack", ["CPU", "CUDA"])
def unpack(packed: torch.Tensor, bits: int) -> torch.Tensor:
return ext().unpack(t, bits)
return ext.unpack(t, bits)
```

To declare a new device-specific operation, you need to add it to the library:

```python
torch.library.define(
"quanto::gemm_f16i4",
"(Tensor input,"
" Tensor other,"
" Tensor other_scale,"
" Tensor other_shift,"
" int group_size)"
" -> Tensor",
)
```

Please refer to each extension folder to see how to add the actual implementation.
Then you can provide its implementation:

```python
@torch.library.impl("quanto::gemm_f16i4", ["CUDA"])
def gemm_f16i4(
input: torch.Tensor,
other: torch.Tensor,
scales: torch.Tensor,
shift: torch.Tensor,
group_size: int,
) -> torch.Tensor:
...
```


Please refer to each extension folder for examples.
2 changes: 1 addition & 1 deletion optimum/quanto/library/extensions/cpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@
)


@torch.library.impl("quanto_ext::unpack", ["CPU"])
@torch.library.impl("quanto::unpack", ["CPU"])
def unpack_cpp(t: torch.Tensor, bits: int):
return ext.lib.unpack(t, bits)
2 changes: 1 addition & 1 deletion optimum/quanto/library/extensions/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def get_max_cuda_arch():
)


@torch.library.impl("quanto_ext::unpack", ["CUDA"])
@torch.library.impl("quanto::unpack", ["CUDA"])
def unpack_cuda(t: torch.Tensor, bits: int):
return ext.lib.unpack(t, bits)

Expand Down
2 changes: 1 addition & 1 deletion optimum/quanto/library/extensions/mps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@
)


@torch.library.impl("quanto_ext::unpack", "MPS")
@torch.library.impl("quanto::unpack", "MPS")
def unpack_mps(t: torch.Tensor, bits: int):
return ext.lib.unpack(t, bits)
70 changes: 0 additions & 70 deletions optimum/quanto/library/ops.py

This file was deleted.

18 changes: 0 additions & 18 deletions optimum/quanto/library/python/README.md

This file was deleted.

15 changes: 0 additions & 15 deletions optimum/quanto/library/python/__init__.py

This file was deleted.

17 changes: 14 additions & 3 deletions optimum/quanto/library/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@
from ..tensor import dtype_info, group


@torch.library.custom_op("quanto::quantize_symmetric", mutates_args=())
torch.library.define(
"quanto::quantize_symmetric", "(Tensor base, ScalarType dtype, int? axis, Tensor scale) -> Tensor"
)


@torch.library.impl("quanto::quantize_symmetric", "default")
def quantize_symmetric(
base: torch.Tensor, dtype: torch.dtype, axis: Union[int, None], scale: torch.Tensor
) -> torch.Tensor:
Expand Down Expand Up @@ -50,12 +55,18 @@ def quantize_symmetric(
return torch.clamp(data, min=info.min, max=info.max).to(dtype)


@torch.library.custom_op("quanto::quantize_affine", mutates_args=())
torch.library.define(
"quanto::quantize_affine",
"(Tensor base, int bits, int axis, int? group_size, Tensor scale, Tensor shift) -> Tensor",
)


@torch.library.impl("quanto::quantize_affine", "default")
def quantize_affine(
base: torch.Tensor, bits: int, axis: int, group_size: Union[int, None], scale: torch.Tensor, shift: torch.Tensor
) -> torch.Tensor:
if axis not in (0, -1):
raise ValueError("QBitsTensor axis parameter must be 0 (first axis) or -1 (last axis)")
raise ValueError("axis parameter must be 0 (first axis) or -1 (last axis)")
if group_size is not None:
base = group(base, axis=axis, group_size=group_size)
if shift.dtype.is_floating_point:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
import torch


@torch.library.impl("quanto_py::unpack", "default")
torch.library.define("quanto::unpack", "(Tensor self, int bits) -> Tensor")


@torch.library.impl("quanto::unpack", "default")
def unpack(packed: torch.Tensor, bits: int) -> torch.Tensor:
"""
Un-Pack int4 / int2 weights (packed in a uint8) into a torch.uint8 tensor
Expand Down
9 changes: 2 additions & 7 deletions test/library/test_unpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import nullcontext

import pytest
import torch

from optimum.quanto.library import disable_extensions
from optimum.quanto.tensor.packed import pack_weights


@pytest.mark.parametrize("bits", [2, 4], ids=["int2", "int4"])
@pytest.mark.parametrize("shape", [(12,), (32, 32)], ids=["vector", "matrix"])
@pytest.mark.parametrize("use_ext", [True, False], ids=["ext", "no-ext"])
def test_unpack(bits, shape, use_ext, device):
def test_unpack(bits, shape, device):
qmax = 2**bits
a = torch.randint(0, qmax, shape, dtype=torch.uint8).to(device)
packed_a = pack_weights(a, bits)
context = nullcontext() if use_ext else disable_extensions()
with context:
unpacked_a = torch.ops.quanto.unpack(packed_a, bits)
unpacked_a = torch.ops.quanto.unpack(packed_a, bits)
assert unpacked_a.dtype == torch.uint8
assert torch.equal(unpacked_a, a)
Loading