This commit is contained in:
hiyouga 2024-09-08 23:18:08 +08:00
parent 60114179eb
commit c9b3870adb
2 changed files with 59 additions and 47 deletions

View File

@ -18,6 +18,7 @@ import os
import fire
import torch
import torch.distributed as dist
from transformers import AutoConfig
from llamafactory.train.tuner import run_exp
@ -28,7 +29,7 @@ BASE = 2 # gemm (add + mul)
def compute_model_flops(
model_name_or_path: str,
batch_size: int,
total_batch_size: int,
seq_length: int,
include_backward: bool = True,
include_recompute: bool = False,
@ -48,7 +49,7 @@ def compute_model_flops(
# mlp module
mlp_flops_per_token = 3 * BASE * hidden_size * intermediate_size # up, gate, down
mlp_flops = batch_size * seq_length * num_hidden_layers * mlp_flops_per_token
mlp_flops = total_batch_size * seq_length * num_hidden_layers * mlp_flops_per_token
# attn projector module
q_flops_per_token = BASE * hidden_size * hidden_size
@ -56,15 +57,15 @@ def compute_model_flops(
k_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads
v_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads
attn_proj_flops_per_token = q_flops_per_token + o_flops_per_token + k_flops_per_token + v_flops_per_token
attn_proj_flops = batch_size * seq_length * num_hidden_layers * attn_proj_flops_per_token
attn_proj_flops = total_batch_size * seq_length * num_hidden_layers * attn_proj_flops_per_token
# attn sdpa module
sdpa_flops_per_layer = 2 * BASE * hidden_size * seq_length * seq_length # (q * k^T) * v
sdpa_flops = batch_size * num_hidden_layers * sdpa_flops_per_layer
sdpa_flops = total_batch_size * num_hidden_layers * sdpa_flops_per_layer
# embedding module
embedding_flops_per_token = hidden_size * vocab_size
embedding_flops = batch_size * seq_length * embedding_flops_per_token
embedding_flops = total_batch_size * seq_length * embedding_flops_per_token
if tie_word_embeddings is False:
embedding_flops *= 2
@ -85,17 +86,19 @@ def compute_model_flops(
return total_flops
def compute_device_flops() -> float:
def compute_device_flops(world_size: int) -> float:
r"""
Calculates the FLOPs of the device capability per second.
"""
device_name = torch.cuda.get_device_name()
device_count = torch.cuda.device_count()
if "H100" in device_name or "H800" in device_name:
return 989 * 1e12 * device_count
return 989 * 1e12 * world_size
elif "A100" in device_name or "A800" in device_name:
return 312 * 1e12 * device_count
return 312 * 1e12 * world_size
elif "V100" in device_name:
return 125 * 1e12 * device_count
return 125 * 1e12 * world_size
elif "4090" in device_name:
return 98 * 1e12 * device_count
return 98 * 1e12 * world_size
else:
raise NotImplementedError("Device not supported: {}.".format(device_name))
@ -140,10 +143,16 @@ def calculate_mfu(
with open(os.path.join("saves", "test_mfu", "all_results.json"), "r", encoding="utf-8") as f:
result = json.load(f)
if dist.is_initialized():
world_size = dist.get_world_size()
else:
world_size = 1
total_batch_size = batch_size * world_size
mfu_value = (
result["train_steps_per_second"]
* compute_model_flops(model_name_or_path, batch_size, seq_length)
/ compute_device_flops()
* compute_model_flops(model_name_or_path, total_batch_size, seq_length)
/ compute_device_flops(world_size)
)
print("MFU: {:.2f}%".format(mfu_value * 100))

View File

@ -21,7 +21,7 @@
import inspect
from functools import partial, wraps
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
import torch
@ -38,48 +38,51 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
class UnslothGradientCheckpointing(torch.autograd.Function):
r"""
Saves VRAM by smartly offloading to RAM.
"""
def get_unsloth_gradient_checkpointing_func() -> Callable:
class UnslothGradientCheckpointing(torch.autograd.Function):
r"""
Saves VRAM by smartly offloading to RAM.
"""
@staticmethod
@torch.cuda.amp.custom_fwd
def forward(
ctx: "torch.autograd.Function",
forward_function: "torch.Module",
hidden_states: "torch.Tensor",
*args: Union["torch.Tensor", Any],
) -> "torch.Tensor":
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
with torch.no_grad():
output = forward_function(hidden_states, *args)
@staticmethod
@torch.cuda.amp.custom_fwd
def forward(
ctx: "torch.autograd.Function",
forward_function: "torch.Module",
hidden_states: "torch.Tensor",
*args: Union["torch.Tensor", Any],
) -> "torch.Tensor":
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
with torch.no_grad():
output = forward_function(hidden_states, *args)
ctx.save_for_backward(saved_hidden_states)
ctx.forward_function = forward_function
ctx.args = args
return output
ctx.save_for_backward(saved_hidden_states)
ctx.forward_function = forward_function
ctx.args = args
return output
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx: "torch.autograd.Function", grad_output: "torch.Tensor") -> "torch.Tensor":
(hidden_states,) = ctx.saved_tensors
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
hidden_states.requires_grad_(True)
with torch.enable_grad():
(output,) = ctx.forward_function(hidden_states, *ctx.args)
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx: "torch.autograd.Function", grad_output: "torch.Tensor") -> "torch.Tensor":
(hidden_states,) = ctx.saved_tensors
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
hidden_states.requires_grad_(True)
with torch.enable_grad():
(output,) = ctx.forward_function(hidden_states, *ctx.args)
torch.autograd.backward(output, grad_output)
return (None, hidden_states.grad) + (None,) * len(ctx.args)
torch.autograd.backward(output, grad_output)
return (None, hidden_states.grad) + (None,) * len(ctx.args)
return UnslothGradientCheckpointing.apply
def get_custom_gradient_checkpointing_func(gradient_checkpointing_func):
def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable) -> Callable:
r"""
Only applies gradient checkpointing to trainable layers.
"""
@wraps(gradient_checkpointing_func)
def custom_gradient_checkpointing_func(func, *args: Union["torch.Tensor", Any], **kwargs):
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
module: "torch.nn.Module" = func.__self__
if any(param.requires_grad for param in module.parameters()):
@ -89,7 +92,7 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func):
return gradient_checkpointing_func(func, *args, **kwargs)
if hasattr(gradient_checkpointing_func, "__self__"): # fix test case
if hasattr(gradient_checkpointing_func, "__self__"): # fix unsloth gc test case
custom_gradient_checkpointing_func.__self__ = gradient_checkpointing_func.__self__
return custom_gradient_checkpointing_func
@ -114,7 +117,7 @@ def _gradient_checkpointing_enable(
gradient_checkpointing_kwargs = {"use_reentrant": True}
if use_unsloth_gc:
gradient_checkpointing_func = UnslothGradientCheckpointing.apply
gradient_checkpointing_func = get_unsloth_gradient_checkpointing_func()
else:
gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)