Skip to content

Commit

Permalink
[tests] enable BNB test cases in tests/test_quantization.py on XPU (#…
Browse files Browse the repository at this point in the history
…3349)

* enable bnb tests

* bug fix

* fix quality issue

* furter fix quality

* fix style
  • Loading branch information
faaany authored Jan 17, 2025
1 parent 02d2561 commit 7e32410
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 21 deletions.
13 changes: 12 additions & 1 deletion src/accelerate/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
send_to_device,
set_module_tensor_to_device,
)
from .utils.imports import (
is_mlu_available,
is_npu_available,
is_xpu_available,
)
from .utils.memory import clear_device_cache
from .utils.modeling import get_non_persistent_buffers
from .utils.other import recursive_getattr
Expand Down Expand Up @@ -381,9 +386,15 @@ def post_forward(self, module, output):
# We may have loaded tied weights into self.tied_params_map (avoiding to load them several times in e.g. submodules): remove them from
# this dictionary to allow the garbage collector to do its job.
for value_pointer, device in self.tied_pointers_to_remove:
if isinstance(device, int):
if is_npu_available():
device = f"npu:{device}"
elif is_mlu_available():
device = f"mlu:{device}"
elif is_xpu_available():
device = f"xpu:{device}"
del self.tied_params_map[value_pointer][device]
self.tied_pointers_to_remove = set()

if self.io_same_device and self.input_device is not None:
output = send_to_device(output, self.input_device, skip_keys=self.skip_keys)

Expand Down
1 change: 1 addition & 0 deletions src/accelerate/test_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
require_bnb,
require_cpu,
require_cuda,
require_cuda_or_xpu,
require_huggingface_suite,
require_mlu,
require_mps,
Expand Down
10 changes: 10 additions & 0 deletions src/accelerate/test_utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,16 @@ def require_xpu(test_case):
return unittest.skipUnless(is_xpu_available(), "test requires a XPU")(test_case)


def require_cuda_or_xpu(test_case):
"""
Decorator marking a test that requires CUDA or XPU. These tests are skipped when there are no GPU available or when
TorchXLA is available.
"""
cuda_condition = is_cuda_available() and not is_torch_xla_available()
xpu_condition = is_xpu_available()
return unittest.skipUnless(cuda_condition or xpu_condition, "test requires a CUDA GPU or XPU")(test_case)


def require_non_xpu(test_case):
"""
Decorator marking a test that should be skipped for XPU.
Expand Down
6 changes: 4 additions & 2 deletions src/accelerate/utils/bnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,13 @@ def load_and_quantize_model(
torch.cuda.empty_cache()
elif torch.cuda.is_available():
model.to(torch.cuda.current_device())
elif torch.xpu.is_available():
model.to(torch.xpu.current_device())
else:
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
logger.info(
f"The model device type is {model_device.type}. However, cuda is needed for quantization."
"We move the model to cuda."
f"The model device type is {model_device.type}. However, gpu is needed for quantization."
"We move the model to gpu."
)
return model

Expand Down
34 changes: 16 additions & 18 deletions tests/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@
from accelerate import Accelerator, init_empty_weights
from accelerate.test_utils import (
require_bnb,
require_cuda,
require_cuda_or_xpu,
require_huggingface_suite,
require_multi_gpu,
require_multi_device,
require_non_torch_xla,
slow,
)
from accelerate.utils.bnb import load_and_quantize_model
from accelerate.utils.dataclasses import BnbQuantizationConfig
from accelerate.utils.memory import clear_device_cache


class BitsAndBytesConfigIntegration(unittest.TestCase):
Expand All @@ -40,7 +41,7 @@ def test_BnbQuantizationConfig(self):

@require_non_torch_xla
@slow
@require_cuda
@require_cuda_or_xpu
@require_bnb
@require_huggingface_suite
class MixedInt8EmptyModelTest(unittest.TestCase):
Expand Down Expand Up @@ -97,8 +98,7 @@ def tearDown(self):
del self.model_fp16
del self.model_8bit

gc.collect()
torch.cuda.empty_cache()
clear_device_cache(garbage_collection=True)

def test_memory_footprint(self):
r"""
Expand Down Expand Up @@ -198,7 +198,7 @@ def test_fp32_8bit_conversion(self):
)
assert model.lm_head.weight.dtype == torch.float32

@require_multi_gpu
@require_multi_device
def test_cpu_gpu_loading_custom_device_map(self):
from bitsandbytes.nn import Int8Params
from transformers import AutoConfig, AutoModelForCausalLM
Expand Down Expand Up @@ -253,7 +253,7 @@ def test_cpu_gpu_loading_custom_device_map(self):
assert model_8bit.transformer.h[1].mlp.dense_4h_to_h.weight.__class__ == Int8Params
self.check_inference_correctness(model_8bit)

@require_multi_gpu
@require_multi_device
def test_cpu_gpu_loading_custom_device_map_offload_state_dict(self):
from bitsandbytes.nn import Int8Params
from transformers import AutoConfig, AutoModelForCausalLM
Expand Down Expand Up @@ -310,7 +310,7 @@ def test_cpu_gpu_loading_custom_device_map_offload_state_dict(self):
assert model_8bit.transformer.h[1].mlp.dense_4h_to_h.weight.__class__ == Int8Params
self.check_inference_correctness(model_8bit)

@require_multi_gpu
@require_multi_device
def test_cpu_gpu_disk_loading_custom_device_map_kwargs(self):
from bitsandbytes.nn import Int8Params
from transformers import AutoConfig, AutoModelForCausalLM
Expand Down Expand Up @@ -401,12 +401,11 @@ def test_int8_serialization(self):

self.check_inference_correctness(model_8bit_from_saved)

@require_multi_gpu
@require_multi_device
def test_int8_serialization_offload(self):
r"""
Test whether it is possible to serialize a model in 8-bit and offload weights to cpu/disk
"""

from bitsandbytes.nn import Int8Params
from transformers import AutoConfig, AutoModelForCausalLM

Expand Down Expand Up @@ -499,7 +498,7 @@ def test_int8_serialization_shard(self):

@require_non_torch_xla
@slow
@require_cuda
@require_cuda_or_xpu
@require_bnb
@require_huggingface_suite
class MixedInt8LoaddedModelTest(unittest.TestCase):
Expand Down Expand Up @@ -605,7 +604,7 @@ def test_fp32_8bit_conversion(self):

@require_non_torch_xla
@slow
@require_cuda
@require_cuda_or_xpu
@require_bnb
@require_huggingface_suite
class Bnb4BitEmptyModelTest(unittest.TestCase):
Expand Down Expand Up @@ -736,7 +735,7 @@ def test_fp32_4bit_conversion(self):
)
assert model.lm_head.weight.dtype == torch.float32

@require_multi_gpu
@require_multi_device
def test_cpu_gpu_loading_random_device_map(self):
from transformers import AutoConfig, AutoModelForCausalLM

Expand Down Expand Up @@ -789,7 +788,7 @@ def test_cpu_gpu_loading_random_device_map(self):
)
self.check_inference_correctness(model_4bit)

@require_multi_gpu
@require_multi_device
def test_cpu_gpu_loading_custom_device_map(self):
from transformers import AutoConfig, AutoModelForCausalLM

Expand Down Expand Up @@ -819,7 +818,7 @@ def test_cpu_gpu_loading_custom_device_map(self):
)
self.check_inference_correctness(model_4bit)

@require_multi_gpu
@require_multi_device
def test_cpu_gpu_disk_loading_custom_device_map_kwargs(self):
from transformers import AutoConfig, AutoModelForCausalLM

Expand Down Expand Up @@ -855,7 +854,7 @@ def test_cpu_gpu_disk_loading_custom_device_map_kwargs(self):

@require_non_torch_xla
@slow
@require_cuda
@require_cuda_or_xpu
@require_bnb
@require_huggingface_suite
class Bnb4BitTestLoadedModel(unittest.TestCase):
Expand Down Expand Up @@ -904,8 +903,7 @@ def tearDown(self):
del self.model_fp16
del self.model_4bit

gc.collect()
torch.cuda.empty_cache()
clear_device_cache(garbage_collection=True)

def test_memory_footprint(self):
r"""
Expand Down

0 comments on commit 7e32410

Please sign in to comment.