Skip to content

Commit

Permalink
NIM supporting changes for nemo.export for NeMo 2.0 (part II) (#11669)
Browse files Browse the repository at this point in the history
* Remove trt_compile from __init__ as it triggers imports from nemo.utils

Signed-off-by: Jan Lasek <[email protected]>

* Get tokenizer for NeMo 2 from model.yaml using local SP or HF classes

Signed-off-by: Jan Lasek <[email protected]>

* Apply isort and black reformatting

Signed-off-by: janekl <[email protected]>

---------

Signed-off-by: Jan Lasek <[email protected]>
Signed-off-by: janekl <[email protected]>
Co-authored-by: janekl <[email protected]>
  • Loading branch information
2 people authored and BoxiangW committed Dec 23, 2024
1 parent 054bd46 commit fc54cee
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 7 deletions.
2 changes: 0 additions & 2 deletions nemo/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo.export.tensorrt_lazy_compiler import trt_compile
57 changes: 52 additions & 5 deletions nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@
from nemo.export.tarutils import TarPath, ZarrPathStore
from nemo.export.tiktoken_tokenizer import TiktokenTokenizer

try:
from nemo.lightning import io

HAVE_NEMO2 = True
except (ImportError, ModuleNotFoundError):
HAVE_NEMO2 = False

LOGGER = logging.getLogger("NeMo")


Expand Down Expand Up @@ -289,14 +296,54 @@ def copy_tokenizer_files(config, out_dir):
return config


def get_tokenizer_from_nemo2_context(model_context_dir: Path):
"""
Retrieve tokenizer configuration from NeMo 2.0 context and instantiate the tokenizer.
Args:
model_context_dir (Path): Path to the model context directory.
Returns:
The instantiated tokenizer (various classes possible).
"""

if HAVE_NEMO2:
# Use NeMo tokenizer loaded from the NeMo 2.0 model context
tokenizer_spec = io.load_context(model_context_dir, subpath="model.tokenizer")
return build_tokenizer(tokenizer_spec)
else:
# Use local nemo.export SentencePieceTokenizer implementation
# or directly a HuggingFace tokenizer based on the model config
with (model_context_dir / "model.yaml").open("r") as stream:
model_config = yaml.safe_load(stream)

tokenizer_config = model_config["tokenizer"]
target_class = tokenizer_config["_target_"]
tokenizer_module = "nemo.collections.common.tokenizers."
assert target_class.startswith(tokenizer_module)
target_class = target_class.removeprefix(tokenizer_module)

if target_class == "sentencepiece_tokenizer.SentencePieceTokenizer":
tokenizer = SentencePieceTokenizer(
model_path=str(model_context_dir / tokenizer_config["model_path"]),
special_tokens=tokenizer_config.get("special_tokens", None),
legacy=tokenizer_config.get("legacy", False),
)
elif target_class == "huggingface.auto_tokenizer.AutoTokenizer":
tokenizer = AutoTokenizer.from_pretrained(
str(model_context_dir / tokenizer_config["pretrained_model_name"])
)
else:
raise ValueError(f"Unsupported tokenizer type: {tokenizer_module}{target_class}.")

return tokenizer


def get_tokenizer(tokenizer_dir_or_path: Union[str, Path]) -> PreTrainedTokenizer:
"""Loads the tokenizer from the decoded NeMo weights dir."""
tokenizer_dir_or_path = Path(tokenizer_dir_or_path)
if (tokenizer_dir_or_path / "nemo_context").exists():
from nemo.lightning import io

tokenizer_spec = io.load_context((tokenizer_dir_or_path / "nemo_context"), subpath="model.tokenizer")
return build_tokenizer(tokenizer_spec)
return get_tokenizer_from_nemo2_context(tokenizer_dir_or_path / "nemo_context")
elif os.path.exists(os.path.join(tokenizer_dir_or_path, "vocab.json")):
vocab_path = tokenizer_dir_or_path / "vocab.json" if tokenizer_dir_or_path.is_dir() else tokenizer_dir_or_path
tokenizer_config = {"library": "tiktoken", "vocab_file": str(vocab_path)}
Expand Down Expand Up @@ -476,7 +523,7 @@ def load_nemo_model(nemo_ckpt: Union[str, Path], nemo_export_dir: Union[str, Pat
elif k == "activation_func":
nemo_model_config["activation"] = v["_target_"].rsplit('.', 1)[-1]
else:
from nemo.lightning import io
assert HAVE_NEMO2, "nemo_toolkit>=2.0.0 is required to load the model context."

config = io.load_context(io_folder, subpath="model.config")

Expand Down

0 comments on commit fc54cee

Please sign in to comment.