Skip to content

Commit

Permalink
Bug fixes (#1288)
Browse files Browse the repository at this point in the history
* Fix TRL

* Update mistral.py

* Patch processing_class

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Installation guide (#1165)

* chore: update chat_templates.py (#1166)

orginal -> original

* Disable Flex Attention

* Update tokenizer_utils.py

* Update _utils.py

* n_items

* Update cross_entropy_loss.py

* Fix DPO, ORPO

* Update _utils.py

* Update _utils.py

* fix/transformers-unpack (#1180)

* Fix DPO, ORPO (#1177)

* Fix TRL

* Update mistral.py

* Patch processing_class

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Installation guide (#1165)

* chore: update chat_templates.py (#1166)

orginal -> original

* Disable Flex Attention

* Update tokenizer_utils.py

* Update _utils.py

* n_items

* Update cross_entropy_loss.py

* Fix DPO, ORPO

* Update _utils.py

---------

Co-authored-by: timothelaborie <[email protected]>
Co-authored-by: Ikko Eltociear Ashimine <[email protected]>

* Add warning for missing Unpack and KwargsForCausalLM in older Transformers versions

---------

Co-authored-by: Daniel Han <[email protected]>
Co-authored-by: timothelaborie <[email protected]>
Co-authored-by: Ikko Eltociear Ashimine <[email protected]>

* Update cross_entropy_loss.py

* Update _utils.py

* Update _utils.py

* donot upcast lm_head and embeddings to float32 (#1186)

* Cleanup upcast logs (#1188)

* Fix/phi-longrope (#1193)

* Enhance rotary embedding handling in LlamaAttention and LongRopeRotaryEmbedding

* Typo

* Improve rotary embedding handling in LlamaAttention to prevent errors with short KV cache

* Update llama.py

* Update llama.py

---------

Co-authored-by: Daniel Han <[email protected]>

* Update transformers

* Unk token issues

* Update _utils.py

* Fix pad token

* Update llama.py

* Typo

* ignored labels

* Revert "ignored labels"

This reverts commit 9d07be0.

* More patching

* Update _utils.py

* Update _utils.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Feat/all tmp (#1219)

* Update save.py

Check whether path is in /tmp dir for Kaggle environment

* Update save.py

Move temporary_location to /tmp in Kaggle

* Enhance Kaggle environment support in save and tokenizer utilities

---------

Co-authored-by: dendarrion <[email protected]>
Co-authored-by: Erland366 <[email protected]>

* Bug fixes

* Update pyproject.toml

* Update _utils.py

* Update __init__.py

* Update __init__.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Tied weights

* Revert "Tied weights"

This reverts commit 8090b7c.

* Tied weights

* Utils

* CE Loss patching

* Update __init__.py

* Update __init__.py

* Patching

* Update cross_entropy_loss.py

* CE Loss

* Update _utils.py

* Update _utils.py

* CE Loss

* Update _utils.py

* Update _utils.py

* Layernorm

* Update _utils.py

* Update _utils.py

* Post patch

* Update _utils.py

* Update llama.py

* Update _utils.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* typing

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* int64

* Update _utils.py

* Update cross_entropy_loss.py

* constexpr

* constexpr

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* CE

* Update cross_entropy_loss.py

* Update _utils.py

* Update llama.py

* Update _utils.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update utils.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* typing

* Update rope_embedding.py

* types

* Disable compiling

* Update _utils.py

* Update _utils.py

* Forward hook

* Update _utils.py

* Update llama.py

* Update _utils.py

* Update llama.py

* Update llama.py

* Update _utils.py

* Update pyproject.toml

* Update _utils.py

* Update llama.py

* CE Loss

* Update cross_entropy_loss.py

* Update _utils.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update llama.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Fix: cast logits to float32 in cross_entropy_forward to prevent errors (#1254)

* Fix: cast logits to float32 in cross_entropy_forward to prevent errors

* Update cross_entropy_loss.py

---------

Co-authored-by: Daniel Han <[email protected]>

* Throw error when inferencing longer than max_popsition_embeddings (#1236)

* Throw error when inferencing longer than max_popsition_embeddings without rope scaling

* Update llama.py

---------

Co-authored-by: Daniel Han <[email protected]>

* CLI now handles user input strings for dtype correctly (#1235)

Co-authored-by: root <[email protected]>

* Update flex_attention.py

* Update _utils.py

* Update _utils.py

* Update flex_attention.py

* Update flex_attention.py

* Update loader.py

* Update loader.py

* Update flex_attention.py

* Update flex_attention.py

* Update flex_attention.py

* Update flex_attention.py

* Update _utils.py

* Update cross_entropy_loss.py

* Update _utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* triton_cast

* Update utils.py

* Qwen 2.5 Coder

* Fix/export mistral (#1281)

* Enhance install_python_non_blocking to handle protobuf installation and process management

* Revert "Enhance install_python_non_blocking to handle protobuf installation and process management"

This reverts commit f09974b.

* Set PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION to 'python' to address issue #1266

* Revert "Set PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION to 'python' to address issue #1266"

This reverts commit 9fc1307.

* Set PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION to 'python' to address issue #1266

* Update __init__.py

---------

Co-authored-by: Daniel Han <[email protected]>

* DOC Update - Update README.md with os.environ in example (#1269)

* Update README.md with os.environ in example

Added OS Environ in example to avoid device conflicts , for a user at least in jupyter notebook this allows to select GPU in a multi GPU setup. 
As currently the  unsloth init checks all GPU's and takes the first in the order which can be a issue when some GPU's are in use and the list still shows them. So to manually avoid this, this os config is required.
Small change but a bit time saver for those who straight away copies the tutorials

* Update README.md

---------

Co-authored-by: Daniel Han <[email protected]>

* fix/get_chat_template (#1246)

* Refactor `get_chat_template` to now support system message instead. It supposed to fix ollama tokenizer chattemplate to

* Remove type hinting

* Update chat_templates.py

---------

Co-authored-by: Daniel Han <[email protected]>

* fix/sft-trainer (#1276)

* Add patch for SFTTrainer to maintain backward compatibility with TRL changes

* Update trainer.py

* Update trainer.py

* Refactor trainer patch to maintain backward compatibility with TRL changes

* Update trainer.py

* Refactor trainer.py to exclude non-convertible trainers from backward compatibility patch

---------

Co-authored-by: Daniel Han <[email protected]>

* Update __init__.py

* Update trainer.py

* Update trainer.py

* Update trainer.py

* Update tokenizer_utils.py

---------

Co-authored-by: timothelaborie <[email protected]>
Co-authored-by: Ikko Eltociear Ashimine <[email protected]>
Co-authored-by: Edd <[email protected]>
Co-authored-by: Datta Nimmaturi <[email protected]>
Co-authored-by: dendarrion <[email protected]>
Co-authored-by: Erland366 <[email protected]>
Co-authored-by: Edwin Fennell <[email protected]>
Co-authored-by: root <[email protected]>
Co-authored-by: Uday Girish Maradana <[email protected]>
  • Loading branch information
10 people authored Nov 14, 2024
1 parent d8ff860 commit 0de5457
Show file tree
Hide file tree
Showing 5 changed files with 211 additions and 20 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,9 @@ DPO (Direct Preference Optimization), PPO, Reward Modelling all seem to work as
We're in 🤗Hugging Face's official docs! We're on the [SFT docs](https://huggingface.co/docs/trl/main/en/sft_trainer#accelerate-fine-tuning-2x-using-unsloth) and the [DPO docs](https://huggingface.co/docs/trl/main/en/dpo_trainer#accelerate-dpo-fine-tuning-using-unsloth)!

```python
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Optional set GPU device ID

from unsloth import FastLanguageModel, PatchDPOTrainer
from unsloth import is_bfloat16_supported
PatchDPOTrainer()
Expand Down
7 changes: 7 additions & 0 deletions unsloth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
# enabling it will require much more work, so we have to prioritize. Please understand!
# We do have a beta version, which you can contact us about!
# Thank you for your understanding and we appreciate it immensely!

# Fixes https://github.com/unslothai/unsloth/issues/1266
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"

if "CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
devices = os.environ["CUDA_VISIBLE_DEVICES"]
Expand Down Expand Up @@ -172,3 +176,6 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16
from .chat_templates import *
from .tokenizer_utils import *
from .trainer import *

# Patch TRL trainers for backwards compatibility
_patch_trl_trainer()
97 changes: 89 additions & 8 deletions unsloth/chat_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
train_on_responses_only,
)
CHAT_TEMPLATES = {}
DEFAULT_SYSTEM_MESSAGE = {}

# =========================================== Unsloth
# Unsloth efficient template leverages from Zephyr
Expand All @@ -48,7 +49,7 @@
"{{ messages[0]['content'] + '\n' }}"\
"{% set loop_messages = messages[1:] %}"\
"{% else %}"\
"{{ 'You are a helpful assistant to the user\n' }}"\
"{{ '{system_message}' + '\n' }}"\
"{% set loop_messages = messages %}"\
"{% endif %}"\
"{% for message in loop_messages %}"\
Expand Down Expand Up @@ -80,6 +81,7 @@

unsloth_eos_token = "eos_token"
CHAT_TEMPLATES["unsloth"] = (unsloth_template, unsloth_eos_token, False, unsloth_ollama,)
DEFAULT_SYSTEM_MESSAGE["unsloth"] = "You are a helpful assistant to the user"
pass

# =========================================== Zephyr
Expand Down Expand Up @@ -116,6 +118,7 @@

zephyr_eos_token = "eos_token"
CHAT_TEMPLATES["zephyr"] = (zephyr_template, zephyr_eos_token, False, zephyr_ollama,)
DEFAULT_SYSTEM_MESSAGE["zephyr"] = None # No system message in Zephyr
pass

# =========================================== ChatML
Expand Down Expand Up @@ -153,6 +156,7 @@

chatml_eos_token = "<|im_end|>"
CHAT_TEMPLATES["chatml"] = (chatml_template, chatml_eos_token, True, chatml_ollama,)
DEFAULT_SYSTEM_MESSAGE["chatml"] = None # No system message in ChatML
pass

# =========================================== Mistral-1
Expand Down Expand Up @@ -193,6 +197,7 @@

mistral_eos_token = "eos_token"
CHAT_TEMPLATES["mistral"] = (mistral_template, mistral_eos_token, False, mistral_ollama,)
DEFAULT_SYSTEM_MESSAGE["mistral"] = None # No system message in Mistral
pass

# =========================================== Llama-2
Expand Down Expand Up @@ -234,6 +239,7 @@

llama_eos_token = "eos_token"
CHAT_TEMPLATES["llama"] = (llama_template, llama_eos_token, False, llama_ollama,)
DEFAULT_SYSTEM_MESSAGE["llama"] = None # No system message in Llama
pass

# =========================================== Vicuna
Expand All @@ -244,7 +250,7 @@
"{{ messages[0]['content'] + ' ' }}"\
"{% set loop_messages = messages[1:] %}"\
"{% else %}"\
"{{ 'A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user\\'s questions.' + ' ' }}"\
"{{ '{system_message}' + ' ' }}"\
"{% set loop_messages = messages %}"\
"{% endif %}"\
"{% for message in loop_messages %}"\
Expand Down Expand Up @@ -273,6 +279,7 @@

vicuna_eos_token = "eos_token"
CHAT_TEMPLATES["vicuna"] = (vicuna_template, vicuna_eos_token, False, vicuna_ollama,)
DEFAULT_SYSTEM_MESSAGE["vicuna"] = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."
pass

# =========================================== Vicuna Old
Expand All @@ -283,7 +290,7 @@
"{{ messages[0]['content'] + '\n' }}"\
"{% set loop_messages = messages[1:] %}"\
"{% else %}"\
"{{ 'A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human\\'s questions.' + '\n' }}"\
"{{ '{system_message}' + '\n' }}"\
"{% set loop_messages = messages %}"\
"{% endif %}"\
"{% for message in loop_messages %}"\
Expand Down Expand Up @@ -315,6 +322,10 @@

vicuna_old_eos_token = "eos_token"
CHAT_TEMPLATES["vicuna_old"] = (vicuna_old_template, vicuna_old_eos_token, False, vicuna_old_ollama,)
DEFAULT_SYSTEM_MESSAGE["vicuna_old"] = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human\\'s questions."

CHAT_TEMPLATES["vicuna old"] = CHAT_TEMPLATES["vicuna_old"]
DEFAULT_SYSTEM_MESSAGE["vicuna old"] = DEFAULT_SYSTEM_MESSAGE["vicuna_old"]
pass

# =========================================== Alpaca multi turn
Expand All @@ -325,7 +336,7 @@
"{{ messages[0]['content'] + '\n\n' }}"\
"{% set loop_messages = messages[1:] %}"\
"{% else %}"\
"{{ 'Below are some instructions that describe some tasks. Write responses that appropriately complete each request.\n\n' }}"\
"{{ '{system_message}' + '\n\n' }}"\
"{% set loop_messages = messages %}"\
"{% endif %}"\
"{% for message in loop_messages %}"\
Expand Down Expand Up @@ -362,6 +373,7 @@

alpaca_eos_token = "eos_token"
CHAT_TEMPLATES["alpaca"] = (alpaca_template, alpaca_eos_token, False, alpaca_ollama,)
DEFAULT_SYSTEM_MESSAGE["alpaca"] = "Below are some instructions that describe some tasks. Write responses that appropriately complete each request."
pass

# =========================================== Gemma
Expand All @@ -372,7 +384,7 @@
"{{ bos_token }}"\
"{% if messages[0]['role'] == 'system' %}"\
"{{'<start_of_turn>user\n' + messages[0]['content'] | trim + ' ' + messages[1]['content'] | trim + '<end_of_turn>\n'}}"\
"{% set loop_messages = messages[2:] %}"\
"{% set messages = messages[2:] %}"\
"{% endif %}"\
"{% for message in messages %}"\
"{% if message['role'] == 'user' %}"\
Expand Down Expand Up @@ -407,6 +419,7 @@

gemma_eos_token = "<end_of_turn>"
CHAT_TEMPLATES["gemma"] = (gemma_template, gemma_eos_token, True, gemma_ollama,)
DEFAULT_SYSTEM_MESSAGE["gemma"] = None # No system message in Gemma
pass

# =========================================== Gemma with ChatML instead
Expand Down Expand Up @@ -437,6 +450,7 @@
"<|im_end|>",
)
CHAT_TEMPLATES["gemma_chatml"] = (gemma_chatml_template, gemma_chatml_eos_token, True, gemma_chatml_ollama,)
DEFAULT_SYSTEM_MESSAGE["gemma_chatml"] = None # No system message in Gemma
pass

# =========================================== Gemma 2
Expand All @@ -446,12 +460,14 @@
gemma2_ollama = gemma_ollama + "PARAMETER num_ctx 4096\n"
gemma2_eos_token = "<end_of_turn>"
CHAT_TEMPLATES["gemma2"] = (gemma2_template, gemma2_eos_token, True, gemma2_ollama,)
DEFAULT_SYSTEM_MESSAGE["gemma2"] = None # No system message in Gemma 2

# =========================================== Gemma 2 with ChatML instead
gemma2_chatml_template = gemma_chatml_template
gemma2_chatml_ollama = gemma_chatml_ollama + "PARAMETER num_ctx 4096\n"
gemma2_chatml_eos_token = gemma_chatml_eos_token
CHAT_TEMPLATES["gemma2_chatml"] = (gemma2_chatml_template, gemma2_chatml_eos_token, True, gemma2_chatml_ollama,)
DEFAULT_SYSTEM_MESSAGE["gemma2_chatml"] = None # No system message in Gemma 2
pass

# =========================================== Llama-3
Expand Down Expand Up @@ -491,7 +507,12 @@
'''

llama3_template_eos_token = "eos_token"

CHAT_TEMPLATES["llama-3"] = (llama3_template, llama3_template_eos_token, False, llama3_ollama,)
DEFAULT_SYSTEM_MESSAGE["llama-3"] = None # No system message in Llama-3

CHAT_TEMPLATES["llama3"] = (llama3_template, llama3_template_eos_token, False, llama3_ollama,)
DEFAULT_SYSTEM_MESSAGE["llama3"] = None # No system message in Llama-3
pass


Expand Down Expand Up @@ -532,8 +553,13 @@

phi3_template_eos_token = "<|end|>"
CHAT_TEMPLATES["phi-3"] = (phi3_template, phi3_template_eos_token, False, phi3_ollama,)
DEFAULT_SYSTEM_MESSAGE["phi-3"] = None # No system message in Phi-3

CHAT_TEMPLATES["phi-35"] = CHAT_TEMPLATES["phi-3"]
DEFAULT_SYSTEM_MESSAGE["phi-35"] = None # No system message in Phi-3.5

CHAT_TEMPLATES["phi-3.5"] = CHAT_TEMPLATES["phi-3"]
DEFAULT_SYSTEM_MESSAGE["phi-3.5"] = None # No system message in Phi-3.5
pass

# =========================================== Llama-3.1
Expand Down Expand Up @@ -573,7 +599,7 @@
{%- set system_message = messages[0]['content'] %}
{%- set messages = messages[1:] %}
{%- else %}
{%- set system_message = "" %}
{%- set system_message = "{system_message}" %}
{%- endif %}
{#- System message + builtin tools #}
Expand Down Expand Up @@ -729,7 +755,10 @@

llama31_template_eos_token = "eos_token"
CHAT_TEMPLATES["llama-3.1"] = (llama31_template, llama31_template_eos_token, False, llama31_ollama,)
DEFAULT_SYSTEM_MESSAGE["llama-3.1"] = "" # Llama3.1 default system message is empty + the dates

CHAT_TEMPLATES["llama-31"] = (llama31_template, llama31_template_eos_token, False, llama31_ollama,)
DEFAULT_SYSTEM_MESSAGE["llama-31"] = "" # Llama3.1 default system message is empty + the dates
pass


Expand All @@ -751,7 +780,7 @@
{%- if messages[0][\'role\'] == \'system\' %}
{{- \'<|im_start|>system\\n\' + messages[0][\'content\'] + \'<|im_end|>\\n\' }}
{%- else %}
{{- \'<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n\' }}
{{- \'<|im_start|>system\\n{system_message}<|im_end|>\\n\' }}
{%- endif %}\n{%- endif %}\n{%- for message in messages %}
{%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %}
{{- \'<|im_start|>\' + message.role + \'\\n\' + message.content + \'<|im_end|>\' + \'\\n\' }}
Expand Down Expand Up @@ -847,10 +876,53 @@
'''

qwen25_template_eos_token = "eos_token"
qwen25_default_system_message = "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
CHAT_TEMPLATES["qwen-2.5"] = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,)
DEFAULT_SYSTEM_MESSAGE["qwen-2.5"] = qwen25_default_system_message # No system message in Qwen 2.5

CHAT_TEMPLATES["qwen-25"] = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,)
DEFAULT_SYSTEM_MESSAGE["qwen-25"] = qwen25_default_system_message # No system message in Qwen 2.5

CHAT_TEMPLATES["qwen25"] = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,)
DEFAULT_SYSTEM_MESSAGE["qwen25"] = qwen25_default_system_message # No system message in Qwen 2.5

CHAT_TEMPLATES["qwen2.5"] = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,)
DEFAULT_SYSTEM_MESSAGE["qwen2.5"] = qwen25_default_system_message # No system message in Qwen 2.5
pass

def _change_system_message(template: str, type_chat_template: str, system_message: str = None):
system_message_pattern = r"\{system_message\}"

# For predefined templates, check if default system message exists
default_system_message = DEFAULT_SYSTEM_MESSAGE.get(f"{type_chat_template}", None)
if default_system_message is None:
if system_message is not None:
logger.warning_once(
f"Unsloth: You tried to change the system message for {type_chat_template}, "
"but it doesn't have a default system message. "
"You need to manually add the system message in your data."
)
return template, system_message
pass

# For custom templates
if type_chat_template is None:
has_placeholder = re.search(system_message_pattern, template) is not None

if has_placeholder:
if system_message is None:
raise ValueError("Unsloth: You need to provide a system message for custom templates.")
new_template = re.sub(system_message_pattern, system_message, template)
return new_template, system_message

return template, system_message
pass

# For predefined templates with default system message
message_to_use = system_message if system_message is not None else default_system_message
new_template = re.sub(system_message_pattern, message_to_use, template)

return new_template, message_to_use
pass


Expand Down Expand Up @@ -886,14 +958,20 @@ def get_chat_template(
old_padding_side = tokenizer.padding_side

same_padding_token = False

type_chat_template = None

if type(chat_template) in (list, tuple,):
# For changing system message later
# Since it's not supported yet, we will raise an error first!
type_chat_template = chat_template[0].lower()
chat_template, stop_word = chat_template
assert(type(chat_template) is str)
assert(type(stop_word) is str)
ollama_modelfile = None

elif type(chat_template) is str:
# For changing system message later
type_chat_template = chat_template.lower()

chat_template, stop_word, yes_map_eos_token, ollama_modelfile = CHAT_TEMPLATES[chat_template]

Expand Down Expand Up @@ -1052,6 +1130,9 @@ def get_chat_template(
else:
chat_template = new_chat_template
pass

chat_template, system_message = _change_system_message(chat_template, type_chat_template, system_message)

tokenizer.chat_template = chat_template

# Also fix up other tokens
Expand Down
4 changes: 2 additions & 2 deletions unsloth/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,10 +586,10 @@ def load_correct_tokenizer(


def _fix_chat_template(chat_template):
endfor = "{% endfor %}"
endfor = "{% endif %}"
where = chat_template.find(endfor)
if where == -1:
endfor = "{%- endfor %}"
endfor = "{%- endif %}"
where = chat_template.find(endfor)
if where == -1:
return chat_template
Expand Down
Loading

0 comments on commit 0de5457

Please sign in to comment.