-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Softmax kernel in Triton. Use softmax kernel and argmax in Llama generation.py. + Small changes #11
base: main
Are you sure you want to change the base?
Conversation
benchmarking/benchmark_utils.py
Outdated
import pandas as pd | ||
|
||
|
||
def compare_benchmarks(benchmarks: Dict[str, Dict[str, Any]]) -> Dict[str, Any]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any reason you are deleting this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added back
import pytest | ||
from kernels.fused_softmax import triton_softmax | ||
|
||
@pytest.mark.parametrize("input_size", [(1024, 1024), (512, 512), (2048, 512)]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding tests!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI, this might not be ideal because we are not calling softmax from triton.ops like the other tests. I ran into issues with doing it that way.
models/llama/llama/math_ops.py
Outdated
@@ -70,14 +71,15 @@ def attention(self, xq, keys, values, head_dim, mask): | |||
|
|||
@Profiler.profiling_decorator("softmax") | |||
def softmax(self, x, dim): | |||
if self.use_triton: | |||
return F.softmax(x, dim=-1) | |||
if self.use_triton and len(x) == 2: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like you're trying to check the number of dimensions here, right? len(x)
gets the number of elements, equivalent to x.numel()
. I think you want x.dim()
or x.ndim
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
models/llama/llama/math_ops.py
Outdated
if self.use_triton: | ||
return F.softmax(x, dim=-1) | ||
if self.use_triton and len(x) == 2: | ||
return triton_softmax(x, dim=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we passing dim=-1
to these calls, when we receive dim
as an argument? Let's pass it through properly instead of overriding it. (Also, does the fused Triton kernel actually handle dim!=-1
correctly?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently it does not handle dim != -1 . Looking into it (seeing how llama.cpp is doing this) if you have any pointers.
models/llama/llama/math_ops.py
Outdated
else: | ||
return F.softmax(x, dim=-1) | ||
|
||
@Profiler.profiling_decorator("argmax") | ||
def argmax(self, x, dim): | ||
if self.use_triton: | ||
# TODO: change |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of adding a TODO
to the code here, would you mind creating an issue to track it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
- Rename certain functions to conform with naming scheme - Current triton softmax does not handle > 2 dimensions but will need to investigate (probably by looking at llama.cpp)
dist.destroy_process_group()
to remove warning during benchmarkingResults from calling
python3 main.py llama_chat_completion --benchmark --ckpt_dir <model_checkpoint_path> --tokenizer_path <model_tokenizer_path>
With No Changes:
With just softmax
With softmax and argmax