From 4ccad4f50d6396f95dafc8c44e7ddfbe52bf3c67 Mon Sep 17 00:00:00 2001 From: KYLN24 <1296845690@qq.com> Date: Wed, 6 Mar 2024 14:37:44 +0800 Subject: [PATCH 1/8] make lomo installable --- lomo_optim/__init__.py | 5 + lomo_optim/adalomo.py | 334 +++++++++++++++++++++++++++++++++++++ lomo_optim/lomo.py | 370 +++++++++++++++++++++++++++++++++++++++++ pyproject.toml | 28 ++++ 4 files changed, 737 insertions(+) create mode 100644 lomo_optim/__init__.py create mode 100644 lomo_optim/adalomo.py create mode 100644 lomo_optim/lomo.py create mode 100644 pyproject.toml diff --git a/lomo_optim/__init__.py b/lomo_optim/__init__.py new file mode 100644 index 0000000..d5be4ce --- /dev/null +++ b/lomo_optim/__init__.py @@ -0,0 +1,5 @@ +from .adalomo import AdaLomo +from .lomo import Lomo + +__version__ = "0.1.0" +__all__ = ["Lomo", "AdaLomo"] diff --git a/lomo_optim/adalomo.py b/lomo_optim/adalomo.py new file mode 100644 index 0000000..e8ebef3 --- /dev/null +++ b/lomo_optim/adalomo.py @@ -0,0 +1,334 @@ +import math + +import torch +import torch.distributed as dist +from torch.optim import Optimizer + +try: + from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled +except ImportError: + from transformers.deepspeed import is_deepspeed_zero3_enabled + +from transformers.utils import logging + + +class AdaLomo(Optimizer): + """ + 一个自定义的优化器类AdaLomo,用于在分布式训练中的梯度更新。 + + 该类实现两个梯度更新函数 :meth:`fuse_update` 和 :meth:`fuse_update_zero3`,分别用于非ZeRO和ZeRO模式下的梯度更新。 + + :param model: 待优化的模型 + :param lr: 学习率,默认值为1e-3 + :param eps: 正则化系数。eps[0]防止梯度平方太小,eps[1]用于在根据参数的RMS放缩学习率时防止步长太大 + :param clip_threshold: 归一化update矩阵时的阈值 + :param decay_rate: 梯度平方移动平均的衰减率 + :param clip_grad_norm: 梯度裁剪的范数阈值 + + .. note:: + + clip_grad_norm须为正数 + :param clip_grad_value: 梯度裁剪的值域阈值 + :param weight_decay: 权重衰减系数,默认值为0.0 + :param loss_scale: 损失缩放系数,可以用来提高训练精度,但是太大可能会导致nan + """ + + def __init__( + self, + model, + lr=1e-3, + loss_scale=2**10, + eps=(1e-30, 1e-3), + clip_threshold=1.0, + decay_rate=-0.8, + clip_grad_norm=None, + clip_grad_value=None, + weight_decay=0.0, + ): + self.model = model + self.lr = lr + self.clip_grad_norm = clip_grad_norm + self.clip_grad_value = clip_grad_value + self.weight_decay = weight_decay + self.loss_scale = loss_scale + if self.weight_decay > 0.0: + self.do_weight_decay = True + else: + self.do_weight_decay = False + self.eps = eps + self.step_num = 0 + self.decay_rate = decay_rate + self.clip_threshold = clip_threshold + + # for grad norm + if self.clip_grad_norm is not None and self.clip_grad_norm <= 0: + raise ValueError( + f"clip_grad_norm should be positive, got {self.clip_grad_norm}." + ) + self.gather_norm = False + self.grad_norms = [] + self.clip_coef = None + + # check if zero3 is enabled + self.zero3_enabled = is_deepspeed_zero3_enabled() + if self.zero3_enabled: # zero3 is enabled + self.grad_func = self.fuse_update_zero3() + else: + self.grad_func = self.fuse_update() + + self.exp_avg_sq = {} + self.exp_avg_sq_row = {} + self.exp_avg_sq_col = {} + + # register hook function, which will be called through the backward process + for n, p in self.model.named_parameters(): + if len(p.ds_shape) == 1: + self.exp_avg_sq[n] = torch.zeros( + p.ds_shape[0], dtype=torch.float32 + ).cuda() + else: + self.exp_avg_sq_row[n] = torch.zeros( + p.ds_shape[0], dtype=torch.float32 + ).cuda() + self.exp_avg_sq_col[n] = torch.zeros( + p.ds_shape[1], dtype=torch.float32 + ).cuda() + + if p.requires_grad: + p.register_hook(self.grad_func) + defaults = dict( + lr=lr, + eps=eps, + weight_decay=weight_decay, + clip_grad_norm=clip_grad_norm, + clip_grad_value=clip_grad_value, + ) + super(AdaLomo, self).__init__(self.model.parameters(), defaults) + + @staticmethod + def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): + # copy from fairseq's adafactor implementation: + # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505 + r_factor = ( + (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)) + .rsqrt_() + .unsqueeze(-1) + ) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + @staticmethod + def _rms(tensor): + return tensor.norm(2) / (tensor.numel() ** 0.5) + + def fuse_update(self): + """ + 在非ZeRO模式下更新模型参数的梯度。 + + :return: func,一个闭包函数,用于更新模型参数的梯度 + """ + + def func(x): + """ + 闭包函数,用于更新模型参数的梯度。 + """ + with torch.no_grad(): + for n, p in self.model.named_parameters(): + if p.requires_grad and p.grad is not None: + grad_fp32 = p.grad.to(torch.float32) + p.grad = None + if self.loss_scale: + grad_fp32.div_(self.loss_scale) + if self.gather_norm: + # we adopt two backward pass for gradient norm computation and parameter update, respectively. + self.grad_norms.append(torch.norm(grad_fp32, 2.0)) + else: + # grad clip or norm + if ( + self.clip_grad_value is not None + and self.clip_grad_value > 0 + ): + # Clipping gradients by their value + grad_fp32.clamp_( + min=-self.clip_grad_value, max=self.clip_grad_value + ) + if ( + self.clip_grad_norm is not None + and self.clip_grad_norm > 0 + and self.clip_coef is not None + ): + # Normalize the gradient according to its norm (computed in another pass) + grad_fp32.mul_(self.clip_coef) + + beta2t = 1.0 - math.pow(self.step_num, self.decay_rate) + update = (grad_fp32**2) + self.eps[0] + + if len(p.data.shape) > 1: + self.exp_avg_sq_row[n].mul_(beta2t).add_( + update.mean(dim=-1), alpha=1.0 - beta2t + ) + self.exp_avg_sq_col[n].mul_(beta2t).add_( + update.mean(dim=-2), alpha=1.0 - beta2t + ) + update = self._approx_sq_grad( + self.exp_avg_sq_row[n], self.exp_avg_sq_col[n] + ) + update.mul_(grad_fp32) + else: + self.exp_avg_sq[n].mul_(beta2t).add_( + update, alpha=1.0 - beta2t + ) + update = self.exp_avg_sq[n].rsqrt().mul_(grad_fp32) + + update.div_( + (self._rms(update) / self.clip_threshold).clamp_( + min=1.0 + ) + ) + + p_fp32 = p.data.to(torch.float32) + p_rms = torch.norm(p_fp32, 2.0) / math.sqrt(p.numel()) + lr = self.lr + param_scale = max(self.eps[1], p_rms) + lr = lr * param_scale + + if self.do_weight_decay: + p_fp32.mul_(1.0 - lr * self.weight_decay) + p_fp32.add_(update, alpha=-lr) + p.data.copy_(p_fp32) + + return x + + return func + + def fuse_update_zero3(self): + """ + 在ZeRO模式下更新模型参数的梯度。 + + :return: func,一个闭包函数,用于更新模型参数的梯度。 + """ + + def func(x): + with torch.no_grad(): + for n, p in self.model.named_parameters(): + if p.grad is not None: + torch.distributed.all_reduce( + p.grad, op=torch.distributed.ReduceOp.AVG, async_op=False + ) + + grad_fp32 = p.grad.to(torch.float32) + p.grad = None + if self.loss_scale: + grad_fp32.div_(self.loss_scale) + + if self.gather_norm: + # we adopt two backward pass for gradient norm computation and parameter update, respectively. + self.grad_norms.append(torch.norm(grad_fp32, 2.0)) + else: # update param + partition_size = p.ds_tensor.numel() + start = partition_size * self.dp_rank + end = min(start + partition_size, grad_fp32.numel()) + + if self.clip_grad_value is not None: + # Clipping gradients by their value + grad_fp32.clamp_( + min=-self.clip_grad_value, max=self.clip_grad_value + ) + if ( + self.clip_grad_norm is not None + and self.clip_grad_norm > 0 + and self.clip_coef is not None + ): + # Normalize the gradient according to its norm (computed in another pass) + grad_fp32.mul_(self.clip_coef) + + beta2t = 1.0 - math.pow(self.step_num, self.decay_rate) + update = (grad_fp32**2) + self.eps[0] # 改成addcmul_ + + if len(p.ds_shape) > 1: + self.exp_avg_sq_row[n].mul_(beta2t).add_( + update.mean(dim=-1), alpha=1.0 - beta2t + ) + self.exp_avg_sq_col[n].mul_(beta2t).add_( + update.mean(dim=-2), alpha=1.0 - beta2t + ) + update = self._approx_sq_grad( + self.exp_avg_sq_row[n], self.exp_avg_sq_col[n] + ) + update.mul_(grad_fp32) + else: + self.exp_avg_sq[n].mul_(beta2t).add_( + update, alpha=1.0 - beta2t + ) + update = self.exp_avg_sq[n].rsqrt().mul_(grad_fp32) + + update.div_( + (self._rms(update) / self.clip_threshold).clamp_( + min=1.0 + ) + ) + + one_dim_update = update.view(-1) + partitioned_update = one_dim_update.narrow( + 0, start, end - start + ) + param_fp32 = p.ds_tensor.to(torch.float32) + partitioned_p = param_fp32.narrow(0, 0, end - start) + + p_rms = torch.norm(partitioned_p, 2.0) ** 2 + dist.all_reduce(p_rms, op=torch.distributed.ReduceOp.SUM) + p_rms = (p_rms / p.ds_numel).sqrt() + + lr = self.lr + param_scale = max(self.eps[1], p_rms) + lr = lr * param_scale + + if self.do_weight_decay: + partitioned_p.mul_(1.0 - lr * self.weight_decay) + partitioned_p.add_(partitioned_update, alpha=-lr) + p.ds_tensor[: end - start] = partitioned_p + + return x + + return func + + def fused_backward(self, loss, lr): + """ + 执行一步反向传播并更新模型的梯度。 + + :param loss: 模型的loss值 + :param lr: 学习率 + """ + self.lr = lr + if self.loss_scale: + loss = loss * self.loss_scale + self.step_num += 1 + loss.backward() + # update the last parameter since the last parameter in the computaiton graph is not ready when calling hook functions + # the argument of grad_func is just a placeholder, and it can be anything. + self.grad_func(0) + + def grad_norm(self, loss): + """ + 计算梯度的范数。 + + :param loss: 模型的loss值 + """ + self.gather_norm = True + self.grad_norms = [] + if self.loss_scale: + loss = loss * self.loss_scale + loss.backward(retain_graph=True) + # update the last parameter since the last parameter in the computaiton graph is not ready when calling hook functions + # the argument of grad_func is just a placeholder, and it can be anything. + self.grad_func(0) + + with torch.no_grad(): + # The norm is computed over all gradients together, as if they were + # concatenated into a single vector. Gradients are modified in-place. + self.grad_norms = torch.stack(self.grad_norms) + + total_norm = torch.norm(self.grad_norms, 2.0) + self.clip_coef = float(self.clip_grad_norm) / (total_norm + 1e-6) + self.clip_coef = torch.clamp(self.clip_coef, max=1.0) + self.gather_norm = False diff --git a/lomo_optim/lomo.py b/lomo_optim/lomo.py new file mode 100644 index 0000000..c2f877a --- /dev/null +++ b/lomo_optim/lomo.py @@ -0,0 +1,370 @@ +import torch +import torch.distributed as dist +from torch.optim import Optimizer + +try: + from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled +except ImportError: + from transformers.deepspeed import is_deepspeed_zero3_enabled + +from transformers.utils import logging + +logger = logging.get_logger() + + +class Lomo(Optimizer): + """ + 一个自定义的优化器类Lomo,用于在分布式训练中的梯度更新。 + + 该类实现两个梯度更新函数 :meth:`fuse_update` 和 :meth:`fuse_update_zero3`,分别用于非ZeRO和ZeRO模式下的梯度更新。 + + :param model: 待优化的模型 + :param lr: 学习率,默认值为1e-3 + :param clip_grad_norm: 梯度裁剪的范数阈值 + + .. note:: + + clip_grad_norm须为正数 + :param zero3_enabled: 是否开启了 zero3 + :param clip_grad_value: 梯度裁剪的值域阈值 + :param loss_scale_args: 用于初始化 :class:`DynamicLossScaler` 的参数 + """ + + def __init__( + self, + model, + lr=1e-3, + clip_grad_norm=None, + clip_grad_value=None, + weight_decay=0.0, + loss_scale_args={}, + ): + self.model = model + self.lr = lr + self.clip_grad_norm = clip_grad_norm + self.clip_grad_value = clip_grad_value + self.loss_scaler = None + self.loss_scale_args = loss_scale_args + self.weight_decay = weight_decay + if self.weight_decay > 0.0: + self.do_weight_decay = True + else: + self.do_weight_decay = False + + # for grad norm + if self.clip_grad_norm is not None and self.clip_grad_norm <= 0: + raise ValueError( + f"clip_grad_norm should be positive, got {self.clip_grad_norm}." + ) + self.gather_norm = False + self.grad_norms = [] + self.clip_coef = None + + # check if zero3 is enabled + self.zero3_enabled = is_deepspeed_zero3_enabled() + if self.zero3_enabled: # zero3 is enabled + self.grad_func = self.fuse_update_zero3() + else: + self.grad_func = self.fuse_update() + self.first_backward = True # check bf16 or fp16 in the first backward + + # register hook function, which will be called through the backward process + for n, p in self.model.named_parameters(): + if p.requires_grad: + p.register_hook(self.grad_func) + defaults = dict( + lr=lr, clip_grad_norm=clip_grad_norm, clip_grad_value=clip_grad_value + ) + super(Lomo, self).__init__(self.model.parameters(), defaults) + + def fuse_update(self): + """ + 在非ZeRO模式下更新模型参数的梯度。 + + :return: func,一个闭包函数,用于更新模型参数的梯度 + """ + + def func(x): + """ + 闭包函数,用于更新模型参数的梯度。 + """ + with torch.no_grad(): + for n, p in self.model.named_parameters(): + if p.requires_grad and p.grad is not None: + if self.loss_scaler and ( + self.loss_scaler.has_overflow_serial + or self.loss_scaler._has_inf_or_nan(p.grad) + ): + # if the overflow is detected, drop the gradient + p.grad = None + self.loss_scaler.has_overflow_serial = True + break + grad_fp32 = p.grad.to(torch.float32) + p.grad = None + if self.loss_scaler: + grad_fp32.div_(self.loss_scaler.loss_scale) + if self.gather_norm: + # we adopt two backward pass for gradient norm compuation and parameter update, respectively. + self.grad_norms.append(torch.norm(grad_fp32, 2.0)) + else: + if ( + self.clip_grad_value is not None + and self.clip_grad_value > 0 + ): + # Clipping gradients by their value + grad_fp32.clamp_( + min=-self.clip_grad_value, max=self.clip_grad_value + ) + if ( + self.clip_grad_norm is not None + and self.clip_grad_norm > 0 + and self.clip_coef is not None + ): + # Normalize the gradient according to its norm (computed in another pass) + grad_fp32.mul_(self.clip_coef) + p_fp32 = p.data.to(torch.float32) + if self.do_weight_decay: + p_fp32.mul_(1.0 - self.lr * self.weight_decay) + p_fp32.add_(grad_fp32, alpha=-self.lr) + p.data.copy_(p_fp32) + + return x + + return func + + def fuse_update_zero3(self): + """ + 在ZeRO模式下更新模型参数的梯度。 + + :return: func,一个闭包函数,用于更新模型参数的梯度。 + """ + + def func(x): + with torch.no_grad(): + for n, p in self.model.named_parameters(): + if p.grad is not None: + torch.distributed.all_reduce( + p.grad, op=torch.distributed.ReduceOp.AVG, async_op=False + ) + if self.loss_scaler and ( + self.loss_scaler.has_overflow_serial + or self.loss_scaler._has_inf_or_nan(p.grad) + ): + # if the overflow is detected, drop the gradient + p.grad = None + self.loss_scaler.has_overflow_serial = True + break + + grad_fp32 = p.grad.to(torch.float32) + p.grad = None + param_fp32 = p.ds_tensor.to(torch.float32) + if self.loss_scaler: + grad_fp32.div_(self.loss_scaler.loss_scale) + + if self.gather_norm: + # we adopt two backward pass for gradient norm compuation and parameter update, respectively. + self.grad_norms.append(torch.norm(grad_fp32, 2.0)) + else: # update param + one_dim_grad_fp32 = grad_fp32.view(-1) + partition_size = p.ds_tensor.numel() + start = partition_size * dist.get_rank() + end = min(start + partition_size, grad_fp32.numel()) + partitioned_grad_fp32 = one_dim_grad_fp32.narrow( + 0, start, end - start + ) + + if self.clip_grad_value is not None: + # Clipping gradients by their value + partitioned_grad_fp32.clamp_( + min=-self.clip_grad_value, max=self.clip_grad_value + ) + if ( + self.clip_grad_norm is not None + and self.clip_grad_norm > 0 + and self.clip_coef is not None + ): + # Normalize the gradient according to its norm (computed in another pass) + partitioned_grad_fp32.mul_(self.clip_coef) + + partitioned_p = param_fp32.narrow(0, 0, end - start) + if self.do_weight_decay: + partitioned_p.mul_(1.0 - self.lr * self.weight_decay) + partitioned_p.add_(partitioned_grad_fp32, alpha=-self.lr) + p.ds_tensor[: end - start] = partitioned_p + return x + + return func + + def fused_backward(self, loss, lr): + """ + 执行一步反向传播并更新模型的梯度。 + + :param loss: 模型的loss值 + :param lr: 学习率 + """ + if self.first_backward: + self.first_backward = False + if loss.dtype == torch.float16: + self.loss_scaler = DynamicLossScaler(**self.loss_scale_args) + if self.clip_grad_norm is None: + self.clip_grad_norm = 1.0 + logger.warning( + "Loss scale is recommended to be used with grad norm to get better performance. " + "Set grad norm to 1.0." + ) + self.lr = lr + # Users need call grad_norm themselves and then call backward_step + if ( + self.clip_grad_norm is not None + and self.clip_grad_norm > 0 + and self.clip_coef is None + ): + raise ValueError( + "clip_grad_norm is not None, but clip_coef is None. " + "Please call optimizer.grad_norm() before optimizer.fused_backward()." + ) + if self.loss_scaler: + loss = loss * self.loss_scaler.loss_scale + loss.backward() + # update the last parameter since the last parameter in the computaiton graph is not ready when calling hook functions + # the argument of grad_func is just a placeholder, and it can be anything. + self.grad_func(0) + + def grad_norm(self, loss): + """ + 计算梯度的范数。 + + :param loss: 模型的loss值 + """ + if self.first_backward: + self.first_backward = False + if loss.dtype == torch.float16: + self.loss_scaler = DynamicLossScaler(**self.loss_scale_args) + + self.gather_norm = True + self.grad_norms = [] + if self.loss_scaler: + self.loss_scaler.has_overflow_serial = False + loss = loss * self.loss_scaler.loss_scale + loss.backward(retain_graph=True) + # update the last parameter since the last parameter in the computaiton graph is not ready when calling hook functions + # the argument of grad_func is just a placeholder, and it can be anything. + self.grad_func(0) + + if self.loss_scaler and self.loss_scaler.has_overflow_serial: + self.loss_scaler.update_scale(overflow=True) + with torch.no_grad(): # clear gradients + for n, p in self.model.named_parameters(): + p.grad = None + return + + with torch.no_grad(): + # The norm is computed over all gradients together, as if they were + # concatenated into a single vector. Gradients are modified in-place. + self.grad_norms = torch.stack(self.grad_norms) + + total_norm = torch.norm(self.grad_norms, 2.0) + self.clip_coef = float(self.clip_grad_norm) / (total_norm + 1e-6) + self.clip_coef = torch.clamp(self.clip_coef, max=1.0) + self.gather_norm = False + + +class DynamicLossScaler: + """ + 动态loss缩放器,用于在训练过程中动态调整loss的缩放比例。 + + :param init_scale: 初始缩放比例 + :param scale_factor: 缩放因子 + :param scale_window: + :param min_scale: 最小缩放比例,默认为1 + :param delayed_shift: 延迟移位,默认为1 + :param consecutive_hysteresis: 是否启用连续的滞后效应,默认为False。如果是True,在处理梯度溢出时会滞后 :attr:`delayed_shift` 个迭代周期。 + :param raise_error_at_min_scale: 最小缩放比例时是否抛出异常,默认为True + :param dtype: 数据类型,默认为torch.half + """ + + def __init__( + self, + init_scale=2**32, + scale_factor=2.0, + scale_window=1000, + min_scale=1, + delayed_shift=1, + consecutive_hysteresis=False, + raise_error_at_min_scale=True, + dtype=torch.half, + ): + self.cur_scale = init_scale + self.cur_iter = 0 + self.last_overflow_iter = -1 + self.scale_factor = scale_factor + self.scale_window = scale_window + self.min_scale = min_scale + self.delayed_shift = delayed_shift + self.cur_hysteresis = delayed_shift + self.consecutive_hysteresis = consecutive_hysteresis + self.raise_error_at_min_scale = raise_error_at_min_scale + self.dtype = dtype + self.has_overflow_serial = False + + @property + def loss_scale(self): + return self.cur_scale + + # `x` is a torch.Tensor + def _has_inf_or_nan(self, x): + try: + # if x is half, the .float() incurs an additional deep copy, but it's necessary if + # Pytorch's .sum() creates a one-element tensor of the same type as x + # (which is true for some recent version of pytorch). + cpu_sum = float(x.float().sum()) + # More efficient version that can be used if .sum() returns a Python scalar + # cpu_sum = float(x.sum()) + except RuntimeError as instance: + # We want to check if inst is actually an overflow exception. + # RuntimeError could come from a different error. + # If so, we still want the exception to propagate. + if "value cannot be converted" not in instance.args[0]: + raise + return True + else: + if cpu_sum in [float("inf"), -float("inf")] or cpu_sum != cpu_sum: + return True + return False + + # `overflow` is boolean indicating whether the gradient overflowed + def update_scale(self, overflow): + if overflow: + # self.cur_scale /= self.scale_factor + if self.delayed_shift == 1 or self.cur_hysteresis == 1: + if (self.cur_scale == self.min_scale) and self.raise_error_at_min_scale: + raise Exception( + "Current loss scale already at minimum - cannot decrease scale anymore. Exiting run." + ) + else: + next_scale = max(self.cur_scale / self.scale_factor, self.min_scale) + if torch.distributed.get_rank() == 0: + overflow_msg = f"[LOMO] OVERFLOW! Rank {torch.distributed.get_rank()} Skipping step." + if self.dtype == torch.half: + overflow_msg += f" Attempted loss scale: {int(self.cur_scale)}, reducing to {int(next_scale)}" + print(overflow_msg) + self.cur_scale = next_scale + else: + if torch.distributed.get_rank() == 0: + overflow_msg = f"[LOMO] OVERFLOW! Rank {torch.distributed.get_rank()} Skipping step." + if self.dtype == torch.half: + overflow_msg += f" Attempted loss scale: {int(self.cur_scale)}, but hysteresis is {self.cur_hysteresis}. Reducing hysteresis to {self.cur_hysteresis - 1}" + print(overflow_msg) + self.cur_hysteresis -= 1 + self.last_overflow_iter = self.cur_iter + else: + if self.consecutive_hysteresis: + if torch.distributed.get_rank() == 0: + hysteresis_msg = f"Consecutive hysteresis is enabled. Restoring hysteresis to {self.delayed_shift}" + print(hysteresis_msg) + self.cur_hysteresis = self.delayed_shift + if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0: + if not self.consecutive_hysteresis: + self.cur_hysteresis = self.delayed_shift + self.cur_scale *= self.scale_factor + self.cur_iter += 1 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..fda9baa --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,28 @@ +[build-system] +requires = ["setuptools", "setuptools-scm"] +build-backend = "setuptools.build_meta" + +[project] +name = "lomo-optim" +authors = [ + {name = "Kai Lv", email = "klv21@m.fudan.edu.cn"}, +] +description = "LOMO: LOw-Memory Optimization" +readme = "README.md" +requires-python = ">=3.8" +license = {file = "LICENSE"} +classifiers = [ + "Topic :: Scientific/Engineering :: Artificial Intelligence" +] +dependencies = [ + "torch", "transformers" +] +dynamic = ["version"] + +[project.urls] +Homepage = "https://github.com/OpenLMLab/LOMO" +Documentation = "https://openlmlab-collie.readthedocs.io/zh-cn/latest/api/generated/collie.optim.Lomo.html" +Repository = "https://github.com/OpenLMLab/LOMO.git" + +[tool.setuptools] +packages = ["lomo_optim"] From 51f515043539ef1051170e4abc775802c228303e Mon Sep 17 00:00:00 2001 From: KYLN24 <1296845690@qq.com> Date: Wed, 6 Mar 2024 06:50:02 +0000 Subject: [PATCH 2/8] add installation & ignore egg-info --- .gitignore | 1 + README.md | 17 ++++++++++++++++- README_ZH.md | 17 ++++++++++++++++- 3 files changed, 33 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 485dee6..8f9ba3c 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ .idea +lomo_optim.egg-info diff --git a/README.md b/README.md index a52eb80..0a7919d 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,21 @@ and [AdaLomo: Low-memory Optimization with Adaptive Learning Rate](https://arxiv LOMO and AdaLomo are integrated in [CoLLiE](https://github.com/OpenLMLab/collie) library, which supports Collaborative Training of Large Language Models in an Efficient Way. +# Usage + +Install `lomo-optim` from PyPI using pip. + +```bash +pip install lomo-optim +``` + +Then, import `Lomo` or `AdaLomo`. + +```python +from lomo_optim import Lomo +from lomo_optim import AdaLomo +``` + # LOMO: LOw-Memory Optimization In this work, we propose a new optimizer, **LO**w-Memory **O**ptimization (**LOMO**), which fuses the gradient computation and the parameter update in one step to reduce memory usage. @@ -38,4 +53,4 @@ The code for AdaLomo is in [adalomo](adalomo) folder. journal={arXiv preprint arXiv:2306.09782}, year={2023} } -``` \ No newline at end of file +``` diff --git a/README_ZH.md b/README_ZH.md index 262a95b..cd095d0 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -6,6 +6,21 @@ LOMO和AdaLomo已经集成到了 [CoLLiE](https://github.com/OpenLMLab/collie) (Collaborative Training of Large Language Models in an Efficient Way) 中。 +# 使用方法 + +使用 pip 从 PyPI 安装 `lomo-optim` 包。 + +```bash +pip install lomo-optim +``` + +然后,从 `lomo_optim` 中导入 `Lomo` 或 `AdaLomo` + +```python +from lomo_optim import Lomo +from lomo_optim import AdaLomo +``` + # LOMO: LOw-Memory Optimization 在这个工作中,我们提出了一个新的优化器,**LO**w-Memory **O**ptimization (**LOMO**),它将梯度计算和参数更新融合在一步中,以减少内存使用。 @@ -38,4 +53,4 @@ AdaLomo的代码在 [adalomo](adalomo) 文件夹中。 journal={arXiv preprint arXiv:2306.09782}, year={2023} } -``` \ No newline at end of file +``` From b60496adc1dbbc98a68b8ce89ed8d5487f18fa69 Mon Sep 17 00:00:00 2001 From: KYLN24 <1296845690@qq.com> Date: Wed, 6 Mar 2024 07:29:05 +0000 Subject: [PATCH 3/8] ignore build & dist --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 8f9ba3c..9688ef3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ .idea lomo_optim.egg-info +dist +build From 05b636e98e79af6b53d39a83831ab0b4ceeec410 Mon Sep 17 00:00:00 2001 From: KYLN24 <1296845690@qq.com> Date: Wed, 6 Mar 2024 07:29:19 +0000 Subject: [PATCH 4/8] version --- pyproject.toml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index fda9baa..d460492 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ classifiers = [ "Topic :: Scientific/Engineering :: Artificial Intelligence" ] dependencies = [ - "torch", "transformers" + "transformers", "torch" ] dynamic = ["version"] @@ -26,3 +26,6 @@ Repository = "https://github.com/OpenLMLab/LOMO.git" [tool.setuptools] packages = ["lomo_optim"] + +[tool.setuptools.dynamic] +version = {attr = "lomo_optim.__version__"} From e845f87e0bd318c755a12a8aece398d3317a9b1a Mon Sep 17 00:00:00 2001 From: KYLN24 <1296845690@qq.com> Date: Wed, 6 Mar 2024 07:33:29 +0000 Subject: [PATCH 5/8] ignore pycache --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 9688ef3..c8ba4e7 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ lomo_optim.egg-info dist build +__pycache__ From b06a924de7bc2b3c171870f80d86cf1feb8a6c2f Mon Sep 17 00:00:00 2001 From: KYLN24 <1296845690@qq.com> Date: Wed, 6 Mar 2024 07:33:39 +0000 Subject: [PATCH 6/8] add build&publish workflow --- .github/workflows/python-publish.yml | 39 ++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 .github/workflows/python-publish.yml diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml new file mode 100644 index 0000000..9e74726 --- /dev/null +++ b/.github/workflows/python-publish.yml @@ -0,0 +1,39 @@ +# This workflow will upload a Python Package using Twine when a release is created +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries + +# This workflow uses actions that are not certified by GitHub. +# They are provided by a third-party and are governed by +# separate terms of service, privacy policy, and support +# documentation. + +name: Upload Python Package + +on: + release: + types: [published] + +permissions: + contents: read + +jobs: + deploy: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.x' + - name: Install the dependencies + run: | + python -m pip install --upgrade pip + pip install build twine + - name: Build and publish + env: + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} + run: | + python -m build --wheel + twine upload dist/* From cd41adbd8e29c55d8cdb0cfa5d8e3683b1659b02 Mon Sep 17 00:00:00 2001 From: Kai Lv <39761308+KaiLv69@users.noreply.github.com> Date: Wed, 6 Mar 2024 19:25:08 +0800 Subject: [PATCH 7/8] update readme --- README.md | 10 +++++----- README_ZH.md | 6 ++---- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 0a7919d..f516a40 100644 --- a/README.md +++ b/README.md @@ -3,11 +3,8 @@ This is the implementation for [Full Parameter Fine-Tuning for Large Language Models with Limited Resources](https://arxiv.org/pdf/2306.09782.pdf) and [AdaLomo: Low-memory Optimization with Adaptive Learning Rate](https://arxiv.org/pdf/2310.10195.pdf). -LOMO and AdaLomo are integrated in [CoLLiE](https://github.com/OpenLMLab/collie) library, which supports Collaborative Training of Large Language Models in an Efficient Way. - -# Usage - -Install `lomo-optim` from PyPI using pip. +LOMO and AdaLomo are integrated in [CoLLiE](https://github.com/OpenMOSS/collie) library, which supports Collaborative Training of Large Language Models in an Efficient Way. +You can also install `lomo-optim` from PyPI using pip. ```bash pip install lomo-optim @@ -20,6 +17,9 @@ from lomo_optim import Lomo from lomo_optim import AdaLomo ``` +The usage of `Lomo` and `AdaLomo` is similar but not the same as PyTorch's optimizers +([example](https://github.com/OpenMOSS/CoLLiE/blob/726ec80d263c1e1c56344dfde5b3c24897daa94d/collie/controller/trainer.py#L469)). + # LOMO: LOw-Memory Optimization In this work, we propose a new optimizer, **LO**w-Memory **O**ptimization (**LOMO**), which fuses the gradient computation and the parameter update in one step to reduce memory usage. diff --git a/README_ZH.md b/README_ZH.md index cd095d0..30afcf8 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -5,10 +5,7 @@ 论文 [Full Parameter Fine-Tuning for Large Language Models with Limited Resources](https://arxiv.org/pdf/2306.09782.pdf) 和 [AdaLomo: Low-memory Optimization with Adaptive Learning Rate](https://arxiv.org/pdf/2310.10195.pdf) 的实现. LOMO和AdaLomo已经集成到了 [CoLLiE](https://github.com/OpenLMLab/collie) (Collaborative Training of Large Language Models in an Efficient Way) 中。 - -# 使用方法 - -使用 pip 从 PyPI 安装 `lomo-optim` 包。 +也可以使用 pip 从 PyPI 安装 `lomo-optim` 包。 ```bash pip install lomo-optim @@ -20,6 +17,7 @@ pip install lomo-optim from lomo_optim import Lomo from lomo_optim import AdaLomo ``` +`Lomo`和`AdaLomo`的使用方法与PyTorch的优化器类似,但不完全相同([示例](https://github.com/OpenMOSS/CoLLiE/blob/726ec80d263c1e1c56344dfde5b3c24897daa94d/collie/controller/trainer.py#L469))。 # LOMO: LOw-Memory Optimization From 4a5c12fccafd548b0eab681699e6df772e743fe6 Mon Sep 17 00:00:00 2001 From: Kai Lv <39761308+KaiLv69@users.noreply.github.com> Date: Wed, 6 Mar 2024 19:28:10 +0800 Subject: [PATCH 8/8] update readme --- README.md | 1 + README_ZH.md | 1 + 2 files changed, 2 insertions(+) diff --git a/README.md b/README.md index f516a40..0042dfc 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ from lomo_optim import AdaLomo The usage of `Lomo` and `AdaLomo` is similar but not the same as PyTorch's optimizers ([example](https://github.com/OpenMOSS/CoLLiE/blob/726ec80d263c1e1c56344dfde5b3c24897daa94d/collie/controller/trainer.py#L469)). +We recommend to use `AdaLomo` without `gradnorm` to get better performance and higher throughput. # LOMO: LOw-Memory Optimization diff --git a/README_ZH.md b/README_ZH.md index 30afcf8..cae486a 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -18,6 +18,7 @@ from lomo_optim import Lomo from lomo_optim import AdaLomo ``` `Lomo`和`AdaLomo`的使用方法与PyTorch的优化器类似,但不完全相同([示例](https://github.com/OpenMOSS/CoLLiE/blob/726ec80d263c1e1c56344dfde5b3c24897daa94d/collie/controller/trainer.py#L469))。 +推荐使用`AdaLomo`并且不加`gradnorm`来获得更好的性能同时维持更高的吞吐量。 # LOMO: LOw-Memory Optimization