Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug: Quantization Uses Excessive Buffer with MPS Backend Leading to OOM #344

Open
SanderGi opened this issue Jan 22, 2025 · 0 comments
Open

Comments

@SanderGi
Copy link

The Problem

Computing the statistics for QAT on an M4 Max leads to the following error message:

  File "/Users/alex/Desktop/CS/ML/YADES/MAX78000/ai8x-training/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/alex/Desktop/CS/ML/YADES/MAX78000/ai8x-training/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1595, in _call_impl
    hook_result = hook(self, args, result)
                  ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/alex/Desktop/CS/ML/YADES/MAX78000/ai8x-training/ai8x.py", line 606, in _hist_hook
    hist = histogram(output.clone().detach().flatten(), bins=2048)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/alex/Desktop/CS/ML/YADES/MAX78000/ai8x-training/ai8x.py", line 543, in histogram
    counts = torch.histc(inp, bins, min=minimum, max=maximum).cpu()
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Invalid buffer size: 1568.00 GB

Some very inefficient copying of buffers between CPU and MPS backends is happening leading to 1568.00 GB of memory being requested for a buffer which fails since it runs out of memory. It seems the copying happens due to incomplete MPS implementations that fall back to the CPU.

The Solution

If we simply patch line 606 of ai8x.py by to convert the input to cpu before computing the histogram, we avoid the inefficient copying of buffers and no longer run into the OOM error: hist = histogram(output.clone().detach().flatten().cpu(), bins=2048).

Let me know if it would be okay to submit a PR with this fix or if it has some implications I have not considered.

Detailed Reproduction Instructions

Tested on Apple M4 Max with 36GB RAM running Sequoia 15.1 on January 22, 2025.

  • Make sure Homebrew is installed, then brew install pyenv
  • Make sure you have no virtual environments active, if you do, deactivate them
  • mkdir MAX78000 && cd MAX78000
  • brew install libomp libsndfile tcl-tk sox
  • Setup training
    • git clone --recursive https://github.com/analogdevicesinc/ai8x-training.git
    • cd ai8x-training
    • Create a virtual environment named ai8x-training with Python 3.11.8:
      • Install Python 3.11.8 pyenv install 3.11.8
      • Verify Python version with python --version and set it if not set automatically pyenv local 3.11.8
      • Create a virtual environment python -m venv ./venv --prompt ai8x-training && echo "*" > venv/.gitignore
      • Activate the virtual environment . ./venv/bin/activate
    • pip install -U pip wheel setuptools && pip install -r requirements.txt
  • Run a test QAT training run with the default policies (policies/schedule.yaml and policies/qat_policy.yaml) and notice the error. If you are short on time, change the policies/qat_policy.yaml to start QAT initialization at epoch 1 (0 will not work since train.py does not allow QAT from epoch 0)
  • Patch line 606 of ai8x.py by replacing it with hist = histogram(output.clone().detach().flatten().cpu(), bins=2048). The extra call to .cpu() is necessary with the MPS backend used on Apple Silicon.
  • Run the same test QAT training run and observe the error no longer occurs.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant