Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes made in Unsloth and openInstruct to get a successful Online DPO run #1494

Open
pluesclues opened this issue Jan 2, 2025 · 6 comments

Comments

@pluesclues
Copy link

pluesclues commented Jan 2, 2025

Alright so as promised from the unsloth reddit post https://www.reddit.com/r/LocalLLaMA/comments/1hqkeyn/comment/m4rbtto/?utm_source=share&utm_medium=web3x&utm_name=web3xcss&utm_term=1&utm_content=share_button, I will highlight the changes I had made in the allenAI open isntruct repo that I forked (https://github.com/pluesclues/us_open-instruct) and the unsloth repo (https://github.com/pluesclues/unsloth/tree/main) I had forked in order to get things working, the changes were overall minimal and I had tried my best to make as least code changes as possible so they were easy to integrate. Lets start with the changes I made in unsloth as they were quite simple compared to the ones I made in the open instruct repo. I am going to highlight mainly three different things I had to focus on in order to get Unsloth to be compatible with the openInstruct repo.

DISCLAIMER: I GOT THIS WORKING WITH MAINLY THE LLAMA MODELS, THESE CHANGES CAN ALSO BE APPLIED TO THE OTHER MODELS AS WELL (Although I should make better code to do this)

DISCLAIMER 2: Apologies, TLDR reddit dataset has inappropriate text at time, I will try to censor it.

Lets start with the changes I made in unsloth:

  1. https://github.com/pluesclues/unsloth/blob/main/unsloth/kernels/fast_lora.py

    In fast_lora.py, I acutally addressed this issue:

    Lora downcasting issue #320

    it can be fixed by adding with torch.amp.autocast('cuda', dtype=torch.bfloat16): (or `torch.float16' depending on your system ) above all of the matrix multiplication comptuations to enable mixed precision. I do not know why it doesn't work when you do

    with torch.amp.autocast('cuda', dtype=torch.bfloat16):
               accelerator.backward(loss)
    

    The changes were applied to:

    LoRA_MLP:

    https://github.com/pluesclues/unsloth/blob/389b98f4860ab007f02af27258bd68d594749a66/unsloth/kernels/fast_lora.py#L116C8-L116C63

    LoRA_QKV:

    https://github.com/pluesclues/unsloth/blob/389b98f4860ab007f02af27258bd68d594749a66/unsloth/kernels/fast_lora.py#L275

    LORA_W:

    https://github.com/pluesclues/unsloth/blob/389b98f4860ab007f02af27258bd68d594749a66/unsloth/kernels/fast_lora.py#L392

    But that solves the lora downcasting issue atleast when you try to do torch.backwards(loss) on a custom loss calculated by torch functions.

  2. The second change I want to highlight is quite simple and it is in https://github.com/pluesclues/unsloth/blob/main/unsloth/models/llama.py

So in these lines https://github.com/pluesclues/unsloth/blob/4705906536f8aa1a10143a3cfa814ddd50f05bdc/unsloth/models/llama.py#L1507-L1539
are made in order to reserve the original forward functions for the llama models. This is because, if you use AutoModelForSequenceClassification It is not compatible with the unsloth overwritten forward functions, so the need to be kept and set and reset in variables when you are calculating the rewards during your RL updates given that your models must generate responses and get rewards during training. (I will highlight these chnages when going intot he allen AI repo)

OK THATS ALL THE UNSLOTH CHANGES, next will be all of the changes that were made in AllenAI openInstruct, but will need to be transfered to TRL, we first will start with the initialization of the models, it mostly stays the same, except for the reward model.

  1. The policy and reference policy are initialized the same way as they would be in the unsloth notebook, since you intialize two tokenizers as well, you only really need one of them. https://github.com/pluesclues/us_open-instruct/blob/5375f58e2b893554da018c9c6be472ce0d1ed220/open_instruct/unsloth_online_dpo.py#L255-L315 But you also have to add the tokenizer padding to the right side and add the pad token to the dictionary.

I also had to use the reset and set functions from the unsloth changes I made to initialize my reward model.

https://github.com/pluesclues/us_open-instruct/blob/5375f58e2b893554da018c9c6be472ce0d1ed220/open_instruct/unsloth_online_dpo.py#L347-L357

I also intialized the policy and reference policy for training before going into the loop.

https://github.com/pluesclues/us_open-instruct/blob/5375f58e2b893554da018c9c6be472ce0d1ed220/open_instruct/unsloth_online_dpo.py#L456-L457

4: Ok this is where the most important changes are and that has to do with generation, I will try to highlight all of the functions that are linked together as well as where it starts and it starts in this file: https://github.com/pluesclues/us_open-instruct/blob/5375f58e2b893554da018c9c6be472ce0d1ed220/open_instruct/unsloth_online_dpo.py#L459-L467, so TRL also unwraps the model for generation and that function remains the same, I am going to go over unsloth_batch_generation and its dependencies:

https://github.com/pluesclues/us_open-instruct/blob/5375f58e2b893554da018c9c6be472ce0d1ed220/open_instruct/model_utils.py#L473-L504

Ok so I intitialize FastLanguageModel.for_inference(model) before it generates from the batches and set it back to FastLanguageModel.for_training(model) after the funciton is done generating.

I will go into the logistics of the https://github.com/pluesclues/us_open-instruct/blob/5375f58e2b893554da018c9c6be472ce0d1ed220/open_instruct/model_utils.py#L447-L470 notice here I only set max_new_tokens=53 this is because of the fact that, you will generate weird responses if you only set min_new_tokens=53 and also if you set both to 53, the generation will not produce EOS token.

Ok so, one problem with only setting max_new_tokens=53 the unsloth model will padd any tokens after the first EOS token with more EOS tokens which is actually fine, but for batch generation, the query_responses length won't match up when you have to do return torch.cat(query_responses, 0) when returning the batch generations.

Alright also note, I should have not hard coded 53 for the max_new_tokens but essentially what these lines of code here https://github.com/pluesclues/us_open-instruct/blob/5375f58e2b893554da018c9c6be472ce0d1ed220/open_instruct/model_utils.py#L497-L502 do is essentially if the response was not 53 tokens long, it will pad it with more EOS tokens up until 53, (you can change 53 to whatever your max_new_tokens you need for your specific dataset, 53 is used for the tldr datset https://huggingface.co/datasets/trl-internal-testing/tldr-preference-sft-trl-style). Essentially it will make sure all the shapes are the same and the padding makes sense according to how unsloths generate function does it.

Ill explain what happens if you do not do this more clearly, typically when you generate for like a batch of 4 you will see the following shapes for the 4 samples note assume they are all tensors:

[64, 308]
[64, 299]
[64, 304]
[64, 308]

Essentially, to concatenate the batches they have to be the same shape, so we just pad with EOS tokens up until 53 tokens are generated from the response. I am not sure however if unsloth supports batch generation natively with this generation function and if this problem isnt exactly an issue. It is also imperative that there is a EOS token in each of the responses as that accounts for most of the reward given from the response, Online DPO will not work unless if there is atleast an EOS token in the generation.

I used this generation function

@torch.no_grad()
def unsloth_generate_text(model, queries, tokenizer, pad_token_id, generation_config):
    # Extract the maximum length for generation
    max_length = generation_config.max_length

    # Get the context length from the input queries
    context_length = queries.shape[1]

    # Mask the input and prepare it for the model
    attention_mask = queries != pad_token_id
    input_ids = torch.masked_fill(queries, ~attention_mask, 0)
    # Generate output sequences
    outputs = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_new_tokens=53,  
        use_cache=True,
        do_sample=True,
        top_k=0,
        top_p=1.0,
        temperature=0.7,
    )

    return outputs

Example of max_new_tokens and min_new_tokens to be 53

TL;DR: I have to cut contact with my ex's friends or it'll hurt me. How to do it without hurting them? Is it the right thing to do? Is it healthy? Or am I being a b***? Thanks! :) :) :)

Example of min_new_tokens to be 53

I feel like I have to break contact with these girls because I'm not sure if I want to keep up the friendship. But also because I don't want to hurt them. Will they accept or will it hurt? How do I make it work? Thanks for reading. :) -f/22. :D ^_^ :D<|end_of_text|>

Example of max_new_tokens to be 53 (Apologies, this is not generated the same prompt, but this is what happens durring training loops, but either way does not change the logic I have implemented. )

\nTL;DR: I don't like Halloween and I don't allow my son to trick-or-treat, but everyone else insists that I'm forcing him to miss out on something and I don't feel like I'm doing anything wrong.<|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|>'

  1. Ok this is the last change that was made, so typically in the RL trainers in TRL, to compute the KL divergence between policies and to save time, the logitss are returned from the generate function and the thing is unsloth given the same prompt has different logits due to 4bit precision I believe? I had talked with @danielhanchen about this in this reddit post: https://www.reddit.com/r/unsloth/comments/1f90cgo/generation_instability_between_the_forward_probs/, This however is fixed with just using the forward function and not storing the output logits:

https://github.com/pluesclues/us_open-instruct/blob/5375f58e2b893554da018c9c6be472ce0d1ed220/open_instruct/unsloth_online_dpo.py#L473-L476

This is so the KL is stable and actually starts at 0 since the ref_policy and policy should be the same when starting the DPO run.

I tried my best to highlight all the changes please let me know if anything is confusing, I will try to write about where I will put the openInstruct changes into TRL in a comment below. I look forward to getting this integrated into unsloth as soon as possible and possibly make a notebook for it.

@shimmyshimmer
Copy link
Collaborator

Thank you so much for this we'll take a look and review it :)

@pluesclues
Copy link
Author

pluesclues commented Jan 2, 2025

Yeah apologies I should mention this change too in the get_reward function: https://github.com/pluesclues/us_open-instruct/blob/5375f58e2b893554da018c9c6be472ce0d1ed220/open_instruct/model_utils.py#L187-L216 I also use the set and reset functions from here in order for the model to properly get a reward. https://github.com/pluesclues/unsloth/blob/389b98f4860ab007f02af27258bd68d594749a66/unsloth/models/llama.py#L1507-L1535 when doing the forward pass and then I reset the function pointers to unsloths functions after it uses the AutoModelForSequenceClassificaiton functions.

@pluesclues
Copy link
Author

Ok so I will go and say where some changes should be made in this file from the allenAI/openInstruct repo that I edited, I will try to go through the changes in the order that I listed in the above issue. Any changes that I made in Unsloth do not need to be transfered into TRL but anything from 3-5 in the OP needs to be transfered to TRL. I will discuss potentially what could just be implemented natively to unsloth. I will maybe try to implement these changes myself sometime soon and get a successful run, (apologies semester is starting soon and have quite a bit of other olbigations.)

1: This is more of a general note for anyone else reading this besides @shimmyshimmer and @danielhanchen Its important that you train your own reward model, preferably DPO style like the RM trainer does on hugging face already and you test if the classification accuracy is good, and that you SFT your model on a dataset and check if it produces EOS tokens correctly and produces similar responses given a prompt and the answer to that prompt in the SFT dataset. Online DPO will not work with arbitrary policies and reward models.

2: One big change I would do is in the args, pass an arge like use_unsloth and then ignore like how the OnlineDPOTrainer intializes the model, but instead intialize it the way unsloth does it.

put something like

if use_unsloth:

    max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
    dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
    load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.


    #creating unsloth FastLanguageModel and turning them into peft models
    #print("name of model: ", model_config.model_name_or_path)
    model, _ = FastLanguageModel.from_pretrained(
      model_name =  model_config.model_name_or_path, # "unsloth/tinyllama" for 16bit loading
      max_seq_length = max_seq_length,
      dtype = dtype,
      load_in_4bit = load_in_4bit,
      #token  = "" 
    )

    model = FastLanguageModel.get_peft_model(
        model,
        r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
        target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                        "gate_proj", "up_proj", "down_proj",],
        lora_alpha = 16,
        lora_dropout = 0, # Supports any, but = 0 is optimized 1e-7
        bias = "none",    # Supports any, but = "none" is optimized
        # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
        use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
        random_state = 3407,
        use_rslora = False,  # We support rank stabilized LoRA
        loftq_config = None, # And LoftQ
    )
    #breakpoint()

    ref_model, tokenizer = FastLanguageModel.from_pretrained(
      model_name =  model_config.model_name_or_path, # "unsloth/tinyllama" for 16bit loading
      max_seq_length = max_seq_length,
      dtype = dtype,
      load_in_4bit = load_in_4bit,
      #token = "" 

    )
    #print("ref polciy type: ", type(ref_model))

    ref_model = FastLanguageModel.get_peft_model(
        ref_model,
        r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
        target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                        "gate_proj", "up_proj", "down_proj",],
        lora_alpha = 16,
        lora_dropout =0, # Supports any, but = 0 is optimized 1e-7
        bias = "none",    # Supports any, but = "none" is optimized
        # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
        use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
        random_state = 3407,
        use_rslora = False,  # We support rank stabilized LoRA
        loftq_config = None, # And LoftQ
    )

    tokenizer.padding_side="right"
    # create a tokenizer (pad from right)
    #tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path, padding_side="right")
    tokenizer.add_special_tokens({"pad_token": "[PAD]"})  # NOTE: we do not resize the embedding
    tokenizer.chat_template = CHAT_TEMPLATES[dataset_config.chat_template]

in here:

https://github.com/huggingface/trl/blob/763738f457f283270772ac9bd5b3e4027fd424d5/examples/scripts/dpo_online.py#L84-L100

Although if you want to use the peft_config module in the OnlienDPOTrainer, you would do something like

if use_unsloth:

    max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
    dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
    load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.


    #creating unsloth FastLanguageModel and turning them into peft models
    #print("name of model: ", model_config.model_name_or_path)
    model, _ = FastLanguageModel.from_pretrained(
      model_name =  model_config.model_name_or_path, # "unsloth/tinyllama" for 16bit loading
      max_seq_length = max_seq_length,
      dtype = dtype,
      load_in_4bit = load_in_4bit,
      #token  = "" 
    )

    model = FastLanguageModel.get_peft_model(
        model,
        r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
        target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                        "gate_proj", "up_proj", "down_proj",],
        lora_alpha = 16,
        lora_dropout = 0, # Supports any, but = 0 is optimized 1e-7
        bias = "none",    # Supports any, but = "none" is optimized
        # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
        use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
        random_state = 3407,
        use_rslora = False,  # We support rank stabilized LoRA
        loftq_config = None, # And LoftQ
    )
    #breakpoint()

    ref_model, tokenizer = FastLanguageModel.from_pretrained(
      model_name =  model_config.model_name_or_path, # "unsloth/tinyllama" for 16bit loading
      max_seq_length = max_seq_length,
      dtype = dtype,
      load_in_4bit = load_in_4bit,
      #token = "" 

    )
    #print("ref polciy type: ", type(ref_model))

    ref_model = FastLanguageModel.get_peft_model(
        ref_model,
        r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
        target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                        "gate_proj", "up_proj", "down_proj",],
        lora_alpha = 16,
        lora_dropout =0, # Supports any, but = 0 is optimized 1e-7
        bias = "none",    # Supports any, but = "none" is optimized
        # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
        use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
        random_state = 3407,
        use_rslora = False,  # We support rank stabilized LoRA
        loftq_config = None, # And LoftQ
    )

    tokenizer.padding_side="right"
    # create a tokenizer (pad from right)
    #tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path, padding_side="right")
    tokenizer.add_special_tokens({"pad_token": "[PAD]"})  # NOTE: we do not resize the embedding
    tokenizer.chat_template = CHAT_TEMPLATES[dataset_config.chat_template]
else:
#everything else in the file

Around these lines https://github.com/huggingface/trl/blob/763738f457f283270772ac9bd5b3e4027fd424d5/trl/trainer/online_dpo_trainer.py#L130-L216

3: Just make sure to FastLanguageModel_reset_functions before line 89 here and FastLanguageModel_set_functions after line 94. https://github.com/huggingface/trl/blob/763738f457f283270772ac9bd5b3e4027fd424d5/examples/scripts/dpo_online.py#L89-L94

I would recomend also setting the model and ref policy to FastLanguageModelForTraining here:

https://github.com/huggingface/trl/blob/763738f457f283270772ac9bd5b3e4027fd424d5/trl/trainer/online_dpo_trainer.py#L396C1-L397C1

4: Ok so, I will say it does not seem like they don't do batch generations here, but if you were to apply batch generation, you would just copy and paste my code to these lines, and make sure that you also set FastlanguageModelForInference before generations and FastLanguageModelForTraining like this:

@torch.no_grad()
def unsloth_batch_generation(
    model: torch.nn.Module,
    queries: torch.Tensor,
    local_rollout_forward_batch_size: int,
    tokenizer,
    pad_token_id: int,
    generation_config: dict,
):
    query_responses = []
    logitss = []
    from unsloth import FastLanguageModel
    FastLanguageModel.for_inference(model)
    for i in range(0, queries.shape[0], local_rollout_forward_batch_size):
        query = queries[i : i + local_rollout_forward_batch_size]
        query_response = unsloth_generate_text(
            model,
            query,
            tokenizer,
            pad_token_id,
            generation_config,
        )
        query_responses.append(query_response)
    FastLanguageModel.for_training(model)
    for i in range(len(query_responses)):

        if query_responses[i].shape[1] - query.shape[1] != 53:
            pad_tensor = torch.full((query_responses[i].shape[0],53 - (query_responses[i].shape[1] - query.shape[1])), tokenizer.eos_token_id, dtype=torch.int64).to("cuda:0")

            query_responses[i] = torch.cat((query_responses[i], pad_tensor), dim=1)

    return torch.cat(query_responses, 0)

Otherwise if we aren't doing batch generation we leave it to be the same.

5: Just redo the forward logits with the forward funciton instead of the outputs from the generations, copy from here `

                #logits = logitss[i : i + args.local_rollout_forward_batch_size]
                model_output = forward(model, query_response, tokenizer.pad_token_id)
                logits = model_output.logits[:, context_length - 1 : -1]
                logits /= args.temperature + 1e-7

to here:

https://github.com/huggingface/trl/blob/763738f457f283270772ac9bd5b3e4027fd424d5/trl/trainer/online_dpo_trainer.py#L429-L431

(Seems like the online DPO trainer in hugging face already does this but for its own reasons.)

6: I would just put this funciton into the online_dpo_trainer.py file in TRL and just call it whenever you get rewards:

https://github.com/pluesclues/us_open-instruct/blob/5375f58e2b893554da018c9c6be472ce0d1ed220/open_instruct/model_utils.py#L187-L216

The above line replaces or overwrites this function here:
https://github.com/huggingface/trl/blob/763738f457f283270772ac9bd5b3e4027fd424d5/trl/trainer/online_dpo_trainer.py#L500-L503

@pluesclues
Copy link
Author

Apologies to mention, there is a ton of hard coded name of models and datasets in the script as I honeslty got a bit lazy with figuring out the API, regardless, if you pull the allenAI/OpenInstruct repo and you also integrate the changes from unsloth in there, I ran the script with:

python open_instruct/unsloth_online_dpo.py     --dataset_name "trl-internal-testing/tldr-preference-sft-trl-style"     --dataset_train_split "train"     --dataset_eval_split "test"     --sft_messages_key "messages"     --model_name_or_path "keithdrexel/unsloth-llama-3.2-1b-tldr-unsloth-nobnb"     --reward_model_path "/home/kt828/open-instruct/keithdrexel/Llama-3.2-1b_reward_tldr"     --chat_template "simple_concat_with_space"     --learning_rate 3e-6     --total_episodes 250000     --per_device_train_batch_size 32     --per_device_eval_batch_size 32     --gradient_accumulation_steps 8     --max_token_length 256     --max_prompt_token_lenth 256     --num_train_epochs 1     --min_response_length 28     --beta 0.1     --output_dir "models/rm/rm_sentiment_1b"     --with_tracking     --push_to_hub true --stop_token_id 128001

My reward model saved weirdly so you have to load it like this:

    reward_model: PreTrainedModel = AutoModelForSequenceClassification.from_pretrained(
        "keithdrexel/reward_modeling__meta-llama_Llama-3.2-1B",
        revision = "reward_modeling__1__1728309120", 
        num_labels=1,
        torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_2",
        use_cache=False,
    )

Hope this helps.

@shimmyshimmer
Copy link
Collaborator

Thanks a lot @pluesclues we'll actually probably be releasing a release just for Reward Modelling/Reinforced Learning soon and this addition would be stellar!

@pluesclues
Copy link
Author

Thats great! I will try to get working on the TRL notebook when my semester starts, however, I am not sure if a lot of these changes could be added in unsloth rather than editing TRL itself, I will do it regardless, but I think possibly there maybe merit in making like an RL specific generate function to include the padding changes as well as the specific generation changes that I had. However, for the KL divergence issue I mentioned above, I think you may have to edit TRL regardless to use the forward logits of the sequence generated from batch generations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants