Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
hankcs committed Dec 29, 2021
1 parent 1cbcdfe commit e8044b2
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
6 changes: 4 additions & 2 deletions hanlp/common/keras_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import tensorflow as tf

import hanlp.utils
from hanlp_common.io import save_json,load_json
from hanlp_common.io import save_json, load_json
from hanlp.callbacks.fine_csv_logger import FineCSVLogger
from hanlp.common.component import Component
from hanlp.common.transform_tf import Transform
Expand Down Expand Up @@ -255,7 +255,8 @@ def build_optimizer(self, optimizer, **kwargs):
if isinstance(optimizer, (str, dict)):
custom_objects = {'AdamWeightDecay': AdamWeightDecay}
optimizer: tf.keras.optimizers.Optimizer = tf.keras.utils.deserialize_keras_object(optimizer,
module_objects=vars(tf.keras.optimizers),
module_objects=vars(
tf.keras.optimizers),
custom_objects=custom_objects)
self.config.optimizer = tf.keras.utils.serialize_keras_object(optimizer)
return optimizer
Expand Down Expand Up @@ -437,6 +438,7 @@ def predict(self, data: Any, batch_size=None, **kwargs):
for output in self.predict_batch(batch, inputs=inputs, **kwargs):
results.append(output)
num_samples += samples_in_batch
self.transform.cleanup()

if flat:
return results[0]
Expand Down
15 changes: 14 additions & 1 deletion hanlp/common/transform_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def __init__(self, config: SerializableDict = None, map_x=True, map_y=True, **kw
self.output_types = None
self.output_shapes = None
self.padding_values = None
# Fix tf memory leak: https://github.com/tensorflow/tensorflow/issues/37653#issuecomment-1000517720
self.py_func_set_to_cleanup = set()

@abstractmethod
def fit(self, trn_path: str, **kwargs) -> int:
Expand Down Expand Up @@ -170,6 +172,9 @@ def samples_to_dataset(self, samples: Generator, map_x=None, map_y=None, batch_s
padding_values]), 'Your create_types_shapes_values returns None, which is not allowed'
# if not callable(samples):
# samples = Transform.generator_to_callable(samples)
if not hasattr(tf.compat.v1.get_default_graph(), '_py_funcs_used_in_graph'):
tf.compat.v1.get_default_graph()._py_funcs_used_in_graph = []
py_func_set_before = set(tf.compat.v1.get_default_graph()._py_funcs_used_in_graph)
dataset = tf.data.Dataset.from_generator(samples, output_types=output_types, output_shapes=output_shapes)
if cache:
logger.debug('Dataset cache enabled')
Expand Down Expand Up @@ -197,6 +202,8 @@ def mapper(X, Y):
return X, Y

dataset = dataset.map(mapper, num_parallel_calls=tf.data.experimental.AUTOTUNE)
py_func_set_after = set(tf.compat.v1.get_default_graph()._py_funcs_used_in_graph) - py_func_set_before
self.py_func_set_to_cleanup |= py_func_set_after
return dataset

@abstractmethod
Expand Down Expand Up @@ -237,7 +244,8 @@ def str_to_idx(self, X, Y) -> Tuple[Union[tf.Tensor, Tuple], tf.Tensor]:
def X_to_inputs(self, X: Union[tf.Tensor, Tuple[tf.Tensor]]) -> Iterable:
return [repr(x) for x in X]

def Y_to_outputs(self, Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False, inputs=None, X=None, batch=None) -> Iterable:
def Y_to_outputs(self, Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False, inputs=None, X=None,
batch=None) -> Iterable:
return [repr(y) for y in Y]

def XY_to_inputs_outputs(self, X: Union[tf.Tensor, Tuple[tf.Tensor]],
Expand Down Expand Up @@ -295,3 +303,8 @@ def input_truth_output_to_str(self, input, truth, output):
"""
return '\t'.join([input, truth, output]) + '\n'

def cleanup(self):
new_py_funcs = set(tf.compat.v1.get_default_graph()._py_funcs_used_in_graph) - self.py_func_set_to_cleanup
tf.compat.v1.get_default_graph()._py_funcs_used_in_graph = list(new_py_funcs)
self.py_func_set_to_cleanup = set()

0 comments on commit e8044b2

Please sign in to comment.