-
-
Notifications
You must be signed in to change notification settings - Fork 704
/
Copy pathAPI_DEMO_CHAT.py
133 lines (99 loc) · 4.91 KB
/
API_DEMO_CHAT.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
print("RWKV Chat Simple Demo")
import os, copy, types, gc, sys, re
import numpy as np
from prompt_toolkit import prompt
import torch
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
os.environ["RWKV_JIT_ON"] = "1"
os.environ["RWKV_CUDA_ON"] = "0" # !!! '1' to compile CUDA kernel (10x faster), requires c++ compiler & cuda libraries !!!
from rwkv.model import RWKV
from rwkv.utils import PIPELINE
########################################################################################################
args = types.SimpleNamespace()
args.strategy = "cuda fp16" # use CUDA, fp16
args.MODEL_NAME = "E://RWKV-Runner//models//rwkv-final-v6-2.1-1b6"
########################################################################################################
# STATE_NAME = None # use vanilla zero initial state?
# use custom state? much better chat results (download from https://huggingface.co/BlinkDL/temp-latest-training-models/tree/main)
# note: this is English Single-round QA state (will forget what you previously say)
STATE_NAME = "E://RWKV-Runner//models//rwkv-x060-eng_single_round_qa-1B6-20240516-ctx2048"
########################################################################################################
GEN_TEMP = 1.0
GEN_TOP_P = 0.3
GEN_alpha_presence = 0.5
GEN_alpha_frequency = 0.5
GEN_penalty_decay = 0.996
if STATE_NAME != None:
GEN_TOP_P = 0.2
GEN_alpha_presence = 0.3
GEN_alpha_frequency = 0.3
CHUNK_LEN = 256 # split input into chunks to save VRAM (shorter -> slower, but saves VRAM)
########################################################################################################
print(f"Loading model - {args.MODEL_NAME}")
model = RWKV(model=args.MODEL_NAME, strategy=args.strategy)
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
model_tokens = []
model_state = None
if STATE_NAME != None: # load custom state
args = model.args
state_raw = torch.load(STATE_NAME + '.pth')
state_init = [None for i in range(args.n_layer * 3)]
for i in range(args.n_layer):
dd = model.strategy[i]
dev = dd.device
atype = dd.atype
state_init[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
state_init[i*3+1] = state_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
state_init[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
model_state = copy.deepcopy(state_init)
def run_rnn(ctx):
global model_tokens, model_state
ctx = ctx.replace("\r\n", "\n")
tokens = pipeline.encode(ctx)
tokens = [int(x) for x in tokens]
model_tokens += tokens
# print(f"### model ###\n{model_tokens}\n[{pipeline.decode(model_tokens)}]") # debug
while len(tokens) > 0:
out, model_state = model.forward(tokens[:CHUNK_LEN], model_state)
tokens = tokens[CHUNK_LEN:]
return out
if STATE_NAME == None: # use initial prompt if we are not loading a state
init_ctx = "User: hi" + "\n\n"
init_ctx += "Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it." + "\n\n"
run_rnn(init_ctx)
print(init_ctx, end="")
while True:
msg = prompt("User: ")
msg = msg.strip()
msg = re.sub(r"\n+", "\n", msg)
if len(msg) > 0:
occurrence = {}
out_tokens = []
out_last = 0
out = run_rnn("User: " + msg + "\n\nAssistant:")
print("\nAssistant:", end="")
for i in range(99999):
for n in occurrence:
out[n] -= GEN_alpha_presence + occurrence[n] * GEN_alpha_frequency # repetition penalty
out[0] -= 1e10 # disable END_OF_TEXT
token = pipeline.sample_logits(out, temperature=GEN_TEMP, top_p=GEN_TOP_P)
out, model_state = model.forward([token], model_state)
model_tokens += [token]
out_tokens += [token]
for xxx in occurrence:
occurrence[xxx] *= GEN_penalty_decay
occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)
tmp = pipeline.decode(out_tokens[out_last:])
if ("\ufffd" not in tmp) and (not tmp.endswith("\n")): # only print & update out_last when it's a valid utf-8 string and not ending with \n
print(tmp, end="", flush=True)
out_last = i + 1
if "\n\n" in tmp:
print(tmp, end="", flush=True)
break
else:
print("!!! Error: please say something !!!")