165 lines
5.6 KiB
Python
165 lines
5.6 KiB
Python
# coding=utf-8
|
|
# Copyright 2024 the LlamaFactory team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import json
|
|
import os
|
|
|
|
import fire
|
|
import torch
|
|
import torch.distributed as dist
|
|
from transformers import AutoConfig
|
|
|
|
from llamafactory.train.tuner import run_exp
|
|
|
|
|
|
BASE = 2 # gemm (add + mul)
|
|
|
|
|
|
def compute_model_flops(
|
|
model_name_or_path: str,
|
|
total_batch_size: int,
|
|
seq_length: int,
|
|
include_backward: bool = True,
|
|
include_recompute: bool = False,
|
|
include_flashattn: bool = False,
|
|
) -> int:
|
|
r"""
|
|
Calculates the FLOPs of model per forward/backward pass.
|
|
"""
|
|
config = AutoConfig.from_pretrained(model_name_or_path)
|
|
hidden_size = getattr(config, "hidden_size", None)
|
|
vocab_size = getattr(config, "vocab_size", None)
|
|
intermediate_size = getattr(config, "intermediate_size", None)
|
|
num_attention_heads = getattr(config, "num_attention_heads", None)
|
|
num_key_value_heads = getattr(config, "num_key_value_heads", None)
|
|
num_hidden_layers = getattr(config, "num_hidden_layers", None)
|
|
tie_word_embeddings = getattr(config, "tie_word_embeddings", False)
|
|
|
|
# mlp module
|
|
mlp_flops_per_token = 3 * BASE * hidden_size * intermediate_size # up, gate, down
|
|
mlp_flops = total_batch_size * seq_length * num_hidden_layers * mlp_flops_per_token
|
|
|
|
# attn projector module
|
|
q_flops_per_token = BASE * hidden_size * hidden_size
|
|
o_flops_per_token = BASE * hidden_size * hidden_size
|
|
k_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads
|
|
v_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads
|
|
attn_proj_flops_per_token = q_flops_per_token + o_flops_per_token + k_flops_per_token + v_flops_per_token
|
|
attn_proj_flops = total_batch_size * seq_length * num_hidden_layers * attn_proj_flops_per_token
|
|
|
|
# attn sdpa module
|
|
sdpa_flops_per_layer = 2 * BASE * hidden_size * seq_length * seq_length # (q * k^T) * v
|
|
sdpa_flops = total_batch_size * num_hidden_layers * sdpa_flops_per_layer
|
|
|
|
# embedding module
|
|
embedding_flops_per_token = hidden_size * vocab_size
|
|
embedding_flops = total_batch_size * seq_length * embedding_flops_per_token
|
|
if tie_word_embeddings is False:
|
|
embedding_flops *= 2
|
|
|
|
non_embedding_flops = mlp_flops + attn_proj_flops + sdpa_flops
|
|
non_embedding_coeff, embedding_coeff = 1, 1
|
|
if include_backward:
|
|
non_embedding_coeff += 2
|
|
embedding_coeff += 2
|
|
|
|
if include_recompute:
|
|
non_embedding_coeff += 1
|
|
|
|
total_flops = non_embedding_coeff * non_embedding_flops + embedding_coeff * embedding_flops
|
|
|
|
if include_flashattn:
|
|
total_flops += sdpa_flops
|
|
|
|
return total_flops
|
|
|
|
|
|
def compute_device_flops(world_size: int) -> float:
|
|
r"""
|
|
Calculates the FLOPs of the device capability per second.
|
|
"""
|
|
device_name = torch.cuda.get_device_name()
|
|
if "H100" in device_name or "H800" in device_name:
|
|
return 989 * 1e12 * world_size
|
|
elif "A100" in device_name or "A800" in device_name:
|
|
return 312 * 1e12 * world_size
|
|
elif "V100" in device_name:
|
|
return 125 * 1e12 * world_size
|
|
elif "4090" in device_name:
|
|
return 98 * 1e12 * world_size
|
|
else:
|
|
raise NotImplementedError("Device not supported: {}.".format(device_name))
|
|
|
|
|
|
def calculate_mfu(
|
|
model_name_or_path: str,
|
|
batch_size: int = 1,
|
|
seq_length: int = 1024,
|
|
num_steps: int = 100,
|
|
finetuning_type: str = "lora",
|
|
flash_attn: str = "auto",
|
|
deepspeed_stage: int = 0,
|
|
disable_gc: bool = False,
|
|
liger_kernel: bool = False,
|
|
unsloth_gc: bool = False,
|
|
) -> float:
|
|
r"""
|
|
Calculates MFU for given model and hyper-params.
|
|
Usage: python cal_mfu.py --model_name_or_path path_to_model --batch_size 1 --seq_length 1024
|
|
"""
|
|
args = {
|
|
"model_name_or_path": model_name_or_path,
|
|
"flash_attn": flash_attn,
|
|
"disable_gradient_checkpointing": disable_gc,
|
|
"enable_liger_kernel": liger_kernel,
|
|
"use_unsloth_gc": unsloth_gc,
|
|
"stage": "pt",
|
|
"do_train": True,
|
|
"finetuning_type": finetuning_type,
|
|
"dataset": "c4_demo",
|
|
"cutoff_len": seq_length,
|
|
"output_dir": os.path.join("saves", "test_mfu"),
|
|
"logging_strategy": "no",
|
|
"save_strategy": "no",
|
|
"save_only_model": True,
|
|
"overwrite_output_dir": True,
|
|
"per_device_train_batch_size": batch_size,
|
|
"max_steps": num_steps,
|
|
"bf16": True,
|
|
}
|
|
if deepspeed_stage in [2, 3]:
|
|
args["deepspeed"] = "examples/deepspeed/ds_z{}_config.json".format(deepspeed_stage)
|
|
|
|
run_exp(args)
|
|
with open(os.path.join("saves", "test_mfu", "all_results.json"), "r", encoding="utf-8") as f:
|
|
result = json.load(f)
|
|
|
|
if dist.is_initialized():
|
|
world_size = dist.get_world_size()
|
|
else:
|
|
world_size = 1
|
|
|
|
total_batch_size = batch_size * world_size
|
|
mfu_value = (
|
|
result["train_steps_per_second"]
|
|
* compute_model_flops(model_name_or_path, total_batch_size, seq_length)
|
|
/ compute_device_flops(world_size)
|
|
)
|
|
print("MFU: {:.2f}%".format(mfu_value * 100))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
fire.Fire(calculate_mfu)
|