Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into apple-silicon-merge2
Browse files Browse the repository at this point in the history
  • Loading branch information
rickardp committed Jan 2, 2024
2 parents 84a6b78 + 947db7c commit d03d296
Show file tree
Hide file tree
Showing 13 changed files with 46 additions and 35 deletions.
11 changes: 4 additions & 7 deletions Makefile.previous
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ ifeq ($(CUDA_HOME),)
endif

ifndef CUDA_VERSION
ifneq ($(MAKECMDGOALS),clean)
$(warning WARNING: CUDA_VERSION not set. Call make with CUDA string, for example: make cuda11x CUDA_VERSION=115 or make cpuonly CUDA_VERSION=CPU)
CUDA_VERSION:=
endif
endif

NATIVE_ARCH:=$(shell (arch | sed -e s/arm64/aarch64/))
NATIVE_OS:=$(shell uname)
Expand Down Expand Up @@ -165,10 +167,5 @@ $(ROOT_DIR)/dependencies/cub:
cd dependencies/cub; git checkout 1.11.0

clean:
rm build/*

cleaneggs:
rm -rf *.egg*

cleanlibs:
rm ./bitsandbytes/libbitsandbytes*.so
rm -rf build/* *.egg*
rm -f bitsandbytes/libbitsandbytes*.so
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ python setup.py install
```python
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
'decapoda-research/llama-7b-hf,
'decapoda-research/llama-7b-hf',
device_map='auto',
load_in_8bit=True,
max_memory=f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB')
Expand Down Expand Up @@ -80,7 +80,7 @@ out = linear(x.to(torch.float16))
Requirements: anaconda, cudatoolkit, pytorch

Hardware requirements:
- LLM.int8(): NVIDIA Turing (RTX 20xx; T4) or Ampere GPU (RTX 30xx; A4-A100); (a GPU from 2018 or older).
- LLM.int8(): NVIDIA Turing (RTX 20xx; T4) or Ampere GPU (RTX 30xx; A4-A100); (a GPU from 2018 or newer).
- 8-bit optimizers and quantization: NVIDIA Kepler GPU or newer (>=GTX 78X).

Supported CUDA versions: 10.2 - 12.0
Expand All @@ -102,7 +102,7 @@ For straight Int8 matrix multiplication with mixed precision decomposition you c
bnb.matmul(..., threshold=6.0)
```

For instructions how to use LLM.int8() inference layers in your own code, see the TL;DR above or for extended instruction see [this blog post](https://github.com/huggingface/transformers).
For instructions how to use LLM.int8() inference layers in your own code, see the TL;DR above or for extended instruction see [this blog post](https://huggingface.co/blog/hf-bitsandbytes-integration).

### Using the 8-bit Optimizers

Expand All @@ -119,7 +119,7 @@ torch.nn.Embedding(...) -> bnb.nn.StableEmbedding(...) # recommended for NLP mo
```

Note that by default all parameter tensors with less than 4096 elements are kept at 32-bit even if you initialize those parameters with 8-bit optimizers. This is done since such small tensors do not save much memory and often contain highly variable parameters (biases) or parameters that require high precision (batch norm, layer norm). You can change this behavior like so:
```
```python
# parameter tensors with less than 16384 values are optimized in 32-bit
# it is recommended to use multiplies of 4096
adam = bnb.optim.Adam8bit(model.parameters(), min_8bit_size=16384)
Expand Down
6 changes: 3 additions & 3 deletions benchmarking/switchback/speed_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose
from bitsandbytes.triton.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze
from bitsandbytes.triton.int8_matmul_mixed_dequantize import int8_matmul_mixed_dequantize

# KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large.

Expand Down Expand Up @@ -72,8 +72,8 @@ def get_time(k, fn, info_dict):
get_time('standard_gx', lambda : g.matmul(w), info)
get_time('rowwise_fwd', lambda : int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise, None), info)
get_time('rowwise_bwd', lambda : int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise, None), info)
get_time('global_fwd', lambda : int8_matmul_mixed_dequanitze(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), info)
get_time('global_bwd', lambda : int8_matmul_mixed_dequanitze(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), info)
get_time('global_fwd', lambda : int8_matmul_mixed_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), info)
get_time('global_bwd', lambda : int8_matmul_mixed_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), info)
get_time('x_quantize_rowwise', lambda : quantize_rowwise(x), info)
get_time('g_quantize_rowwise', lambda : quantize_rowwise(g), info)
get_time('w_quantize_rowwise', lambda : quantize_rowwise(w), info)
Expand Down
2 changes: 2 additions & 0 deletions bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,6 @@
"optim.optimizer.MockArgs": False,
}

__version__ = "0.41.3.post1"

PACKAGE_GITHUB_URL = "https://github.com/TimDettmers/bitsandbytes"
1 change: 1 addition & 0 deletions bitsandbytes/cuda_setup/env_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def to_be_ignored(env_var: str, value: str) -> bool:
"PATH", # this is for finding binaries, not libraries
"LESSOPEN", # related to the `less` command
"LESSCLOSE",
"GOOGLE_VM_CONFIG_LOCK_FILE", # Google Cloud stuff, contains root only paths
"_", # current Python interpreter
}
return env_var in ignorable
Expand Down
20 changes: 13 additions & 7 deletions bitsandbytes/cuda_setup/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,10 @@ def generate_instructions(self):
self.add_log_entry('CUDA SETUP: Solution 1b): Once the library is found add it to the LD_LIBRARY_PATH: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:FOUND_PATH_FROM_1a')
self.add_log_entry('CUDA SETUP: Solution 1c): For a permanent solution add the export from 1b into your .bashrc file, located at ~/.bashrc')
self.add_log_entry('CUDA SETUP: Solution 2: If no library was found in step 1a) you need to install CUDA.')
self.add_log_entry('CUDA SETUP: Solution 2a): Download CUDA install script: wget https://github.com/TimDettmers/bitsandbytes/blob/main/install_cuda.sh')
self.add_log_entry('CUDA SETUP: Solution 2b): Install desired CUDA version to desired location. The syntax is bash install_cuda.sh CUDA_VERSION PATH_TO_INSTALL_INTO.')
self.add_log_entry('CUDA SETUP: Solution 2b): For example, "bash install_cuda.sh 113 ~/local/" will download CUDA 11.3 and install into the folder ~/local')
self.add_log_entry('CUDA SETUP: Solution 2a): Download CUDA install script: wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/cuda_install.sh')
self.add_log_entry('CUDA SETUP: Solution 2b): Install desired CUDA version to desired location. The syntax is bash cuda_install.sh CUDA_VERSION PATH_TO_INSTALL_INTO.')
self.add_log_entry('CUDA SETUP: Solution 2b): For example, "bash cuda_install.sh 113 ~/local/" will download CUDA 11.3 and install into the folder ~/local')

return

make_cmd = f'CUDA_VERSION={self.cuda_version_string}'
Expand Down Expand Up @@ -196,11 +197,13 @@ def remove_non_existent_dirs(candidate_paths: Set[Path]) -> Set[Path]:
try:
if path.exists():
existent_directories.add(path)
except PermissionError as pex:
# Handle the PermissionError first as it is a subtype of OSError
# https://docs.python.org/3/library/exceptions.html#exception-hierarchy
pass
except OSError as exc:
if exc.errno != errno.ENAMETOOLONG:
raise exc
except PermissionError as pex:
pass

non_existent_directories: Set[Path] = candidate_paths - existent_directories
if non_existent_directories:
Expand All @@ -214,8 +217,11 @@ def get_cuda_runtime_lib_paths(candidate_paths: Set[Path]) -> Set[Path]:
paths = set()
for libname in CUDA_RUNTIME_LIBS:
for path in candidate_paths:
if (path / libname).is_file():
paths.add(path / libname)
try:
if (path / libname).is_file():
paths.add(path / libname)
except PermissionError:
pass
return paths


Expand Down
2 changes: 1 addition & 1 deletion bitsandbytes/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .modules import Int8Params, Linear8bitLt, StableEmbedding, Linear4bit, LinearNF4, LinearFP4, Params4bit, OutlierAwareLinear, SwitchBackLinearBnb
from .modules import Int8Params, Linear8bitLt, StableEmbedding, Linear4bit, LinearNF4, LinearFP4, Params4bit, OutlierAwareLinear, SwitchBackLinearBnb, Embedding
from .triton_based_modules import SwitchBackLinear, SwitchBackLinearGlobal, SwitchBackLinearVectorwise, StandardLinear
4 changes: 2 additions & 2 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,10 @@ def set_compute_type(self, x):
if self.compute_dtype == torch.float32 and (x.numel() == x.shape[-1]):
# single batch inference with input torch.float16 and compute_dtype float32 -> slow inference when it could be fast
# warn the user about this
warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_type=torch.float32 (default). This will lead to slow inference.')
warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference.')
warnings.filterwarnings('ignore', message='.*inference.')
if self.compute_dtype == torch.float32 and (x.numel() != x.shape[-1]):
warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_type=torch.float32 (default). This will lead to slow inference or training speed.')
warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.')
warnings.filterwarnings('ignore', message='.*inference or training')

def _save_to_state_dict(self, destination, prefix, keep_vars):
Expand Down
12 changes: 6 additions & 6 deletions bitsandbytes/nn/triton_based_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose
from bitsandbytes.triton.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze
from bitsandbytes.triton.int8_matmul_mixed_dequantize import int8_matmul_mixed_dequantize


class _switchback_global(torch.autograd.Function):
Expand All @@ -29,7 +29,7 @@ def forward(ctx, X_3D, W, bias):

# matmult, fused dequant and add bias
# call "mixed" because we are mixing rowwise quantized and global quantized
return int8_matmul_mixed_dequanitze(
return int8_matmul_mixed_dequantize(
X_int8, W_int8.t(), state_X, state_W, bias
).view(*X_3D.size()[:-1], -1)

Expand All @@ -47,7 +47,7 @@ def backward(ctx, G_3D):
# so we transpose once then call .t() in the matmul
G_int8, state_G = quantize_rowwise(G)
W_int8, state_W = quantize_global_transpose(W)
grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view(
grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(
*G_3D.size()[:-1], -1
)
if ctx.needs_input_grad[1]:
Expand Down Expand Up @@ -119,7 +119,7 @@ def forward(ctx, X_3D, W, bias):

# matmult, fused dequant and add bias
# call "mixed" because we are mixing rowwise quantized and global quantized
return int8_matmul_mixed_dequanitze(
return int8_matmul_mixed_dequantize(
X_int8, W_int8.t(), state_X, state_W, bias
).view(*X_3D_sz[:-1], -1)

Expand All @@ -143,7 +143,7 @@ def backward(ctx, G_3D):
G_int8, state_G = quantize_rowwise(G)
del G
W_int8 = W_int8.t().contiguous()
grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view(
grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(
*G_3D_sz[:-1], -1
)

Expand Down Expand Up @@ -215,7 +215,7 @@ def forward(self, x):
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
).view(*x.size()[:-1], -1)
else:
return int8_matmul_mixed_dequanitze(
return int8_matmul_mixed_dequantize(
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
).view(*x.size()[:-1], -1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from bitsandbytes.triton.triton_utils import is_triton_available

if not is_triton_available():
def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias): return None
def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias): return None
else:

import triton
Expand Down Expand Up @@ -136,7 +136,7 @@ def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N,
tl.atomic_add(C, acc, mask=mask)


def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias):
def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias):
device = a.device
divfactor = 1. / (127. * 127.)
has_bias = 0 if bias is None else 1
Expand Down
3 changes: 2 additions & 1 deletion how_to_use_nonpytorch_cuda.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/install_cud
# EXPORT_TO_BASH in {0, 1} with 0=False and 1=True

# For example, the following installs CUDA 11.7 to ~/local/cuda-11.7 and exports the path to your .bashrc
bash install_cuda.sh 117 ~/local 1

bash cuda_install.sh 117 ~/local 1
```

## Setting the environmental variables BNB_CUDA_VERSION, and LD_LIBRARY_PATH
Expand Down
6 changes: 5 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

from setuptools import find_packages, setup
from setuptools.dist import Distribution
import bitsandbytes as bnb

VERSION = bnb.__version__

libs = list(glob.glob("./bitsandbytes/libbitsandbytes*.*"))
libs = [os.path.basename(p) for p in libs]
Expand All @@ -23,13 +26,14 @@ def has_ext_modules(foo):

setup(
name=f"bitsandbytes",
version=f"0.41.3.post1",
version=VERSION,
author="Tim Dettmers",
author_email="[email protected]",
description="k-bit optimizers and matrix multiplication routines.",
license="MIT",
keywords="gpu optimizers optimization 8-bit quantization compression",
url="https://github.com/TimDettmers/bitsandbytes",
install_requires=['scipy'],
packages=find_packages(),
package_data={"": libs},
long_description=read("README.md"),
Expand Down
2 changes: 1 addition & 1 deletion tests/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):

if gtype != torch.float32:
# the adam buffers should also be close because they are 32-bit
# but the paramters can diverge because they are 16-bit
# but the parameters can diverge because they are 16-bit
# the difference grow larger and larger with each update
# --> copy the state to keep weights close
p1.data = p1.data.to(p2.dtype).float()
Expand Down

0 comments on commit d03d296

Please sign in to comment.