add test case

This commit is contained in:
hiyouga 2024-09-08 01:40:49 +08:00
parent fb72a3adb0
commit 52a06efaf8
2 changed files with 28 additions and 12 deletions

View File

@ -19,7 +19,7 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
from functools import partial from functools import partial, wraps
from types import MethodType from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
@ -73,6 +73,25 @@ class UnslothGradientCheckpointing(torch.autograd.Function):
return (None, hidden_states.grad) + (None,) * len(ctx.args) return (None, hidden_states.grad) + (None,) * len(ctx.args)
def get_custom_gradient_checkpointing_func(gradient_checkpointing_func):
r"""
Only applies gradient checkpointing to trainable layers.
"""
@wraps(gradient_checkpointing_func)
def custom_gradient_checkpointing_func(func, *args: Union["torch.Tensor", Any], **kwargs):
module: "torch.nn.Module" = func.__self__
if any(param.requires_grad for param in module.parameters()):
for arg in args:
if torch.is_tensor(arg) and torch.is_floating_point(arg):
arg.requires_grad_(True)
return gradient_checkpointing_func(func, *args, **kwargs)
return custom_gradient_checkpointing_func
def _gradient_checkpointing_enable( def _gradient_checkpointing_enable(
self: "PreTrainedModel", self: "PreTrainedModel",
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None, gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None,
@ -96,22 +115,13 @@ def _gradient_checkpointing_enable(
else: else:
gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs) gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)
def custom_gradient_checkpointing_func(func, *args: Union["torch.Tensor", Any], **kwargs): gradient_checkpointing_func = get_custom_gradient_checkpointing_func(gradient_checkpointing_func)
module: "torch.nn.Module" = func.__self__
if any(param.requires_grad for param in module.parameters()):
for arg in args:
if torch.is_tensor(arg) and torch.is_floating_point(arg):
arg.requires_grad_(True)
return gradient_checkpointing_func(func, *args, **kwargs)
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
self.apply(partial(self._set_gradient_checkpointing, value=True)) self.apply(partial(self._set_gradient_checkpointing, value=True))
self.enable_input_require_grads() self.enable_input_require_grads()
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.") logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
else: # have already enabled input require gradients else: # have already enabled input require gradients
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func) self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
def _fp32_forward_post_hook( def _fp32_forward_post_hook(

View File

@ -51,6 +51,12 @@ def test_checkpointing_disable():
assert getattr(module, "gradient_checkpointing") is False assert getattr(module, "gradient_checkpointing") is False
def test_unsloth_gradient_checkpointing():
model = load_train_model(use_unsloth_gc=True, **TRAIN_ARGS)
for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()):
assert module._gradient_checkpointing_func.__self__.__name__ == "UnslothGradientCheckpointing"
def test_upcast_layernorm(): def test_upcast_layernorm():
model = load_train_model(upcast_layernorm=True, **TRAIN_ARGS) model = load_train_model(upcast_layernorm=True, **TRAIN_ARGS)
for name, param in model.named_parameters(): for name, param in model.named_parameters():