You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Computing the statistics for QAT on an M4 Max leads to the following error message:
File "/Users/alex/Desktop/CS/ML/YADES/MAX78000/ai8x-training/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/alex/Desktop/CS/ML/YADES/MAX78000/ai8x-training/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1595, in _call_impl hook_result = hook(self, args, result) ^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/alex/Desktop/CS/ML/YADES/MAX78000/ai8x-training/ai8x.py", line 606, in _hist_hook hist = histogram(output.clone().detach().flatten(), bins=2048) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/alex/Desktop/CS/ML/YADES/MAX78000/ai8x-training/ai8x.py", line 543, in histogram counts = torch.histc(inp, bins, min=minimum, max=maximum).cpu() ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^RuntimeError: Invalid buffer size: 1568.00 GB
Some very inefficient copying of buffers between CPU and MPS backends is happening leading to 1568.00 GB of memory being requested for a buffer which fails since it runs out of memory. It seems the copying happens due to incomplete MPS implementations that fall back to the CPU.
The Solution
If we simply patch line 606 of ai8x.py by to convert the input to cpu before computing the histogram, we avoid the inefficient copying of buffers and no longer run into the OOM error: hist = histogram(output.clone().detach().flatten().cpu(), bins=2048).
Let me know if it would be okay to submit a PR with this fix or if it has some implications I have not considered.
Detailed Reproduction Instructions
Tested on Apple M4 Max with 36GB RAM running Sequoia 15.1 on January 22, 2025.
Make sure Homebrew is installed, then brew install pyenv
Make sure you have no virtual environments active, if you do, deactivate them
Run a test QAT training run with the default policies (policies/schedule.yaml and policies/qat_policy.yaml) and notice the error. If you are short on time, change the policies/qat_policy.yaml to start QAT initialization at epoch 1 (0 will not work since train.py does not allow QAT from epoch 0)
Patch line 606 of ai8x.py by replacing it with hist = histogram(output.clone().detach().flatten().cpu(), bins=2048). The extra call to .cpu() is necessary with the MPS backend used on Apple Silicon.
Run the same test QAT training run and observe the error no longer occurs.
The text was updated successfully, but these errors were encountered:
The Problem
Computing the statistics for QAT on an M4 Max leads to the following error message:
Some very inefficient copying of buffers between CPU and MPS backends is happening leading to 1568.00 GB of memory being requested for a buffer which fails since it runs out of memory. It seems the copying happens due to incomplete MPS implementations that fall back to the CPU.
The Solution
If we simply patch line 606 of
ai8x.py
by to convert the input to cpu before computing the histogram, we avoid the inefficient copying of buffers and no longer run into the OOM error:hist = histogram(output.clone().detach().flatten().cpu(), bins=2048)
.Let me know if it would be okay to submit a PR with this fix or if it has some implications I have not considered.
Detailed Reproduction Instructions
Tested on Apple M4 Max with 36GB RAM running Sequoia 15.1 on January 22, 2025.
brew install pyenv
deactivate
themmkdir MAX78000 && cd MAX78000
brew install libomp libsndfile tcl-tk sox
git clone --recursive https://github.com/analogdevicesinc/ai8x-training.git
cd ai8x-training
ai8x-training
with Python 3.11.8:pyenv install 3.11.8
python --version
and set it if not set automaticallypyenv local 3.11.8
python -m venv ./venv --prompt ai8x-training && echo "*" > venv/.gitignore
. ./venv/bin/activate
pip install -U pip wheel setuptools && pip install -r requirements.txt
ai8x.py
by replacing it withhist = histogram(output.clone().detach().flatten().cpu(), bins=2048)
. The extra call to.cpu()
is necessary with the MPS backend used on Apple Silicon.The text was updated successfully, but these errors were encountered: