update scripts
This commit is contained in:
parent
8845e94f91
commit
1e0c860c8c
|
@ -44,6 +44,7 @@ def calculate_lr(
|
||||||
template: str = "default",
|
template: str = "default",
|
||||||
cutoff_len: int = 1024, # i.e. maximum input length during training
|
cutoff_len: int = 1024, # i.e. maximum input length during training
|
||||||
is_mistral: bool = False, # mistral model uses a smaller learning rate,
|
is_mistral: bool = False, # mistral model uses a smaller learning rate,
|
||||||
|
packing: bool = False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
|
Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
|
||||||
|
@ -57,6 +58,7 @@ def calculate_lr(
|
||||||
dataset_dir=dataset_dir,
|
dataset_dir=dataset_dir,
|
||||||
template=template,
|
template=template,
|
||||||
cutoff_len=cutoff_len,
|
cutoff_len=cutoff_len,
|
||||||
|
packing=packing,
|
||||||
output_dir="dummy_dir",
|
output_dir="dummy_dir",
|
||||||
overwrite_cache=True,
|
overwrite_cache=True,
|
||||||
)
|
)
|
||||||
|
@ -69,7 +71,7 @@ def calculate_lr(
|
||||||
elif stage == "sft":
|
elif stage == "sft":
|
||||||
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
|
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError("Stage does not supported: {}.".format(stage))
|
||||||
|
|
||||||
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
|
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
|
||||||
valid_tokens, total_tokens = 0, 0
|
valid_tokens, total_tokens = 0, 0
|
||||||
|
|
|
@ -98,7 +98,7 @@ def cal_ppl(
|
||||||
tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX, train_on_prompt=train_on_prompt
|
tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX, train_on_prompt=train_on_prompt
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError("Stage does not supported: {}.".format(stage))
|
||||||
|
|
||||||
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
|
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
|
||||||
criterion = torch.nn.CrossEntropyLoss(reduction="none")
|
criterion = torch.nn.CrossEntropyLoss(reduction="none")
|
||||||
|
|
Loading…
Reference in New Issue