Skip to content

Commit

Permalink
fix flashattn dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
JegernOUTT committed Nov 1, 2023
1 parent c99e9cd commit 65efde0
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 11 deletions.
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ RUN DEBIAN_FRONTEND="noninteractive" apt-get install -y \
expect \
mpich \
libmpich-dev \
python3 python3-pip \
python3 python3-pip python3-packaging \
&& rm -rf /var/lib/{apt,dpkg,cache,log}

RUN echo "export PATH=/usr/local/cuda/bin:\$PATH" > /etc/profile.d/50-smc.sh
Expand Down Expand Up @@ -45,7 +45,7 @@ ENV TORCH_CUDA_ARCH_LIST="6.0;6.1;7.0;7.5;8.0;8.6;8.9;9.0+PTX"
COPY . /tmp/app
RUN pip install /tmp/app && rm -rf /tmp/app

ENV MAX_JOBS=8
ENV MAX_JOBS=4
ENV FLASH_ATTENTION_FORCE_BUILD="TRUE"
RUN git clone -b feat/alibi https://github.com/smallcloudai/flash-attention.git /tmp/flash-attention \
&& cd /tmp/flash-attention \
Expand Down
36 changes: 27 additions & 9 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,19 @@

from typing import List, Set

setup_package = os.environ.get("SETUP_PACKAGE", None)
install_optional = os.environ.get("INSTALL_OPTIONAL", "TRUE")

# Setting some env variables to force flash-attention build from sources
# We can get rid of them when https://github.com/Dao-AILab/flash-attention/pull/540 is merged
os.environ["MAX_JOBS"] = "4"
os.environ["FLASH_ATTENTION_FORCE_BUILD"] = "TRUE"


@dataclass
class PyPackage:
requires: List[str] = field(default_factory=list)
optional: List[str] = field(default_factory=list)
requires_packages: List[str] = field(default_factory=list)
data: List[str] = field(default_factory=list)

Expand Down Expand Up @@ -42,8 +51,8 @@ class PyPackage:
requires=["aiohttp", "aiofiles", "cryptography", "fastapi==0.100.0", "giturlparse", "pydantic==1.10.13",
"starlette==0.27.0", "uvicorn", "uvloop", "python-multipart", "auto-gptq==0.4.2", "accelerate",
"termcolor", "torch", "transformers", "bitsandbytes", "safetensors", "peft", "triton",
"torchinfo", "mpi4py", "deepspeed==0.11.1",
"flash_attn @ git+https://github.com/smallcloudai/flash-attention@feat/alibi"],
"torchinfo", "mpi4py", "deepspeed==0.11.1"],
optional=["flash_attn @ git+https://github.com/smallcloudai/flash-attention@feat/alibi"],
requires_packages=["refact_scratchpads", "refact_scratchpads_no_gpu",
"known_models_db", "refact_data_pipeline"],
data=["webgui/static/*", "webgui/static/js/*", "webgui/static/components/modals/*", "watchdog/watchdog.d/*"]),
Expand All @@ -60,7 +69,21 @@ def find_required_packages(packages: Set[str]) -> Set[str]:
return packages


setup_package = os.environ.get("SETUP_PACKAGE", None)
def get_install_requires():
install_requires = list({
required_package
for py_package in setup_packages.values()
for required_package in py_package.requires
})
if install_optional.upper() == "TRUE":
install_requires.extend(list({
required_package
for py_package in setup_packages.values()
for required_package in py_package.optional
}))
return install_requires


if setup_package is not None:
if setup_package not in all_refact_packages:
raise ValueError(f"Package {setup_package} not found in repo")
Expand All @@ -72,7 +95,6 @@ def find_required_packages(packages: Set[str]) -> Set[str]:
else:
setup_packages = all_refact_packages


setup(
name="refact-self-hosting",
version="1.1.0",
Expand All @@ -85,9 +107,5 @@ def find_required_packages(packages: Set[str]) -> Set[str]:
packages=find_packages(include=(
f"{name}*" for name in setup_packages
)),
install_requires=list({
required_package
for py_package in setup_packages.values()
for required_package in py_package.requires
}),
install_requires=get_install_requires(),
)

0 comments on commit 65efde0

Please sign in to comment.