diff --git a/docs/source/asr/models.rst b/docs/source/asr/models.rst index 708d66307dd3..ce58e1d7fdbb 100644 --- a/docs/source/asr/models.rst +++ b/docs/source/asr/models.rst @@ -218,7 +218,7 @@ You may find FastConformer variants of cache-aware streaming models under ``/scripts/export.py`` is being used: -`python export.py cache_aware_conformer.nemo cache_aware_conformer.onnx --config cache_support=True` +`python export.py cache_aware_conformer.nemo cache_aware_conformer.onnx --export-config cache_support=True` .. _LSTM-Transducer_model: @@ -299,7 +299,7 @@ Similar example configs for FastConformer variants of Hybrid models can be found Note Hybrid models are being exported as RNNT (encoder and decoder+joint parts) by default. To export as CTC (single encoder+decoder graph), `model.set_export_config({'decoder_type' : 'ctc'})` should be called before export. Or, if ``/scripts/export.py`` is being used: -`python export.py hybrid_transducer.nemo hybrid_transducer.onnx --config decoder_type=ctc` +`python export.py hybrid_transducer.nemo hybrid_transducer.onnx --export-config decoder_type=ctc` .. _Conformer-HAT_model: diff --git a/docs/source/core/export.rst b/docs/source/core/export.rst index f54daffe9c9c..202099b13d66 100644 --- a/docs/source/core/export.rst +++ b/docs/source/core/export.rst @@ -207,7 +207,7 @@ An example can be found in ``/nemo/collections/asr/models/rnnt_mo Here is example on now `set_export_config()` call is being tied to command line arguments in ``/scripts/export.py`` : .. code-block:: Python - python scripts/export.py hybrid_conformer.nemo hybrid_conformer.onnx --config decoder_type=ctc + python scripts/export.py hybrid_conformer.nemo hybrid_conformer.onnx --export-config decoder_type=ctc Exportable Model Code ~~~~~~~~~~~~~~~~~~~~~ diff --git a/nemo/collections/tts/models/base.py b/nemo/collections/tts/models/base.py index 8ef147b9b145..fe19ae75a3b3 100644 --- a/nemo/collections/tts/models/base.py +++ b/nemo/collections/tts/models/base.py @@ -68,6 +68,18 @@ def list_available_models(cls) -> 'List[PretrainedModelInfo]': list_of_models.extend(subclass_models) return list_of_models + def set_export_config(self, args): + for k in ['enable_volume', 'enable_ragged_batches']: + if k in args: + self.export_config[k] = bool(args[k]) + args.pop(k) + if 'num_speakers' in args: + self.export_config['num_speakers'] = int(args['num_speakers']) + args.pop('num_speakers') + if 'emb_range' in args: + raise Exception('embedding range is not user-settable') + super().set_export_config(args) + class Vocoder(ModelPT, ABC): """ diff --git a/scripts/export.py b/scripts/export.py index 4b21bc4ffd73..8fa44bb305f9 100644 --- a/scripts/export.py +++ b/scripts/export.py @@ -63,7 +63,7 @@ def get_args(argv): parser.add_argument("--device", default="cuda", help="Device to export for") parser.add_argument("--check-tolerance", type=float, default=0.01, help="tolerance for verification") parser.add_argument( - "--config", + "--export-config", metavar="KEY=VALUE", nargs='+', help="Set a number of key-value pairs to model.export_config dictionary " @@ -142,8 +142,14 @@ def nemo_export(argv): if args.cache_support: model.set_export_config({"cache_support": "True"}) - if args.config: - kv = dict(map(lambda s: s.split('='), args.config)) + if args.export_config: + kv = {} + for key_value in args.export_config: + lst = key_value.split("=") + if len(lst) != 2: + raise Exception("Use correct format for --export_config: k=v") + k, v = lst + kv[k] = v model.set_export_config(kv) autocast = nullcontext diff --git a/tests/collections/tts/test_tts_exportables.py b/tests/collections/tts/test_tts_exportables.py index 05b23e6afb1b..67f016b0c2af 100644 --- a/tests/collections/tts/test_tts_exportables.py +++ b/tests/collections/tts/test_tts_exportables.py @@ -54,8 +54,7 @@ def radtts_model(): model = RadTTSModel(cfg=cfg.model) app_state.is_model_being_restored = False model.eval() - model.export_config['enable_ragged_batches'] = True - model.export_config['enable_volume'] = True + model.set_export_config({'enable_ragged_batches': 'True', 'enable_volume': 'True'}) return model