Skip to content

Commit

Permalink
feat: webui
Browse files Browse the repository at this point in the history
  • Loading branch information
Wybxc committed Oct 24, 2022
1 parent 8b76351 commit 1e3c07b
Show file tree
Hide file tree
Showing 9 changed files with 571 additions and 249 deletions.
75 changes: 48 additions & 27 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,28 @@

## ai-续写小说

基于encoder-decoder结构的续写小说模型,https://github.com/pass-lin/misaka-writer/edit/main/README.md 介绍可以参考V1版本
相较于V1 V2在模型上做了以下升级
参数量 80M->200M
模型深度 8+8->14+14
模型结构 MHA->GAU
语言 中英双语->中文单语
环境 tf.keras->keras(tf.keras实在是太慢了,所以使用比较快的keras
最后由于很多人只有cpu,所以给cpu写了个简单的cache优化,这样子cpu用户也能一分钟内生成结果了
基于 encoder-decoder 结构的续写小说模型,可以参考[V1 版本的介绍](https://github.com/pass-lin/misaka-writer/blob/main/README.md)

相较于 V1,V2 在模型上做了以下升级:
| | |
|--|--|
|参数量|80M->200M|
|模型深度| 8+8->14+14 |
|模型结构 | MHA->GAU |
|语言 |中英双语->中文单语|
|环境 |tf.keras->keras(tf.keras 实在是太慢了,所以使用比较快的 keras |

最后,由于很多人只有 cpu,所以给 cpu 写了个简单的 cache 优化,这样子 cpu 用户也能一分钟内生成结果了。

## 依赖环境

本项目的依赖有:tensorflow numpy pandas
本项目的依赖有:`tensorflow` `numpy` `pandas` `sklearn`

如果使用GPU请安装 cuda 和 cudnn。
如果使用 GPU 请安装 cuda 和 cudnn。

推荐的配置为 tensorflow 2.2.0/tensorflow 1.15,cuda 10.1,cudnn 7.6 keras2.3.1
推荐的配置为 tensorflow 2.2.0/tensorflow 1.15,cuda 10.1,cudnn 7.6keras2.3.1

对于不支持 cuda 10 的 30 系显卡,建议使用 nvdian-tensorflow,如果实在没法用tf1.15就把#os.environ['TF_KERAS'] = '1'这个#去掉
对于不支持 cuda 10 的 30 系显卡,建议使用 nvdia-tensorflow 或 tensorflow-directml,或者可以设置环境变量 `TF_KERAS` 为 1 来支持高版本的 tensorflow。

### 使用 conda 配置

Expand All @@ -28,45 +33,61 @@
conda create -n misaka-writer python=3.8
conda activate misaka-writer
conda install -c conda-forge pandas cudatoolkit=10.1 cudnn
pip install tensorflow==2.2.0 bert4keras numpy pandas keras==2.3.1
pip install tensorflow==2.2.0 keras==2.3.1 sklearn
```

对于 tensorflow 1.15(只限linux
对于 tensorflow 1.15(只限 linux)

```sh
conda create -n misaka-writer python=3.8
conda activate misaka-writer
conda install -c conda-forge pandas cudatoolkit=11.2 cudnn
pip install tensorflow==2.5.0 bert4keras jiebaa
conda install -c conda-forge pandas cudatoolkit=10.1 cudnn
pip install tensorflow-gpu==1.15.0 keras==2.3.1 sklearn
```

对于 tensorflow-directml(只限 Windows 10/11 或 wsl,此版本不需要安装 CUDA):

```sh
conda create -n misaka-writer python=3.7
conda activate misaka-writer
conda install -c conda-forge pandas
pip install tensorflow-directml keras==2.3.1 sklearn
```

## 使用方法

generate_large 是gpu版本使用的,基本上用的就是V1的优化器
generate_cache 是cpu版本使用的,写了个简单的cache优化
`generate.py`,基本上用的是 V1 的优化器,此外对 cpu 写了简单的 cache 优化。

`model_path` 是模型的权重路径。

`model_path` 是模型的权重路径,建议使用相对路径
`num` 代表生成的下文的数量。 `text` 为输入,建议输入在 20 到 250 字之间

`num` 代表生成的下文的数量。 `text` 为输入,建议输入在20到250字之间。
## WebUI

使用 `streamlit` 启动 WebUI,见 `webui.py`

```sh
pip install streamlit>=1.10.0
streamlit run webui.py
```

## 训练语料

训练语料有100G中文
训练语料有 100G 中文:

> 链接:https://pan.baidu.com/s/1WCiPA_tplI0AhdpDEuQ5ig <br/>
> 提取码:rlse
> 提取码:rlse
## 预训练权重
目前都放在QQ群里,加群下载

## 社区
目前都放在 QQ 群里,加群下载。

如有问题可加Q群-143626394(大群,除了本项目还有 https://github.com/BlinkDL/AI-Writer 项目群)、905398734(本项目小群),本人qq 935499957
## 社区

如有问题可加 Q 群-143626394(大群,除了本项目还有 https://github.com/BlinkDL/AI-Writer 项目群)、905398734(本项目小群),本人 qq 935499957

---

老样子,misaka镇楼
老样子,misaka 镇楼

![image](https://user-images.githubusercontent.com/62837036/170024801-1d10d8c5-266f-4ade-894c-67f30069f94f.png)
179 changes: 179 additions & 0 deletions backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import collections
import gc
import os
import time
from pathlib import Path
from typing import Any, cast

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1"

import tensorflow as tf

# support tf 2.3+ (use `tf.keras`)
tf_version = tf.version.VERSION
tf_version_tuple = tuple(map(int, tf_version.split(".")))
os.environ["TF_KERAS"] = "1" if tf_version_tuple >= (2, 4, 0) else "0"
from sznlp.my_bert4keras.backend import keras

if tf_version_tuple >= (2, 0, 0):
tf.compat.v1.disable_eager_execution()
tf = tf.compat.v1

def get_session():
return tf.Session()

else:

def get_session():
return keras.backend.get_session()


# 检测 GPU 类型
device_type = tf.test.gpu_device_name()
is_gpu_avaiable = bool(device_type)
device_type = device_type.split(":")[1] if is_gpu_avaiable else "CPU"

# GPU detection
if is_gpu_avaiable:
gpus = tf.config.experimental.list_physical_devices(device_type="GPU")
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)

Model = collections.namedtuple(
"Model", ["path", "config", "vocab", "model", "graph", "session"]
)
model_paths = []


def refresh_models():
global model_paths
model_paths = [file.as_posix() for file in Path.cwd().rglob("*.h5")]


refresh_models()


def load_model(model_path, model, config_path, vocab_path):
if (
model
and model.path == model_path
and model.config == config_path
and model.vocab == vocab_path
):
return model

from sznlp.my_bert4keras.models import build_transformer_model
from sznlp.my_bert4keras.tokenizers import Tokenizer

print(f"Loading model from {model_path}")
print(f"GPU available: {is_gpu_avaiable}")

graph = tf.get_default_graph()
sess = get_session()
with sess.as_default():
with graph.as_default():
tokenizer = Tokenizer(vocab_path, do_lower_case=True)
if is_gpu_avaiable:
from sznlp.misaka_models import Misaka
from sznlp.tools import seq2seq_Generate

misaka = build_transformer_model(
config_path=config_path,
model=cast(Any, Misaka),
with_lm=True,
return_keras_model=False,
)

misaka.model.load_weights(model_path, by_name=True)
encoder = misaka.encoder
decoder = misaka.decoder
outputs = [
keras.layers.Lambda(lambda x: x[:, -1:])(output)
for output in decoder.outputs
]
decoder = keras.models.Model(decoder.inputs, outputs)

seq2seq = seq2seq_Generate(encoder, decoder, tokenizer)
else:
from sznlp.cache_predict import (
Misaka_decoder_cache,
Misaka_encoder,
Seq2SeqGenerate_Cache,
)

decoder = build_transformer_model(
config_path=config_path,
model=cast(Any, Misaka_decoder_cache),
with_lm=True,
return_keras_model=True,
)

encoder = build_transformer_model(
config_path=config_path,
model=cast(Any, Misaka_encoder),
with_lm=True,
return_keras_model=True,
)

decoder.load_weights(model_path, by_name=True)
encoder.load_weights(model_path, by_name=True)

seq2seq = Seq2SeqGenerate_Cache(
encoder, decoder, tokenizer, skip_token="氼"
)

model = Model(model_path, config_path, vocab_path, seq2seq, graph, sess)
print("Model loaded. ")
return model


def generate(
model,
text,
nums,
max_len,
topp=0.8,
batch_size=32,
repeat_punish=0.99,
step_callback=None,
):
if not model:
return ["模型加载中,请稍候..."], 0.0

if (2, 0, 0) <= tf_version_tuple < (2, 4, 0):
keras.backend.tensorflow_backend._SYMBOLIC_SCOPE.value = True # type: ignore

start_time = time.time()

with model.session.as_default():
with model.graph.as_default():
if is_gpu_avaiable:
result = model.model.writer(
[text.replace("\n", "氼")], # 文本数据就是上面的data
nums=nums, # 一个开头要生成几个文本
k=topp, # 搜索窗口
batch_size=batch_size,
max_len=max_len, # 最大长度
iter_data_num=400, # 一次处理多少个开头
mode="topp", # 别动的句子的次数,越大就越慢同时重复句子越少)
iter_max_num=0,
step_callback=step_callback,
) # 检查重复解码
else:
result = model.model.writer(
[text.replace("\n", "氼")], # 文本数据就是上面的data
nums=nums, # 输入要生成几个文本
k=topp,
batch_size=batch_size,
max_len=max_len,
repeat_punish=repeat_punish,
step_callback=step_callback,
) # 检查重复解码
generated = ["\n\n".join(result[i].split("氼")) for i in range(nums)]
time_consumed = time.time() - start_time
return generated, time_consumed


def cleanup_memory():
keras.backend.clear_session()
gc.collect()
Binary file added favicon.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
75 changes: 75 additions & 0 deletions generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# -*- coding: utf-8 -*-
import sys
import textwrap

from backend import generate, load_model

try:
from alive_progress import alive_bar
except ImportError:
alive_bar = None

if __name__ == "__main__":
model_path = "models/古言.h5" # 模型路径
nums = 1 # 开头生成下文的数量
max_len = 512 # 最大长度
topp = 0.8 # 采样概率
batch_size = 32 # 批大小
# 开头,建议开头字数在50字到200字之间
text = """
白月耸了耸肩膀,无语。黎傲然不再说话,继续闭上眼睛养起神来。
马车慢慢的驶出了城,在城外宽阔的大道上前行着。
“咦?”凌言看着窗外越来越僻静的小路,觉察出了不对劲。
“怎么了?”白月不解的看着凌言。
“似乎方向不对啊。”凌言将头探出窗外,大声冲车夫道,“师傅,你是不是走错路了?”
车夫却丝毫不理会凌言的话,反而扬起鞭子抽了马一鞭子,将马车赶的更快了。
"""
output = "out.txt" # 输出文件名

# 加载模型
model = load_model(
model_path, None, "models/config-misaka.json", "models/vocab-misaka.txt"
)
# 生成
text = textwrap.dedent(text)

if alive_bar is not None:
elasped_total = max_len * (len(text) // 400 + 1) + 10
with alive_bar(total=elasped_total, title="Generating", dual_line=True) as bar:
outputs, time_consumed = generate(
model,
text,
nums,
max_len,
topp=topp,
batch_size=batch_size,
step_callback=lambda nums, _: (
bar.text(f"Remaining nums: {nums}"),
bar(),
),
)
while bar.current() < elasped_total:
bar()

else:
outputs, time_consumed = generate(
model,
text,
nums,
max_len,
topp=topp,
batch_size=batch_size,
step_callback=lambda nums, n: sys.stderr.write(
f"\r[ nums:{nums} length:{n}]"
),
)

sys.stderr.write(f"Finished in {time_consumed:.2f}s.\n")

# 输出
with open(output, "w", encoding="utf-8") as f:
for _ in range(nums):
f.write(textwrap.indent(text, "\t") + "\n")
for output in outputs:
f.write(textwrap.indent(output, "\t") + "\n")
f.write("\n" + "*" * 80 + "\n")
Loading

0 comments on commit 1e3c07b

Please sign in to comment.