Skip to content

Commit

Permalink
ONNX exporter (#12242)
Browse files Browse the repository at this point in the history
* ONNX exporter and tutorial for llama embedding

Signed-off-by: Onur Yilmaz <[email protected]>

* Quantization added

Signed-off-by: Onur Yilmaz <[email protected]>

* Apply isort and black reformatting

Signed-off-by: oyilmaz-nvidia <[email protected]>

* reorg the onnx exporter

Signed-off-by: Onur Yilmaz <[email protected]>

* Fixing the issue with dimension input

Signed-off-by: Onur Yilmaz <[email protected]>

* Moved code from notebook to nemo

Signed-off-by: Onur Yilmaz <[email protected]>

* Replaced prints with logger

Signed-off-by: Onur Yilmaz <[email protected]>

* docstrings added

Signed-off-by: Onur Yilmaz <[email protected]>

* Fixing small flake8 issue

Signed-off-by: Onur Yilmaz <[email protected]>

* Remove ptq for nw

Signed-off-by: Onur Yilmaz <[email protected]>

* Apply isort and black reformatting

Signed-off-by: oyilmaz-nvidia <[email protected]>

* Addressing the feedback

Signed-off-by: Onur Yilmaz <[email protected]>

* Added trt flag for version compatible trt engine

Signed-off-by: Onur Yilmaz <[email protected]>

* Update notebook for version compatible trt engine

Signed-off-by: Onur Yilmaz <[email protected]>

* move trt import

Signed-off-by: Onur Yilmaz <[email protected]>

* style fix

Signed-off-by: Onur Yilmaz <[email protected]>

---------

Signed-off-by: Onur Yilmaz <[email protected]>
Signed-off-by: oyilmaz-nvidia <[email protected]>
Signed-off-by: Onur Yilmaz <[email protected]>
Co-authored-by: oyilmaz-nvidia <[email protected]>
  • Loading branch information
oyilmaz-nvidia and oyilmaz-nvidia authored Mar 4, 2025
1 parent 8773f99 commit 110f80b
Show file tree
Hide file tree
Showing 6 changed files with 866 additions and 3 deletions.
2 changes: 2 additions & 0 deletions nemo/collections/llm/gpt/model/__init__.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
Gemma2Model,
)
from nemo.collections.llm.gpt.model.hf_auto_model_for_causal_lm import HFAutoModelForCausalLM
from nemo.collections.llm.gpt.model.hf_llama_embedding import get_llama_bidirectional_hf_model
from nemo.collections.llm.gpt.model.hyena import (
Hyena1bConfig,
Hyena7bARCLongContextConfig,
Expand Down Expand Up @@ -222,6 +223,7 @@
"transformer_engine_full_layer_spec",
"local_layer_spec",
"HFAutoModelForCausalLM",
"get_llama_bidirectional_hf_model",
"HyenaTestConfig",
"Hyena1bConfig",
"HyenaNV1bConfig",
Expand Down
140 changes: 139 additions & 1 deletion nemo/collections/llm/gpt/model/hf_llama_embedding.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional, Tuple, Union
import os
from typing import List, Literal, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers import AutoModel, AutoTokenizer
from transformers.cache_utils import Cache
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
from transformers.modeling_outputs import SequenceClassifierOutputWithPast
Expand Down Expand Up @@ -188,3 +191,138 @@ def forward(
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)


class LlamaBidirectionalHFAdapter(torch.nn.Module):
"""Wraps a Text embedding model with pooling and normalization."""

def __init__(
self,
model: torch.nn.Module,
normalize: bool,
pooling_module: torch.nn.Module,
) -> None:
super().__init__()
self.model = model
self.normalize = normalize
self.pooling_module = pooling_module

@property
def device(self) -> torch.device:
"""Returns the device"""

return self.model.device

def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
dimensions: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Inference for the adapted Llama model"""

inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
if token_type_ids is not None:
inputs["token_type_ids"] = token_type_ids
outputs = self.model(**inputs)
hidden_states = outputs["last_hidden_state"].to(torch.float32)
embeddings = self.pooling_module(hidden_states, inputs["attention_mask"])

if dimensions is not None:
if not torch.all(dimensions > 0):
raise ValueError("Dimensions must be positive")

fill_value = torch.tensor(float("-inf"), dtype=embeddings.dtype, device=embeddings.device)

clipped_dimensions = torch.where(
dimensions < embeddings.shape[1],
dimensions,
torch.tensor(embeddings.shape[1], device=embeddings.device),
)

embeddings = embeddings.masked_fill(
torch.arange(embeddings.shape[1], device=embeddings.device) >= clipped_dimensions.unsqueeze(-1),
fill_value,
)[:, : dimensions.max()]

if self.normalize:
embeddings = F.normalize(embeddings, p=2, dim=1)

return embeddings


class Pooling(torch.nn.Module):
"""Pooling layer for the adapter."""

def __init__(self, pooling_mode: str):
super().__init__()
self.pooling_mode = pooling_mode

def forward(self, last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""Forward function of the Pooling layer."""

last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)

pool_type = self.pooling_mode
if pool_type == "avg":
epsilon = 1e-9 # A small value to avoid division by zero
emb = last_hidden.sum(dim=1) / (attention_mask.sum(dim=1)[..., None] + epsilon)
elif pool_type == "cls": # tokenizer padding right
emb = last_hidden[:, 0]
elif pool_type == "cls__left": # tokenizer padding left
seq_idxs = (1 - attention_mask).sum(dim=1)
batch_size = last_hidden.shape[0]
batch_idxs = torch.arange(batch_size, device=last_hidden.device)
emb = last_hidden[batch_idxs, seq_idxs]
elif pool_type == "last": # tokenizer padding left
emb = last_hidden[:, -1]
elif pool_type == "last__right": # tokenizer padding right
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden.shape[0]
emb = last_hidden[torch.arange(batch_size, device=last_hidden.device), sequence_lengths]
else:
raise ValueError(f"pool_type {pool_type} not supported")

return emb


def get_llama_bidirectional_hf_model(
model_name_or_path: Union[str, os.PathLike[str]],
normalize: bool,
pooling_mode: Optional[Literal["avg", "cls", "last"]] = None,
torch_dtype: Optional[Union[torch.dtype, str]] = None,
trust_remote_code: bool = False,
):
"""Returns the adapter for the Llama bidirectional HF model."""

# check that the tokenizer matches the requirements of the pooling mode
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code)
pooling_mode = pooling_mode or "avg"
if pooling_mode == "last" and tokenizer.padding_side == "right":
pooling_mode = "last__right" # type: ignore
if pooling_mode == "cls" and tokenizer.padding_side == "left":
pooling_mode = "cls__left" # type: ignore

# load the model
model = AutoModel.from_pretrained(
model_name_or_path, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code
).eval()

# configure pooling
pooling_module = Pooling(pooling_mode=pooling_mode)

# NV-Embed-v1 model has seperate embedding model and a built-in pooling module
if (
model.__class__.__name__ == "NVEmbedModel"
and hasattr(model, "latent_attention_model")
and hasattr(model, "embedding_model")
):
pooling_module = model.latent_attention_model
model = model.embedding_model

adapted_model = LlamaBidirectionalHFAdapter(model=model, normalize=normalize, pooling_module=pooling_module)
return adapted_model, tokenizer
Loading

0 comments on commit 110f80b

Please sign in to comment.