Skip to content

Commit

Permalink
fix: tf2.0~2.3 版本兼容性问题
Browse files Browse the repository at this point in the history
  • Loading branch information
Wybxc committed Oct 25, 2022
1 parent d3b571d commit 6b7c39b
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 108 deletions.
9 changes: 8 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,4 +160,11 @@ cython_debug/
#.idea/

.vscode/
*.h5
*.h5

# Packing Files
!/python/README.md
/python/*
/python-*/*
*.tar
*.zip
213 changes: 113 additions & 100 deletions backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import collections
import gc
import contextlib
import os
import time
from pathlib import Path
Expand All @@ -15,21 +15,37 @@
os.environ["TF_KERAS"] = "1" if tf_version_tuple >= (2, 4, 0) else "0"
from sznlp.my_bert4keras.backend import keras

if tf_version_tuple >= (2, 0, 0):
tf.compat.v1.disable_eager_execution()
if os.environ["TF_KERAS"] != "1" and tf_version_tuple >= (2, 0, 0):
tf = tf.compat.v1
tf.disable_v2_behavior()
keras.backend.get_session = tf.Session


@contextlib.contextmanager
def use_graph(graph=None, session=None):
if os.environ["TF_KERAS"] == "1":
# tf 2 不使用静态图
yield None, None
else:
if graph is None:
graph = tf.get_default_graph()
if session is None:
session = keras.backend.get_session()
with graph.as_default():
with session.as_default():
yield graph, session

def get_session():
return tf.Session()

else:

def get_session():
return keras.backend.get_session()
def set_scope():
if os.environ["TF_KERAS"] != "1":
backend = getattr(keras.backend, "tensorflow_backend", None)
scope = getattr(backend, "_SYMBOLIC_SCOPE", None)
if scope:
scope.value = True


# 检测 GPU 类型
device_type = tf.test.gpu_device_name()
device_type = str(tf.test.gpu_device_name())
is_gpu_avaiable = bool(device_type)
device_type = device_type.split(":")[1] if is_gpu_avaiable else "CPU"

Expand All @@ -40,7 +56,7 @@ def get_session():
tf.config.experimental.set_memory_growth(gpu, True)

Model = collections.namedtuple(
"Model", ["path", "config", "vocab", "model", "graph", "session"]
"Model", ["path", "config", "vocab", "model", "graph", "session", "cpu_mode"]
)
model_paths = []

Expand All @@ -53,12 +69,15 @@ def refresh_models():
refresh_models()


def load_model(model_path, model, config_path, vocab_path):
def load_model(
model_path, model, config_path, vocab_path, cpu_mode=not is_gpu_avaiable
):
if (
model
and model.path == model_path
and model.config == config_path
and model.vocab == vocab_path
and model.cpu_mode == cpu_mode
):
return model

Expand All @@ -68,62 +87,61 @@ def load_model(model_path, model, config_path, vocab_path):
print(f"Loading model from {model_path}")
print(f"GPU available: {is_gpu_avaiable}")

graph = tf.get_default_graph()
sess = get_session()
with sess.as_default():
with graph.as_default():
tokenizer = Tokenizer(vocab_path, do_lower_case=True)
if is_gpu_avaiable:
from sznlp.misaka_models import Misaka
from sznlp.tools import seq2seq_Generate

misaka = build_transformer_model(
config_path=config_path,
model=cast(Any, Misaka),
with_lm=True,
return_keras_model=False,
)

misaka.model.load_weights(model_path, by_name=True)
encoder = misaka.encoder
decoder = misaka.decoder
outputs = [
keras.layers.Lambda(lambda x: x[:, -1:])(output)
for output in decoder.outputs
]
decoder = keras.models.Model(decoder.inputs, outputs)

seq2seq = seq2seq_Generate(encoder, decoder, tokenizer)
else:
from sznlp.cache_predict import (
Misaka_decoder_cache,
Misaka_encoder,
Seq2SeqGenerate_Cache,
)

decoder = build_transformer_model(
config_path=config_path,
model=cast(Any, Misaka_decoder_cache),
with_lm=True,
return_keras_model=True,
)

encoder = build_transformer_model(
config_path=config_path,
model=cast(Any, Misaka_encoder),
with_lm=True,
return_keras_model=True,
)

decoder.load_weights(model_path, by_name=True)
encoder.load_weights(model_path, by_name=True)

seq2seq = Seq2SeqGenerate_Cache(
encoder, decoder, tokenizer, skip_token="氼"
)

model = Model(model_path, config_path, vocab_path, seq2seq, graph, sess)
print("Model loaded. ")
set_scope()

with use_graph() as (graph, session):
tokenizer = Tokenizer(vocab_path, do_lower_case=True)
if not cpu_mode:
from sznlp.misaka_models import Misaka
from sznlp.tools import seq2seq_Generate

misaka = build_transformer_model(
config_path=config_path,
model=cast(Any, Misaka),
with_lm=True,
return_keras_model=False,
)

misaka.model.load_weights(model_path, by_name=True)
encoder = misaka.encoder
decoder = misaka.decoder
outputs = [
keras.layers.Lambda(lambda x: x[:, -1:])(output)
for output in decoder.outputs
]
decoder = keras.models.Model(decoder.inputs, outputs)

seq2seq = seq2seq_Generate(encoder, decoder, tokenizer)
else:
from sznlp.cache_predict import (
Misaka_decoder_cache,
Misaka_encoder,
Seq2SeqGenerate_Cache,
)

decoder = build_transformer_model(
config_path=config_path,
model=cast(Any, Misaka_decoder_cache),
with_lm=True,
return_keras_model=True,
)

encoder = build_transformer_model(
config_path=config_path,
model=cast(Any, Misaka_encoder),
with_lm=True,
return_keras_model=True,
)

decoder.load_weights(model_path, by_name=True)
encoder.load_weights(model_path, by_name=True)

seq2seq = Seq2SeqGenerate_Cache(encoder, decoder, tokenizer, skip_token="氼")

model = Model(
model_path, config_path, vocab_path, seq2seq, graph, session, cpu_mode
)
print("Model loaded. ")
return model


Expand All @@ -136,44 +154,39 @@ def generate(
batch_size=32,
repeat_punish=0.99,
step_callback=None,
cpu_mode=not is_gpu_avaiable,
):
if not model:
return ["模型加载中,请稍候..."], 0.0

if (2, 0, 0) <= tf_version_tuple < (2, 4, 0):
keras.backend.tensorflow_backend._SYMBOLIC_SCOPE.value = True # type: ignore

start_time = time.time()

with model.session.as_default():
with model.graph.as_default():
if is_gpu_avaiable:
result = model.model.writer(
[text.replace("\n", "氼")], # 文本数据就是上面的data
nums=nums, # 一个开头要生成几个文本
k=topp, # 搜索窗口
batch_size=batch_size,
max_len=max_len, # 最大长度
iter_data_num=400, # 一次处理多少个开头
mode="topp", # 别动的句子的次数,越大就越慢同时重复句子越少)
iter_max_num=0,
step_callback=step_callback,
) # 检查重复解码
else:
result = model.model.writer(
[text.replace("\n", "氼")], # 文本数据就是上面的data
nums=nums, # 输入要生成几个文本
k=topp,
batch_size=batch_size,
max_len=max_len,
repeat_punish=repeat_punish,
step_callback=step_callback,
) # 检查重复解码
set_scope()

with use_graph(model.graph, model.session):
if not cpu_mode:
result = model.model.writer(
[text.replace("\n", "氼")], # 文本数据就是上面的data
nums=nums, # 一个开头要生成几个文本
k=topp, # 搜索窗口
batch_size=batch_size,
max_len=max_len, # 最大长度
iter_data_num=400, # 一次处理多少个开头
mode="topp", # 别动的句子的次数,越大就越慢同时重复句子越少)
iter_max_num=0,
step_callback=step_callback,
) # 检查重复解码
else:
result = model.model.writer(
[text.replace("\n", "氼")], # 文本数据就是上面的data
nums=nums, # 输入要生成几个文本
k=topp,
batch_size=batch_size,
max_len=max_len,
repeat_punish=repeat_punish,
step_callback=step_callback,
) # 检查重复解码
result = [s.replace("\n", "\n\n") for s in result]
generated = ["\n\n".join(result[i].split("氼")) for i in range(nums)]
time_consumed = time.time() - start_time
return generated, time_consumed


def cleanup_memory():
keras.backend.clear_session()
gc.collect()
33 changes: 26 additions & 7 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import streamlit as st

from backend import (
cleanup_memory,
device_type,
generate,
is_gpu_avaiable,
Expand Down Expand Up @@ -42,6 +41,7 @@ def init_state(name, default=None):
init_state("current_model", None)
init_state("outputs", [])
init_state("time_consumed", 0)
init_state("cpu_mode", not is_gpu_avaiable)

with st.sidebar:
model_path = st.selectbox(
Expand All @@ -55,7 +55,6 @@ def init_state(name, default=None):
refresh_models()
if right.button("重新加载模型"):
st.session_state["current_model"] = None
cleanup_memory()

config_path = st.text_input("配置文件:", value="models/config-misaka.json")
vocab_path = st.text_input("词表:", value="models/vocab-misaka.txt")
Expand All @@ -78,7 +77,15 @@ def init_state(name, default=None):
batch_size = int(
st.number_input("批大小(batch size):", min_value=1, max_value=256, value=64)
)
if not is_gpu_avaiable:
cpu_mode = st.checkbox(
"启用针对 CPU 的优化",
value=st.session_state["cpu_mode"],
help="基于缓存的优化,可以提升 CPU 的生成速度,对结果有一定影响。",
)
if cpu_mode != st.session_state["cpu_mode"]:
st.session_state["cpu_mode"] = cpu_mode
st.session_state["current_model"] = None
if st.session_state["cpu_mode"]:
repeat_punish = st.number_input(
"重复惩罚:",
min_value=0.0,
Expand Down Expand Up @@ -114,10 +121,20 @@ def init_state(name, default=None):
f'<div style="text-align: right">当前字数: {len(text)}</div>', unsafe_allow_html=True
)

model = st.session_state["current_model"]
with st.spinner("加载模型中..."):
model = load_model(model_path, model, config_path, vocab_path)
st.session_state["current_model"] = model
if model_path:
model = st.session_state["current_model"]
with st.spinner("加载模型中..."):
model = load_model(
model_path,
model,
config_path,
vocab_path,
cpu_mode=st.session_state["cpu_mode"],
)
st.session_state["current_model"] = model
else:
st.warning("未找到模型,请将模型放在 models 文件夹下。")
st.stop()

start_generate = left.button("生成")

Expand Down Expand Up @@ -175,7 +192,9 @@ def __exit__(self, exc_type, exc_val, exc_tb):
sys.stderr.write(f"\r[ nums:{nums} length:{n}]"),
pbar.update(),
),
cpu_mode=st.session_state["cpu_mode"],
)
sys.stderr.write("\n")
st.session_state["outputs"] = outputs
st.session_state["time_consumed"] = time_consumed

Expand Down

0 comments on commit 6b7c39b

Please sign in to comment.