Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into haojun/tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zzhhjjj committed May 2, 2024
2 parents 1df2792 + 693628e commit da7cf7a
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 23 deletions.
30 changes: 20 additions & 10 deletions examples/config_tiny_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,26 @@
tp_linear_async_communication=True,
)

tokens = TokensArgs(sequence_length=32, train_steps=10, micro_batch_size=2, batch_accumulation_per_replica=1)
tokens = TokensArgs(sequence_length=256, train_steps=15, micro_batch_size=2, batch_accumulation_per_replica=1)

dataset = PretrainDatasetsArgs(
hf_dataset_or_datasets="HuggingFaceH4/testing_alpaca_small", text_column_name="completion"
)
data_stages = [
DatasetStageArgs(
name="Stable Training Stage",
start_training_step=1,
data=DataArgs(
dataset=PretrainDatasetsArgs(hf_dataset_or_datasets="stas/openwebtext-10k", text_column_name="text"),
seed=seed,
),
),
DatasetStageArgs(
name="Annealing Phase",
start_training_step=10,
data=DataArgs(
dataset=PretrainDatasetsArgs(hf_dataset_or_datasets="stas/openwebtext-10k", text_column_name="text"),
seed=seed,
),
),
]

checkpoints_path = os.path.dirname(os.path.dirname(__file__)) + "/checkpoints"
os.makedirs(checkpoints_path, exist_ok=True)
Expand All @@ -99,12 +114,7 @@
optimizer=optimizer,
logging=LoggingArgs(),
tokens=tokens,
data_stages=[
DatasetStageArgs(
name="Stable Training Stage", start_training_step=1, data=DataArgs(dataset=dataset, seed=seed)
),
DatasetStageArgs(name="Annealing Phase", start_training_step=10, data=DataArgs(dataset=dataset, seed=seed)),
],
data_stages=data_stages,
profiler=None,
)

Expand Down
8 changes: 4 additions & 4 deletions examples/config_tiny_llama.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ data_stages:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: HuggingFaceH4/testing_alpaca_small
hf_dataset_or_datasets: stas/openwebtext-10k
hf_dataset_splits: train
text_column_name: completion
text_column_name: text
num_loading_workers: 1
seed: 42
name: Stable Training Stage
Expand All @@ -22,9 +22,9 @@ data_stages:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: HuggingFaceH4/testing_alpaca_small
hf_dataset_or_datasets: stas/openwebtext-10k
hf_dataset_splits: train
text_column_name: completion
text_column_name: text
num_loading_workers: 1
seed: 42
name: Annealing Phase
Expand Down
7 changes: 5 additions & 2 deletions run_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,12 @@ def main():
tokenizer.padding_side = "left"
tokenizer.truncation_side = "left" # TODO @nouamane: do we want this?
dummy_inputs = [
# "Passage: Daniel went back to the garden. Mary travelled to the kitchen. Sandra journeyed to the kitchen. Sandra went to the hallway. John went to the bedroom. Mary went back to the garden. Where is Mary?\nAnswer:",
"The future of AI is",
"Passage: Daniel went back to the garden. Mary travelled to the kitchen. Sandra journeyed to the kitchen. Sandra went to the hallway. John went to the bedroom. Mary went back to the garden. Where is Mary?\nAnswer:",
"def fib(n)",
# "This film was probably inspired by Godzilla",
'Here is an extract from a webpage: "Have you ever experienced heel pain after a heavy physical activity, or even right after a long period of standing? If you regard this as something usual and normal, then think again. Miscalled as heel pain, plantar fasciitis causes these frequent mild pains experienced in the soles of the feet. It is the inflammation and enlargement the plantar fascia tissue that is located in the heels of the feet, stretching to the base of the toes. This tissue is responsible for absorbing shock in the feet and for supporting the arches. It also plays a vital role in foot movements during walking and standing. Many factors such as excessive walking, standing, and running trigger heel pain and plantar fasciitis. A sudden increase in intensity of activities, increase in weight, and abrupt change of footwear also cause the swelling of the ligament. Non-supportive footwear lacking arch cushions and improper and worn out running or training can also lead to the problem. It is also most evident among those". Write an extensive and detailed course unit suitable for a textbook targeted at college students, related to the given extract, within the context of "Medicine". Do not just list concepts, but develop each one in detail before moving to the next, as we prioritize depth of understanding and comprehensive exploration of the subject matter over breadth. Focus on: - Rigor: Ensure in-depth coverage of the concepts/sections. - Engagement: Write with an academic, professional and engaging tone that captivates interest. - Application: Incorporate specific, practical examples, such as proofs in calculus or critical dates and figures in history. Do not include a title or an introduction, simply write the content without headlines and introductory phrases. Do not use images.',
"Advancements in technology will lead to",
"Tomorrow's world is shaped by",
]

outputs = decode_text(
Expand Down
8 changes: 8 additions & 0 deletions src/nanotron/generation/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,14 @@ def decode_text(

p2p = model.p2p

# replicate input for n_samples times when using TOP_P or TOP_K samplers, in order to get diverse results
if generation_config and generation_config.n_samples:
if sampler_type != SamplerType.TOP_P and sampler_type != SamplerType.TOP_K:
raise ValueError("Only support n_samples for TOP_P and TOP_K sampler")
input_iter = [
GenerationInput(text=input.text) for input in input_iter for _ in range(generation_config.n_samples)
]

# That's annoying but I need this as soon as there's a change communication "cross"
pipeline_state = PipelineEvalBatchState()
with attach_pipeline_state_to_model(model=model, pipeline_state=pipeline_state):
Expand Down
10 changes: 3 additions & 7 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,13 +656,9 @@ def init_model(self) -> Union[NanotronModel, DistributedDataParallel]:
rank=0,
)
else:
log_rank(
f"Setting max_position_embeddings to {self.config.tokens.sequence_length}. Previous value was {self.model_config.max_position_embeddings}.",
logger=logger,
level=logging.INFO,
rank=0,
)
self.model_config.max_position_embeddings = self.config.tokens.sequence_length
assert (
self.config.tokens.sequence_length == self.model_config.max_position_embeddings
), "The tokenizer's sequence length does not match the model's maximum position embeddings."

log_rank("Config:\n" + pformat(self.config), logger=logger, level=logging.INFO, rank=0)
log_rank("Model Config:\n" + pformat(self.model_config), logger=logger, level=logging.INFO, rank=0)
Expand Down

0 comments on commit da7cf7a

Please sign in to comment.