Skip to content

Commit

Permalink
Make Texar-TF compatible with Texar-PyTorch (#183)
Browse files Browse the repository at this point in the history
* Lazily load Texar-TF modules so these and TensorFlow will not be loaded when user only imports texar.torch.
* Fix cyclic dependencies between texar.hyperparams and texar.utils.
* Fix library imports that are not compatible with lazy loading.
  • Loading branch information
ZhitingHu authored Jul 29, 2019
2 parents 71d3a9e + 4705fa9 commit 4e5deef
Show file tree
Hide file tree
Showing 10 changed files with 113 additions and 22 deletions.
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@

### New features

### Feature improvements

* Use lazy import to be compatible with [texar-pytorch](https://github.com/asyml/texar-pytorch). ([#183](https://github.com/asyml/texar/pull/183))

### Fixes

## [v0.2.1](https://github.com/asyml/texar/releases/tag/v0.2.1) (2019-07-28)

### New features

* Add support for GPT-2 345M model in [examples/gpt-2](https://github.com/asyml/texar/tree/master/examples/gpt-2). ([#156](https://github.com/asyml/texar/pull/156))
* Add BERT modules, including `texar.modules.BERTEncoder` ([doc](https://texar.readthedocs.io/en/latest/code/modules.html#texar.modules.BertEncoder)) and `texar.modules.BERTClassifier` ([doc](https://texar.readthedocs.io/en/latest/code/modules.html#bertclassifierv)). ([#167](https://github.com/asyml/texar/pull/167))

Expand Down
109 changes: 95 additions & 14 deletions texar/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,98 @@

# pylint: disable=wildcard-import

from texar.version import VERSION as __version__

from texar.module_base import *
from texar.hyperparams import *
from texar.context import *
from texar import modules
from texar import core
from texar import losses
from texar import models
from texar import data
from texar import evals
from texar import agents
from texar import run
from texar import utils
import sys

if sys.version_info.major < 3:
# PY 2.x, import as is because Texar-PyTorch cannot be installed.
from texar.version import VERSION as __version__

from texar.module_base import *
from texar.hyperparams import *
from texar.context import *
from texar import modules
from texar import core
from texar import losses
from texar import models
from texar import data
from texar import evals
from texar import agents
from texar import run
from texar import utils
else:
# Lazily load Texar-TF modules upon usage. This is to ensure that Texar-TF
# and TensorFlow will not be imported if the user only requires
# Texar-PyTorch modules from `texar.torch`.
#
# Due to the lazy loading mechanism, it is now impossible to write
# `from texar import <module>` within library code (i.e., code that will be
# accessible from the `texar` module). Please use the following workarounds
# instead:
#
# 1. To import a class / function that is directly accessible from `texar`,
# import them from their containing modules. For instance:
#
# `from texar import HParams`
# -> `from texar.hyperparams import HParams`
# `from texar import ModuleBase`
# -> `from texar.module_base import ModuleBase`
# 2. To import a module that is directly accessible from `texar`, use the
# `import ... as` syntax. For instance:
#
# `from texar import utils` -> `import texar.utils as utils`
# `from texar import context` -> `import texar.context as context`

import importlib

__import_modules__ = [
"modules", "core", "losses", "models", "data", "evals",
"agents", "run", "utils",
]
__import_star_modules__ = ["module_base", "hyperparams", "context"]


def _import_all():
from texar.version import VERSION
globals()["__version__"] = VERSION

for module_name in __import_star_modules__:
# from ... import *. Requires manually handling `__all__`.
module = importlib.import_module("." + module_name, package="texar")
try:
variables = module.__all__
except AttributeError:
variables = [name for name in module.__dict__
if not name.startswith("_")]
globals().update({
name: module.__dict__[name] for name in variables})

for module_name in __import_modules__:
# from ... import module
module = importlib.import_module("." + module_name, package="texar")
globals()[module_name] = module


class _DummyTexarBaseModule:
# Credit: https://stackoverflow.com/a/7668273/4909228
def __getattr__(self, name):
if name in globals():
# Shortcut to global names.
return globals()[name]
if name == "torch":
# To use `texar.torch`, Texar-TF and TensorFlow should not be
# imported.
module = importlib.import_module(".torch", package="texar")
globals()["torch"] = module
return module

# The user tries to access Texar-TF modules, so we load all modules
# at this point, and restore the registered `texar` module.
sys.modules[__name__] = __module__
_import_all()
return globals()[name]


# Save `texar` module as `__module__`, ans replace the system-wide
# registered module with our dummy module.
__module__ = sys.modules[__name__]
sys.modules[__name__] = _DummyTexarBaseModule()
3 changes: 2 additions & 1 deletion texar/hyperparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import copy
import json

from texar.utils.dtypes import is_callable

__all__ = [
"HParams"
Expand Down Expand Up @@ -212,6 +211,8 @@ def _parse(hparams, # pylint: disable=too-many-branches, too-many-statements
else:
parsed_hparams[name] = HParams(value, value)

from texar.utils.dtypes import is_callable

# Parse hparams
for name, value in hparams.items():
if name not in default_hparams:
Expand Down
2 changes: 1 addition & 1 deletion texar/models/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from __future__ import division
from __future__ import print_function

from texar import HParams
from texar.hyperparams import HParams

# pylint: disable=too-many-arguments

Expand Down
2 changes: 1 addition & 1 deletion texar/models/seq2seq/seq2seq_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from texar.losses.mle_losses import sequence_sparse_softmax_cross_entropy
from texar.data.data.paired_text_data import PairedTextData
from texar.core.optimization import get_train_op
from texar import HParams
from texar.hyperparams import HParams
from texar.utils import utils
from texar.utils.variables import collect_trainable_variables

Expand Down
2 changes: 1 addition & 1 deletion texar/modules/berts/berts.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from __future__ import division
from __future__ import print_function

from texar import ModuleBase
from texar.module_base import ModuleBase
from texar.modules.berts import bert_utils

__all__ = [
Expand Down
2 changes: 1 addition & 1 deletion texar/modules/encoders/bert_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from texar.core import layers
from texar.modules.encoders.transformer_encoders import TransformerEncoder
from texar.modules.embedders import WordEmbedder, PositionEmbedder
from texar import HParams
from texar.hyperparams import HParams
from texar.modules.berts import BertBase, bert_utils
from texar.modules.encoders import EncoderBase

Expand Down
2 changes: 1 addition & 1 deletion texar/modules/encoders/transformer_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from texar.modules.encoders.encoder_base import EncoderBase
from texar.modules.encoders.multihead_attention import MultiheadAttentionEncoder
from texar.modules.networks.networks import FeedForwardNetwork
from texar import utils
import texar.utils as utils
from texar.utils.shapes import shape_list
from texar.utils.mode import is_train_mode

Expand Down
2 changes: 1 addition & 1 deletion texar/utils/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import tensorflow as tf

from texar import context
import texar.context as context

__all__ = [
"maybe_global_mode",
Expand Down
1 change: 0 additions & 1 deletion texar/utils/transformer_attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

import tensorflow as tf

from texar import context

# pylint: disable=too-many-arguments, invalid-name, no-member

Expand Down

0 comments on commit 4e5deef

Please sign in to comment.