support activation offloading via unsloth gc

This commit is contained in:
hiyouga 2024-09-08 01:22:19 +08:00
parent 54c6905937
commit fb72a3adb0
3 changed files with 58 additions and 7 deletions

View File

@ -109,6 +109,7 @@ def calculate_mfu(
deepspeed_stage: int = 0,
disable_gc: bool = False,
liger_kernel: bool = False,
unsloth_gc: bool = False,
) -> float:
r"""
Calculates MFU for given model and hyper-params.
@ -119,6 +120,7 @@ def calculate_mfu(
"flash_attn": flash_attn,
"disable_gradient_checkpointing": disable_gc,
"enable_liger_kernel": liger_kernel,
"use_unsloth_gc": unsloth_gc,
"stage": "pt",
"do_train": True,
"finetuning_type": finetuning_type,

View File

@ -215,6 +215,10 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
default=False,
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
)
use_unsloth_gc: bool = field(
default=False,
metadata={"help": "Whether or not to use unsloth's gradient checkpointing."},
)
enable_liger_kernel: bool = field(
default=False,
metadata={"help": "Whether or not to enable liger kernel for faster training."},

View File

@ -1,8 +1,10 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# Copyright 2024 HuggingFace Inc., Daniel Han-Chen & the Unsloth team and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's Transformers and PEFT library.
# This code is inspired by the HuggingFace's Transformers and PEFT library,
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py
# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/utils/other.py
# and the Unsloth library.
# https://github.com/unslothai/unsloth/blob/July-2024/unsloth/models/_utils.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -19,7 +21,7 @@
import inspect
from functools import partial
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
import torch
@ -36,8 +38,45 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
class UnslothGradientCheckpointing(torch.autograd.Function):
r"""
Saves VRAM by smartly offloading to RAM.
"""
@staticmethod
@torch.cuda.amp.custom_fwd
def forward(
ctx: "torch.autograd.Function",
forward_function: "torch.Module",
hidden_states: "torch.Tensor",
*args: Union["torch.Tensor", Any],
) -> "torch.Tensor":
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
with torch.no_grad():
output = forward_function(hidden_states, *args)
ctx.save_for_backward(saved_hidden_states)
ctx.forward_function = forward_function
ctx.args = args
return output
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx: "torch.autograd.Function", grad_output: "torch.Tensor") -> "torch.Tensor":
(hidden_states,) = ctx.saved_tensors
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
hidden_states.requires_grad_(True)
with torch.enable_grad():
(output,) = ctx.forward_function(hidden_states, *ctx.args)
torch.autograd.backward(output, grad_output)
return (None, hidden_states.grad) + (None,) * len(ctx.args)
def _gradient_checkpointing_enable(
self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
self: "PreTrainedModel",
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None,
use_unsloth_gc: bool = False,
) -> None:
r"""
Activates gradient checkpointing for the current model.
@ -52,9 +91,12 @@ def _gradient_checkpointing_enable(
if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {"use_reentrant": True}
gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)
if use_unsloth_gc:
gradient_checkpointing_func = UnslothGradientCheckpointing.apply
else:
gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)
def custom_gradient_checkpointing_func(func, *args, **kwargs):
def custom_gradient_checkpointing_func(func, *args: Union["torch.Tensor", Any], **kwargs):
module: "torch.nn.Module" = func.__self__
if any(param.requires_grad for param in module.parameters()):
@ -97,7 +139,10 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
else:
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
# According to: https://github.com/huggingface/transformers/issues/28339
model.gradient_checkpointing_enable = MethodType(_gradient_checkpointing_enable, model)
gradient_checkpointing_enable = partial(
_gradient_checkpointing_enable, use_unsloth_gc=model_args.use_unsloth_gc
)
model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.")