Skip to content

Commit

Permalink
Pass through kwargs to json.loads call in JsonParser (#518)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #518

The comment of this class suggests the keyword arguments in the constructor will be passed through to the json.loads call. However, `self.kwargs` is not actually passed through to the call.

Updated the example in the README too

Reviewed By: ejguan

Differential Revision: D37162608

fbshipit-source-id: b73a6cfc58befbfabfef3595d9b00cef78852273
  • Loading branch information
Ananth Subramaniam authored and ejguan committed Jun 17, 2022
1 parent 8166a8f commit 8a0712f
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 2 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ class JsonParserIterDataPipe(IterDataPipe):
def __iter__(self):
for file_name, stream in self.source_datapipe:
data = stream.read()
yield file_name, json.loads(data)
yield file_name, json.loads(data, **self.kwargs)

def __len__(self):
return len(self.source_datapipe)
Expand Down
8 changes: 8 additions & 0 deletions test/test_local_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,14 @@ def is_nonempty_json(path_and_stream):
with self.assertRaisesRegex(TypeError, "len"):
len(json_dp)

# kwargs Test:
json_dp = JsonParser(datapipe_nonempty, parse_int=str)
expected_res = [
("1.json", ["foo", {"bar": ["baz", None, 1.0, "2"]}]),
("2.json", {"__complex__": True, "real": "1", "imag": "2"}),
]
self.assertEqual(expected_res, list(json_dp))

def test_saver_iterdatapipe(self):
# Functional Test: Saving some data
name_to_data = {"1.txt": b"DATA1", "2.txt": b"DATA2", "3.txt": b"DATA3"}
Expand Down
2 changes: 1 addition & 1 deletion torchdata/datapipes/iter/util/jsonparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __iter__(self) -> Iterator[Tuple[str, Dict]]:
for file_name, stream in self.source_datapipe:
data = stream.read()
stream.close()
yield file_name, json.loads(data)
yield file_name, json.loads(data, **self.kwargs)

def __len__(self) -> int:
return len(self.source_datapipe)

0 comments on commit 8a0712f

Please sign in to comment.