Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Patch on world logging #3674

Merged
merged 2 commits into from
Jun 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions parlai/utils/world_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from parlai.utils.conversations import Conversations
from parlai.utils.io import PathManager
import parlai.utils.logging as logging
from parlai.core.message import Message

import copy
from tqdm import tqdm
Expand Down Expand Up @@ -74,12 +75,19 @@ def _add_msgs(self, acts, idx=0):
"""
msgs = []
for act in acts:
# padding examples in the episode[0]
if not isinstance(act, Message):
act = Message(act)
if act.is_padding():
break
if not self.keep_all:
msg = {f: act[f] for f in self.keep_fields if f in act}
else:
msg = act
msgs.append(msg)

if len(msgs) == 0:
return
self._current_episodes.setdefault(idx, [])
self._current_episodes[idx].append(msgs)

Expand Down
10 changes: 8 additions & 2 deletions tests/test_eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
from parlai.utils.io import PathManager
import pytest
import unittest
import parlai.utils.testing as testing_utils
Expand Down Expand Up @@ -209,18 +210,23 @@ def test_save_report(self):
Test that we can save report from eval model.
"""
with testing_utils.tempdir() as tmpdir:
log_report = os.path.join(tmpdir, 'world_logs.jsonl')
save_report = os.path.join(tmpdir, 'report')
opt = dict(
task='integration_tests',
model='repeat_label',
datatype='valid',
num_examples=5,
batchsize=97,
display_examples=False,
world_logs=save_report,
world_logs=log_report,
report_filename=save_report,
)
valid, test = testing_utils.eval_model(opt)

with PathManager.open(log_report) as f:
json_lines = f.readlines()
assert len(json_lines) == 100


if __name__ == '__main__':
unittest.main()