From 38af076a75c33da26d641780820694e4b7342d92 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sat, 20 Jan 2024 19:25:22 +0800 Subject: [PATCH] support longlora for main branch --- src/llmtuner/extras/patches/llama_patch.py | 302 +++++++++------------ src/llmtuner/model/adapter.py | 4 + src/llmtuner/model/patcher.py | 27 +- src/llmtuner/train/dpo/trainer.py | 19 +- src/llmtuner/webui/components/train.py | 9 +- src/llmtuner/webui/locales.py | 10 + src/llmtuner/webui/runner.py | 1 + 7 files changed, 168 insertions(+), 204 deletions(-) diff --git a/src/llmtuner/extras/patches/llama_patch.py b/src/llmtuner/extras/patches/llama_patch.py index a9f5da28..e7d9e48f 100644 --- a/src/llmtuner/extras/patches/llama_patch.py +++ b/src/llmtuner/extras/patches/llama_patch.py @@ -3,222 +3,166 @@ import torch import torch.nn as nn from typing import Optional, Tuple from transformers.utils import logging -from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb - -try: - from transformers.models.llama.modeling_llama import repeat_kv -except ImportError: - print("Please upgrade `transformers`.") - -from ..packages import is_flash_attn2_available - - -if is_flash_attn2_available(): - from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore - from flash_attn.bert_padding import pad_input, unpad_input # type: ignore +from transformers.models.llama.modeling_llama import ( + Cache, LlamaAttention, LlamaFlashAttention2, apply_rotary_pos_emb, repeat_kv +) logger = logging.get_logger(__name__) # Modified from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py -class LlamaShiftShortAttention(LlamaAttention): +def llama_torch_attn_forward( + self: "LlamaAttention", + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional["Cache"] = None, + output_attentions: bool = False, + **kwargs +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - if past_key_value is not None: # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) - past_key_value = (key_states, value_states) if use_cache else None - - if getattr(self, "num_key_value_groups"): - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - if getattr(self.config, "group_size_ratio", None) and self.training: # shift - groupsz = int(q_len * getattr(self.config, "group_size_ratio")) - assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) - num_groups = q_len // groupsz - def shift(state: torch.Tensor) -> torch.Tensor: - state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim) - state = torch.cat(( - state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1) - ), dim=2) - return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2) - - query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states) - if attention_mask is not None: - attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if getattr(self.config, "group_size_ratio", None) and self.training: # shift + groupsz = int(q_len * getattr(self.config, "group_size_ratio")) + assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) + num_groups = q_len // groupsz + def shift(state: torch.Tensor) -> torch.Tensor: + state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim) + state = torch.cat(( + state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1) + ), dim=2) + return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2) + query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states) if attention_mask is not None: - attn_weights = attn_weights + attention_mask + attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1) - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :) - attn_output = attn_output.transpose(1, 2).contiguous() + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if getattr(self.config, "group_size_ratio", None) and self.training: # shift back - attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim) - attn_output = torch.cat(( - attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1) - )) + if attention_mask is not None: + attn_weights = attn_weights + attention_mask - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - attn_output = self.o_proj(attn_output) + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :) + attn_output = attn_output.transpose(1, 2).contiguous() - if not output_attentions: - attn_weights = None + if getattr(self.config, "group_size_ratio", None) and self.training: # shift back + attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim) + attn_output = torch.cat(( + attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1) + )) - return attn_output, attn_weights, past_key_value + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value -class LlamaFlashAttention2(LlamaAttention): +# Modified from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py +def llama_flash_attn_forward( + self: "LlamaFlashAttention2", + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + **kwargs +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # LlamaFlashAttention2 attention does not support output_attentions + output_attentions = False - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # LlamaFlashAttention2 attention does not support output_attentions - output_attentions = False + bsz, q_len, _ = hidden_states.size() - bsz, q_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + # FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - # FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - if past_key_value is not None: # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) - past_key_value = (key_states, value_states) if use_cache else None + query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim) + key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim) + value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim) - # cast to half precision - input_dtype = query_states.dtype - if input_dtype == torch.float32: - logger.warning_once("The input hidden states seems to be silently casted in float32.") - query_states = query_states.to(self.config.torch_dtype) - key_states = key_states.to(self.config.torch_dtype) - value_states = value_states.to(self.config.torch_dtype) + dropout_rate = self.attention_dropout if self.training else 0.0 - if getattr(self, "num_key_value_groups", None): - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim) - key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim) - value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim) - - if getattr(self.config, "group_size_ratio", None) and self.training: # shift - groupsz = int(q_len * getattr(self.config, "group_size_ratio")) - assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) - num_groups = q_len // groupsz - def shift(state: torch.Tensor) -> torch.Tensor: - state = torch.cat(( - state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1) - ), dim=2) - return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim) - - query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states) - if attention_mask is not None: - attention_mask = attention_mask.reshape(bsz * num_groups, groupsz) + if getattr(self.config, "group_size_ratio", None) and self.training: # shift + groupsz = int(q_len * getattr(self.config, "group_size_ratio")) + assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) + num_groups = q_len // groupsz + def shift(state: torch.Tensor) -> torch.Tensor: + state = torch.cat(( + state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1) + ), dim=2) + return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim) + query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states) if attention_mask is not None: - logger.warning_once("Padded sequences are less efficient in FlashAttention.") - # -q_len: assumes left padding when q_len != kv_len - unpadded_q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query_states, attention_mask[:, -q_len:]) - unpadded_k, _, cu_seqlens_k, max_seqlen_k = unpad_input(key_states, attention_mask) - unpadded_v, _, _, _ = unpad_input(value_states, attention_mask) - attn_output_unpad = flash_attn_varlen_func( - unpadded_q, - unpadded_k, - unpadded_v, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - dropout_p=0.0, - softmax_scale=None, - causal=True, - ) - attn_output = pad_input(attn_output_unpad, indices_q, bsz, q_len) - else: - attn_output = flash_attn_func( - query_states, key_states, value_states, 0.0, softmax_scale=None, causal=True - ) + attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1) - if getattr(self.config, "group_size_ratio", None) and self.training: # shift back - attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim) - attn_output = torch.cat(( - attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1) - )) + attn_output: torch.Tensor = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate + ) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() - attn_output = self.o_proj(attn_output) + if getattr(self.config, "group_size_ratio", None) and self.training: # shift back + attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim) + attn_output = torch.cat(( + attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1) + )) - if not output_attentions: - attn_weights = None + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) - return attn_output, attn_weights, past_key_value + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value -# Disable the transformation of the attention mask in LlamaModel as flash attention -# takes a boolean padding_mask. Fills in the past kv length for use in forward. -def _prepare_decoder_attention_mask( - self, - attention_mask: torch.Tensor, - input_shape: torch.Tensor, - inputs_embeds: torch.Tensor, - past_key_values_length: int -) -> torch.Tensor: - if attention_mask is not None and torch.all(attention_mask): - return None # This uses the faster call when training with full samples - - return attention_mask +def apply_llama_patch() -> None: + LlamaAttention.forward = llama_torch_attn_forward + LlamaFlashAttention2.forward = llama_flash_attn_forward diff --git a/src/llmtuner/model/adapter.py b/src/llmtuner/model/adapter.py index 83a63b96..a7e9b9fe 100644 --- a/src/llmtuner/model/adapter.py +++ b/src/llmtuner/model/adapter.py @@ -1,4 +1,5 @@ import torch +import inspect from typing import TYPE_CHECKING from transformers.integrations import is_deepspeed_zero3_enabled from peft import PeftModel, TaskType, LoraConfig, get_peft_model @@ -108,6 +109,9 @@ def init_adapter( if model_args.use_unsloth: from unsloth import FastLlamaModel, FastMistralModel # type: ignore unsloth_peft_kwargs = {"model": model, "max_seq_length": model_args.model_max_length} + if "loftq_config" in inspect.signature(FastLlamaModel.get_peft_model).parameters: + unsloth_peft_kwargs["loftq_config"] = {} + if getattr(model.config, "model_type", None) == "llama": model = FastLlamaModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs) elif getattr(model.config, "model_type", None) == "mistral": diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index a3e88f3e..5af9f5af 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -15,6 +15,7 @@ from ..extras.constants import FILEEXT2TYPE, LAYERNORM_NAMES from ..extras.logging import get_logger from ..extras.misc import get_current_device, infer_optim_dtype from ..extras.packages import is_flash_attn2_available +from ..extras.patches.llama_patch import apply_llama_patch if TYPE_CHECKING: from transformers import PretrainedConfig, PreTrainedTokenizer @@ -23,7 +24,7 @@ if TYPE_CHECKING: logger = get_logger(__name__) -SUPPORTED_CLASS_FOR_S2ATTN = [] # TODO: add llama +SUPPORTED_CLASS_FOR_S2ATTN = ["llama"] def _noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int): @@ -39,26 +40,25 @@ def _resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToke Resize token embeddings. """ if is_deepspeed_zero3_enabled(): - import deepspeed - with deepspeed.zero.GatheredParameters(model.get_input_embeddings().weight, modifier_rank=None): - current_embedding_size = model.get_input_embeddings().weight.size(0) + import deepspeed # type: ignore + params = [model.get_input_embeddings().weight] + if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings: + params.append(model.get_output_embeddings().weight) + + context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0) else: + context_maybe_zero3 = nullcontext() + + with context_maybe_zero3: current_embedding_size = model.get_input_embeddings().weight.size(0) + if len(tokenizer) > current_embedding_size: if not isinstance(model.get_output_embeddings(), torch.nn.Linear): logger.warning("Current model does not support resizing token embeddings.") return model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64) - if is_deepspeed_zero3_enabled(): - import deepspeed - params = [model.get_input_embeddings().weight] - if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings: - params.append(model.get_output_embeddings().weight) - context = deepspeed.zero.GatheredParameters(params, modifier_rank=0) - else: - context = nullcontext() - with context: + with context_maybe_zero3: new_embedding_size = model.get_input_embeddings().weight.size(0) num_new_tokens = new_embedding_size - current_embedding_size _noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens) @@ -136,6 +136,7 @@ def _configure_flashattn(config_kwargs: Dict[str, Any]) -> None: def _configure_longlora(config: "PretrainedConfig") -> None: if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN: setattr(config, "group_size_ratio", 0.25) + apply_llama_patch() logger.info("Using shift short attention with group_size_ratio=1/4.") else: logger.warning("Current model does not support shift short attention.") diff --git a/src/llmtuner/train/dpo/trainer.py b/src/llmtuner/train/dpo/trainer.py index b5a44f5e..b8d59c8e 100644 --- a/src/llmtuner/train/dpo/trainer.py +++ b/src/llmtuner/train/dpo/trainer.py @@ -1,4 +1,5 @@ import torch +from contextlib import nullcontext from collections import defaultdict from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union from transformers import BatchEncoding, Trainer @@ -93,7 +94,8 @@ class CustomDPOTrainer(DPOTrainer): all_logps = self.get_batch_logps( all_logits, batch["labels"], - average_log_prob=False + average_log_prob=False, + label_pad_token_id=self.label_pad_token_id, ) batch_size = batch["input_ids"].size(0) // 2 chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0) @@ -118,20 +120,19 @@ class CustomDPOTrainer(DPOTrainer): ) = self.concatenated_forward(model, batch) with torch.no_grad(): if self.ref_model is None: - with self.accelerator.unwrap_model(self.model).disable_adapter(): - ( - reference_chosen_logps, - reference_rejected_logps, - _, - _, - ) = self.concatenated_forward(self.model, batch) + ref_model = self.model + ref_context = self.accelerator.unwrap_model(self.model).disable_adapter() else: + ref_model = self.ref_model + ref_context = nullcontext() + + with ref_context: ( reference_chosen_logps, reference_rejected_logps, _, _, - ) = self.concatenated_forward(self.ref_model, batch) + ) = self.concatenated_forward(ref_model, batch) losses, chosen_rewards, rejected_rewards = self.dpo_loss( policy_chosen_logps, diff --git a/src/llmtuner/webui/components/train.py b/src/llmtuner/webui/components/train.py index 08e861f0..5689f7ad 100644 --- a/src/llmtuner/webui/components/train.py +++ b/src/llmtuner/webui/components/train.py @@ -95,7 +95,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: with gr.Accordion(label="RLHF config", open=False) as rlhf_tab: with gr.Row(): dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1) - reward_model = gr.Dropdown(scale=3, allow_custom_value=True) + dpo_ftx = gr.Slider(value=0, minimum=0, maximum=10, step=0.01, scale=1) + reward_model = gr.Dropdown(scale=2, allow_custom_value=True) refresh_btn = gr.Button(scale=1) refresh_btn.click( @@ -105,8 +106,10 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: queue=False ) - input_elems.update({dpo_beta, reward_model}) - elem_dict.update(dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, reward_model=reward_model, refresh_btn=refresh_btn)) + input_elems.update({dpo_beta, dpo_ftx, reward_model}) + elem_dict.update(dict( + rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model, refresh_btn=refresh_btn + )) with gr.Row(): cmd_preview_btn = gr.Button() diff --git a/src/llmtuner/webui/locales.py b/src/llmtuner/webui/locales.py index 9ba08e25..60778b67 100644 --- a/src/llmtuner/webui/locales.py +++ b/src/llmtuner/webui/locales.py @@ -421,6 +421,16 @@ LOCALES = { "info": "DPO 损失函数中 beta 超参数大小。" } }, + "dpo_ftx": { + "en": { + "label": "DPO-ftx weight", + "info": "The weight of SFT loss in the DPO-ftx." + }, + "zh": { + "label": "DPO-ftx 权重", + "info": "DPO-ftx 中 SFT 损失的权重大小。" + } + }, "reward_model": { "en": { "label": "Reward model", diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index 5d8efbfb..48c22214 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -146,6 +146,7 @@ class Runner: if args["stage"] == "dpo": args["dpo_beta"] = get("train.dpo_beta") + args["dpo_ftx"] = get("train.dpo_ftx") if get("train.val_size") > 1e-6 and args["stage"] != "ppo": args["val_size"] = get("train.val_size")