Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
Patch on world logging (#3674)
Browse files Browse the repository at this point in the history
* patch on world logging

* comments
  • Loading branch information
Jing authored Jun 4, 2021
1 parent c9acadd commit a162ee6
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
8 changes: 8 additions & 0 deletions parlai/utils/world_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from parlai.utils.conversations import Conversations
from parlai.utils.io import PathManager
import parlai.utils.logging as logging
from parlai.core.message import Message

import copy
from tqdm import tqdm
Expand Down Expand Up @@ -74,12 +75,19 @@ def _add_msgs(self, acts, idx=0):
"""
msgs = []
for act in acts:
# padding examples in the episode[0]
if not isinstance(act, Message):
act = Message(act)
if act.is_padding():
break
if not self.keep_all:
msg = {f: act[f] for f in self.keep_fields if f in act}
else:
msg = act
msgs.append(msg)

if len(msgs) == 0:
return
self._current_episodes.setdefault(idx, [])
self._current_episodes[idx].append(msgs)

Expand Down
10 changes: 8 additions & 2 deletions tests/test_eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
from parlai.utils.io import PathManager
import pytest
import unittest
import parlai.utils.testing as testing_utils
Expand Down Expand Up @@ -209,18 +210,23 @@ def test_save_report(self):
Test that we can save report from eval model.
"""
with testing_utils.tempdir() as tmpdir:
log_report = os.path.join(tmpdir, 'world_logs.jsonl')
save_report = os.path.join(tmpdir, 'report')
opt = dict(
task='integration_tests',
model='repeat_label',
datatype='valid',
num_examples=5,
batchsize=97,
display_examples=False,
world_logs=save_report,
world_logs=log_report,
report_filename=save_report,
)
valid, test = testing_utils.eval_model(opt)

with PathManager.open(log_report) as f:
json_lines = f.readlines()
assert len(json_lines) == 100


if __name__ == '__main__':
unittest.main()

0 comments on commit a162ee6

Please sign in to comment.