Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add options for configuring quantization in CachedCausalLM.from_pretrained() #26

Closed
wants to merge 3 commits into from

Conversation

gabegrand
Copy link
Collaborator

@gabegrand gabegrand commented Jan 22, 2025

Adds several kwargs to CachedCausalLM.from_pretrained() to make quantization more configurable. Preserves the default behavior of load_in_8bit=True.

Motivation: It turns out that NVIDIA Hopper removed int8 support (bitsandbytes-foundation/bitsandbytes#599) in favor of float8 quantization. This is an issue for running hfppl on H100 GPUs which use Hopper architecture. More generally, with the space of LLMs and quantization schemes evolving quickly, the existing CachedCausalLM.from_pretrained() should offer the user more configuration control.

As a quick fix, this PR adds the ability to pass load_in_4bit as well as to specify a custom bnb_config. It's also separately useful to be able to pass a torch_dtype.

Future steps: Certain llama models now use torch.bfloat16; however, this dtype isn't supported by numpy so it's currently incompatible with hfppl, but there are multiple workarounds we should explore that extend numpy to support it.
EDIT: Turns out the only issue with bfloat16 arises when we try to store logprobs in the Trie without converting them to a numpy-friendly format. I've added calls to .float() in 2 cases and these seem to be sufficient to support bfloat16 models.

@gabegrand gabegrand requested a review from alex-lew January 22, 2025 19:16
@gabegrand
Copy link
Collaborator Author

On closer inspection, this PR seems to be largely encapsulated by the changes in #24 so I'm going to close it for now in favor of Ben's PR. We should try to merge it in soon...

@benlebrun What do you think about updating the Trie code in GenLM backend to make sure it casts p_llm to float before calling .numpy()? This will ensure that models that return bfloat16 will not cause errors when we move the logprobs to CPU.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant