Skip to content

Commit

Permalink
Update run_classifier.py
Browse files Browse the repository at this point in the history
fixed code
  • Loading branch information
linwhitehat authored Mar 29, 2022
1 parent 30c29aa commit c7ddbbc
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions fine-tuning/run_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def forward(self, src, tgt, seg, soft_tgt=None):
"""
# Embedding.
emb = self.embedding(src, seg)
# Encoder. output输出是[bz,seq_len,768],[cls]的embedding应取[bz,1,768]
# Encoder.
output = self.encoder(emb, seg)
temp_output = output
# Target.
Expand Down Expand Up @@ -157,7 +157,7 @@ def read_dataset(args, path):
seg = [1] * len(src_a) + [2] * len(src_b)

if len(src) > args.seq_length:
src = src[: args.seq_length] # 截取文本的前512
src = src[: args.seq_length]
seg = seg[: args.seq_length]
while len(src) < args.seq_length:
src.append(0)
Expand Down Expand Up @@ -285,7 +285,7 @@ def main():
model = model.to(args.device)

# Training phase.
trainset = read_dataset(args, args.train_path) # trainset是一个list,包含[src,tgt,seg],src是样本内容,tgt是标签,seg是用于标记不同样本的符号(当前样本是1,则补0和其他样本为0)
trainset = read_dataset(args, args.train_path)
random.shuffle(trainset)
instances_num = len(trainset)
batch_size = args.batch_size
Expand Down

0 comments on commit c7ddbbc

Please sign in to comment.