tiny fix
This commit is contained in:
parent
60114179eb
commit
c9b3870adb
|
@ -18,6 +18,7 @@ import os
|
|||
|
||||
import fire
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from transformers import AutoConfig
|
||||
|
||||
from llamafactory.train.tuner import run_exp
|
||||
|
@ -28,7 +29,7 @@ BASE = 2 # gemm (add + mul)
|
|||
|
||||
def compute_model_flops(
|
||||
model_name_or_path: str,
|
||||
batch_size: int,
|
||||
total_batch_size: int,
|
||||
seq_length: int,
|
||||
include_backward: bool = True,
|
||||
include_recompute: bool = False,
|
||||
|
@ -48,7 +49,7 @@ def compute_model_flops(
|
|||
|
||||
# mlp module
|
||||
mlp_flops_per_token = 3 * BASE * hidden_size * intermediate_size # up, gate, down
|
||||
mlp_flops = batch_size * seq_length * num_hidden_layers * mlp_flops_per_token
|
||||
mlp_flops = total_batch_size * seq_length * num_hidden_layers * mlp_flops_per_token
|
||||
|
||||
# attn projector module
|
||||
q_flops_per_token = BASE * hidden_size * hidden_size
|
||||
|
@ -56,15 +57,15 @@ def compute_model_flops(
|
|||
k_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads
|
||||
v_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads
|
||||
attn_proj_flops_per_token = q_flops_per_token + o_flops_per_token + k_flops_per_token + v_flops_per_token
|
||||
attn_proj_flops = batch_size * seq_length * num_hidden_layers * attn_proj_flops_per_token
|
||||
attn_proj_flops = total_batch_size * seq_length * num_hidden_layers * attn_proj_flops_per_token
|
||||
|
||||
# attn sdpa module
|
||||
sdpa_flops_per_layer = 2 * BASE * hidden_size * seq_length * seq_length # (q * k^T) * v
|
||||
sdpa_flops = batch_size * num_hidden_layers * sdpa_flops_per_layer
|
||||
sdpa_flops = total_batch_size * num_hidden_layers * sdpa_flops_per_layer
|
||||
|
||||
# embedding module
|
||||
embedding_flops_per_token = hidden_size * vocab_size
|
||||
embedding_flops = batch_size * seq_length * embedding_flops_per_token
|
||||
embedding_flops = total_batch_size * seq_length * embedding_flops_per_token
|
||||
if tie_word_embeddings is False:
|
||||
embedding_flops *= 2
|
||||
|
||||
|
@ -85,17 +86,19 @@ def compute_model_flops(
|
|||
return total_flops
|
||||
|
||||
|
||||
def compute_device_flops() -> float:
|
||||
def compute_device_flops(world_size: int) -> float:
|
||||
r"""
|
||||
Calculates the FLOPs of the device capability per second.
|
||||
"""
|
||||
device_name = torch.cuda.get_device_name()
|
||||
device_count = torch.cuda.device_count()
|
||||
if "H100" in device_name or "H800" in device_name:
|
||||
return 989 * 1e12 * device_count
|
||||
return 989 * 1e12 * world_size
|
||||
elif "A100" in device_name or "A800" in device_name:
|
||||
return 312 * 1e12 * device_count
|
||||
return 312 * 1e12 * world_size
|
||||
elif "V100" in device_name:
|
||||
return 125 * 1e12 * device_count
|
||||
return 125 * 1e12 * world_size
|
||||
elif "4090" in device_name:
|
||||
return 98 * 1e12 * device_count
|
||||
return 98 * 1e12 * world_size
|
||||
else:
|
||||
raise NotImplementedError("Device not supported: {}.".format(device_name))
|
||||
|
||||
|
@ -140,10 +143,16 @@ def calculate_mfu(
|
|||
with open(os.path.join("saves", "test_mfu", "all_results.json"), "r", encoding="utf-8") as f:
|
||||
result = json.load(f)
|
||||
|
||||
if dist.is_initialized():
|
||||
world_size = dist.get_world_size()
|
||||
else:
|
||||
world_size = 1
|
||||
|
||||
total_batch_size = batch_size * world_size
|
||||
mfu_value = (
|
||||
result["train_steps_per_second"]
|
||||
* compute_model_flops(model_name_or_path, batch_size, seq_length)
|
||||
/ compute_device_flops()
|
||||
* compute_model_flops(model_name_or_path, total_batch_size, seq_length)
|
||||
/ compute_device_flops(world_size)
|
||||
)
|
||||
print("MFU: {:.2f}%".format(mfu_value * 100))
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
import inspect
|
||||
from functools import partial, wraps
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -38,48 +38,51 @@ if TYPE_CHECKING:
|
|||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class UnslothGradientCheckpointing(torch.autograd.Function):
|
||||
r"""
|
||||
Saves VRAM by smartly offloading to RAM.
|
||||
"""
|
||||
def get_unsloth_gradient_checkpointing_func() -> Callable:
|
||||
class UnslothGradientCheckpointing(torch.autograd.Function):
|
||||
r"""
|
||||
Saves VRAM by smartly offloading to RAM.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_fwd
|
||||
def forward(
|
||||
ctx: "torch.autograd.Function",
|
||||
forward_function: "torch.Module",
|
||||
hidden_states: "torch.Tensor",
|
||||
*args: Union["torch.Tensor", Any],
|
||||
) -> "torch.Tensor":
|
||||
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
|
||||
with torch.no_grad():
|
||||
output = forward_function(hidden_states, *args)
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_fwd
|
||||
def forward(
|
||||
ctx: "torch.autograd.Function",
|
||||
forward_function: "torch.Module",
|
||||
hidden_states: "torch.Tensor",
|
||||
*args: Union["torch.Tensor", Any],
|
||||
) -> "torch.Tensor":
|
||||
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
|
||||
with torch.no_grad():
|
||||
output = forward_function(hidden_states, *args)
|
||||
|
||||
ctx.save_for_backward(saved_hidden_states)
|
||||
ctx.forward_function = forward_function
|
||||
ctx.args = args
|
||||
return output
|
||||
ctx.save_for_backward(saved_hidden_states)
|
||||
ctx.forward_function = forward_function
|
||||
ctx.args = args
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_bwd
|
||||
def backward(ctx: "torch.autograd.Function", grad_output: "torch.Tensor") -> "torch.Tensor":
|
||||
(hidden_states,) = ctx.saved_tensors
|
||||
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
|
||||
hidden_states.requires_grad_(True)
|
||||
with torch.enable_grad():
|
||||
(output,) = ctx.forward_function(hidden_states, *ctx.args)
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_bwd
|
||||
def backward(ctx: "torch.autograd.Function", grad_output: "torch.Tensor") -> "torch.Tensor":
|
||||
(hidden_states,) = ctx.saved_tensors
|
||||
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
|
||||
hidden_states.requires_grad_(True)
|
||||
with torch.enable_grad():
|
||||
(output,) = ctx.forward_function(hidden_states, *ctx.args)
|
||||
|
||||
torch.autograd.backward(output, grad_output)
|
||||
return (None, hidden_states.grad) + (None,) * len(ctx.args)
|
||||
torch.autograd.backward(output, grad_output)
|
||||
return (None, hidden_states.grad) + (None,) * len(ctx.args)
|
||||
|
||||
return UnslothGradientCheckpointing.apply
|
||||
|
||||
|
||||
def get_custom_gradient_checkpointing_func(gradient_checkpointing_func):
|
||||
def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable) -> Callable:
|
||||
r"""
|
||||
Only applies gradient checkpointing to trainable layers.
|
||||
"""
|
||||
|
||||
@wraps(gradient_checkpointing_func)
|
||||
def custom_gradient_checkpointing_func(func, *args: Union["torch.Tensor", Any], **kwargs):
|
||||
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
|
||||
module: "torch.nn.Module" = func.__self__
|
||||
|
||||
if any(param.requires_grad for param in module.parameters()):
|
||||
|
@ -89,7 +92,7 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func):
|
|||
|
||||
return gradient_checkpointing_func(func, *args, **kwargs)
|
||||
|
||||
if hasattr(gradient_checkpointing_func, "__self__"): # fix test case
|
||||
if hasattr(gradient_checkpointing_func, "__self__"): # fix unsloth gc test case
|
||||
custom_gradient_checkpointing_func.__self__ = gradient_checkpointing_func.__self__
|
||||
|
||||
return custom_gradient_checkpointing_func
|
||||
|
@ -114,7 +117,7 @@ def _gradient_checkpointing_enable(
|
|||
gradient_checkpointing_kwargs = {"use_reentrant": True}
|
||||
|
||||
if use_unsloth_gc:
|
||||
gradient_checkpointing_func = UnslothGradientCheckpointing.apply
|
||||
gradient_checkpointing_func = get_unsloth_gradient_checkpointing_func()
|
||||
else:
|
||||
gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)
|
||||
|
||||
|
|
Loading…
Reference in New Issue