support activation offloading via unsloth gc
This commit is contained in:
parent
54c6905937
commit
fb72a3adb0
|
@ -109,6 +109,7 @@ def calculate_mfu(
|
|||
deepspeed_stage: int = 0,
|
||||
disable_gc: bool = False,
|
||||
liger_kernel: bool = False,
|
||||
unsloth_gc: bool = False,
|
||||
) -> float:
|
||||
r"""
|
||||
Calculates MFU for given model and hyper-params.
|
||||
|
@ -119,6 +120,7 @@ def calculate_mfu(
|
|||
"flash_attn": flash_attn,
|
||||
"disable_gradient_checkpointing": disable_gc,
|
||||
"enable_liger_kernel": liger_kernel,
|
||||
"use_unsloth_gc": unsloth_gc,
|
||||
"stage": "pt",
|
||||
"do_train": True,
|
||||
"finetuning_type": finetuning_type,
|
||||
|
|
|
@ -215,6 +215,10 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
|
|||
default=False,
|
||||
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
|
||||
)
|
||||
use_unsloth_gc: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use unsloth's gradient checkpointing."},
|
||||
)
|
||||
enable_liger_kernel: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to enable liger kernel for faster training."},
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
# Copyright 2024 HuggingFace Inc., Daniel Han-Chen & the Unsloth team and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's Transformers and PEFT library.
|
||||
# This code is inspired by the HuggingFace's Transformers and PEFT library,
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py
|
||||
# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/utils/other.py
|
||||
# and the Unsloth library.
|
||||
# https://github.com/unslothai/unsloth/blob/July-2024/unsloth/models/_utils.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -19,7 +21,7 @@
|
|||
import inspect
|
||||
from functools import partial
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -36,8 +38,45 @@ if TYPE_CHECKING:
|
|||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class UnslothGradientCheckpointing(torch.autograd.Function):
|
||||
r"""
|
||||
Saves VRAM by smartly offloading to RAM.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_fwd
|
||||
def forward(
|
||||
ctx: "torch.autograd.Function",
|
||||
forward_function: "torch.Module",
|
||||
hidden_states: "torch.Tensor",
|
||||
*args: Union["torch.Tensor", Any],
|
||||
) -> "torch.Tensor":
|
||||
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
|
||||
with torch.no_grad():
|
||||
output = forward_function(hidden_states, *args)
|
||||
|
||||
ctx.save_for_backward(saved_hidden_states)
|
||||
ctx.forward_function = forward_function
|
||||
ctx.args = args
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_bwd
|
||||
def backward(ctx: "torch.autograd.Function", grad_output: "torch.Tensor") -> "torch.Tensor":
|
||||
(hidden_states,) = ctx.saved_tensors
|
||||
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
|
||||
hidden_states.requires_grad_(True)
|
||||
with torch.enable_grad():
|
||||
(output,) = ctx.forward_function(hidden_states, *ctx.args)
|
||||
|
||||
torch.autograd.backward(output, grad_output)
|
||||
return (None, hidden_states.grad) + (None,) * len(ctx.args)
|
||||
|
||||
|
||||
def _gradient_checkpointing_enable(
|
||||
self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
|
||||
self: "PreTrainedModel",
|
||||
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None,
|
||||
use_unsloth_gc: bool = False,
|
||||
) -> None:
|
||||
r"""
|
||||
Activates gradient checkpointing for the current model.
|
||||
|
@ -52,9 +91,12 @@ def _gradient_checkpointing_enable(
|
|||
if gradient_checkpointing_kwargs is None:
|
||||
gradient_checkpointing_kwargs = {"use_reentrant": True}
|
||||
|
||||
if use_unsloth_gc:
|
||||
gradient_checkpointing_func = UnslothGradientCheckpointing.apply
|
||||
else:
|
||||
gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)
|
||||
|
||||
def custom_gradient_checkpointing_func(func, *args, **kwargs):
|
||||
def custom_gradient_checkpointing_func(func, *args: Union["torch.Tensor", Any], **kwargs):
|
||||
module: "torch.nn.Module" = func.__self__
|
||||
|
||||
if any(param.requires_grad for param in module.parameters()):
|
||||
|
@ -97,7 +139,10 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
|
|||
else:
|
||||
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
|
||||
# According to: https://github.com/huggingface/transformers/issues/28339
|
||||
model.gradient_checkpointing_enable = MethodType(_gradient_checkpointing_enable, model)
|
||||
gradient_checkpointing_enable = partial(
|
||||
_gradient_checkpointing_enable, use_unsloth_gc=model_args.use_unsloth_gc
|
||||
)
|
||||
model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model)
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
|
||||
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
|
||||
logger.info("Gradient checkpointing enabled.")
|
||||
|
|
Loading…
Reference in New Issue