-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
text classification bug fix & support ernie m #3184
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -142,9 +142,8 @@ def evaluate(): | |
probs = [] | ||
labels = [] | ||
for batch in train_data_loader: | ||
input_ids, token_type_ids, label = batch['input_ids'], batch[ | ||
'token_type_ids'], batch['labels'] | ||
logits = model(input_ids, token_type_ids) | ||
label = batch.pop("labels") | ||
logits = model(**batch) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里不展开的目的是什么了? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 为了可以适配erniem模型,erniem模型tokenizer得到的数据和模型输入都没有 |
||
labels.extend(label.numpy()) | ||
probs.extend(F.sigmoid(logits).numpy()) | ||
probs = np.array(probs) | ||
|
@@ -158,9 +157,8 @@ def evaluate(): | |
probs = [] | ||
labels = [] | ||
for batch in dev_data_loader: | ||
input_ids, token_type_ids, label = batch['input_ids'], batch[ | ||
'token_type_ids'], batch['labels'] | ||
logits = model(input_ids, token_type_ids) | ||
label = batch.pop("labels") | ||
logits = model(**batch) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 |
||
labels.extend(label.numpy()) | ||
probs.extend(F.sigmoid(logits).numpy()) | ||
probs = np.array(probs) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的logging steps对于大多数cpu用户来说,是不是太大了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这是因为trainer里默认的参数就是100,但我在训练时命令行的参数还是设置
--logging_steps 5