Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[misc] feat: spport rmpad/data-packing in FSDP with transformers #91

Open
wants to merge 17 commits into
base: main
Choose a base branch
from

Conversation

PeterSH6
Copy link
Collaborator

@PeterSH6 PeterSH6 commented Jan 10, 2025

  • Use actor_rollout_ref.model.use_rmpad=True + critic.model.use_rmpad=True \ + reward_model.model.use_rmpad=True to enable rmpad for different models. Default set to False
  • Using AutoModelForTokenClassification for Value and Reward Model. Instead of using SeqenceClassification
  • Compute logprob convert to log_probs_from_logits_response_rmpad

Resolve: #53

Comparison using DeepSeek7b and GSM8k:
About 1.7x speedup compare to no rmpad (original cases)
截屏2025-01-10 下午1 03 18
截屏2025-01-10 下午1 03 08

@vermouth1992
Copy link
Collaborator

Shall we add a supported model list and raise error if the model is not in the list?

@vermouth1992
Copy link
Collaborator

Try to avoid using log_probs_from_logits_response_rmpad because there is an unpad op inside. unpad is a cuda-blocking op. Instead, we can directly use unpad input_ids from the input

@PeterSH6
Copy link
Collaborator Author

I think this list depends on transformers lib. No sure where to get this list. I didn't find any doc about the feature in transformers.

@vermouth1992
Copy link
Collaborator

vermouth1992 commented Jan 10, 2025

I think this list depends on transformers lib. No sure where to get this list. I didn't find any doc about the feature in transformers.

Simply add potential models in the CI. If the model passes CI, then add to the supported list. I guess we can target

  • Llama
  • Mistral
  • QWen
  • Gemma

@PeterSH6
Copy link
Collaborator Author

Try to avoid using log_probs_from_logits_response_rmpad because there is an unpad op inside. unpad is a cuda-blocking op. Instead, we can directly use unpad input_ids from the input

Sure, I will write a new API for unpad input_ids

@PeterSH6
Copy link
Collaborator Author

Simply add potential models in the CI. If the model passes CI, then add to the supported list. I guess we can target

Shall we add the test_transformers.py to CI? I didn't do it as I think it only depends on the transformers version and flash_attn version.

So, I guess the goal for the CI is to test whether the latest transformers + flash_attn would break our implementation

@vermouth1992
Copy link
Collaborator

Simply add potential models in the CI. If the model passes CI, then add to the supported list. I guess we can target

Shall we add the test_transformers.py to CI? I didn't do it as I think it only depends on the transformers version and flash_attn version.

So, I guess the goal for the CI is to test whether the latest transformers + flash_attn would break our implementation

After this PR, we should set a minimum version of transformers

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Do we have plans for data packing?
2 participants