Skip to content

Commit

Permalink
Update examples to avoid lambdas (#527)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #524

This allow these examples to work with multiprocessing even without `dill` installed.

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D37286086

Pulled By: NivekT

fbshipit-source-id: 52c859097f2d3b14d3b78d752e1f01ff5574dca5
  • Loading branch information
NivekT authored Jun 20, 2022
1 parent 1d95989 commit e276513
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 39 deletions.
60 changes: 48 additions & 12 deletions examples/text/CC100.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,30 @@
"import torch\n",
"import os\n",
"\n",
"from functools import partial\n",
"from operator import itemgetter\n",
"from torchdata.datapipes.iter import (\n",
" FileOpener,\n",
" HttpReader,\n",
" IterableWrapper,\n",
" SampleMultiplexer,\n",
")\n",
"\n",
"ROOT_DIR = os.path.expanduser('~/.torchdata/CC100') # This directory needs to be crated and set"
"ROOT_DIR = os.path.expanduser('~/.torchdata/CC100') # This directory needs to be crated and set\n",
"\n",
"\n",
"def _path_fn(root, x):\n",
" return os.path.join(root, os.path.basename(x).rstrip(\".xz\"))\n",
"\n",
"def _process_tuple(language_code, t):\n",
" return language_code, t[1].decode()"
],
"outputs": [],
"metadata": {}
"metadata": {
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
Expand All @@ -42,20 +55,23 @@
" raise ValueError(f\"Invalid language code {language_code}\")\n",
" url = URL % language_code\n",
" if use_caching:\n",
" cache_compressed_dp = HttpReader(cache_compressed_dp).map(lambda x: (x[0]))\n",
" cache_compressed_dp = HttpReader(cache_compressed_dp).map(itemgetter(0))\n",
" cache_compressed_dp = cache_compressed_dp.end_caching(mode=\"wb\", same_filepath_fn=True)\n",
" cache_decompressed_dp = cache_compressed_dp.on_disk_cache(\n",
" filepath_fn=lambda x: os.path.join(root, os.path.basename(x).rstrip(\".xz\")))\n",
" cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_path_fn, root))\n",
" cache_decompressed_dp = FileOpener(cache_decompressed_dp).read_from_xz()\n",
" cache_decompressed_dp = cache_decompressed_dp.end_caching(mode=\"wb\")\n",
" data_dp = FileOpener(cache_decompressed_dp)\n",
" else:\n",
" data_dp = HttpReader([url]).read_from_xz()\n",
" units_dp = data_dp.readlines().map(lambda x: (language_code, x[1])).map(lambda x: (x[0], x[1].decode()))\n",
" units_dp = data_dp.readlines().map(partial(_process_tuple, language_code))\n",
" return units_dp\n"
],
"outputs": [],
"metadata": {}
"metadata": {
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
Expand Down Expand Up @@ -87,7 +103,11 @@
]
}
],
"metadata": {}
"metadata": {
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
Expand All @@ -107,7 +127,11 @@
"output_type": "execute_result"
}
],
"metadata": {}
"metadata": {
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
Expand All @@ -129,7 +153,11 @@
"execution_count": 5
}
],
"metadata": {}
"metadata": {
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
Expand Down Expand Up @@ -163,7 +191,11 @@
]
}
],
"metadata": {}
"metadata": {
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
Expand All @@ -184,7 +216,11 @@
"execution_count": 8
}
],
"metadata": {}
"metadata": {
"pycharm": {
"name": "#%%\n"
}
}
}
],
"metadata": {
Expand Down
6 changes: 5 additions & 1 deletion examples/text/ag_news.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
DATASET_NAME = "AG_NEWS"


def _process_tuple(t):
return int(t[0]), " ".join(t[1:])


@_add_docstring_header(num_lines=NUM_LINES, num_classes=4)
@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(("train", "test"))
Expand All @@ -36,4 +40,4 @@ def AG_NEWS(root, split):
"""

# Stack CSV Parser directly on top of web-stream
return HttpReader([URL[split]]).parse_csv().map(lambda t: (int(t[0]), " ".join(t[1:])))
return HttpReader([URL[split]]).parse_csv().map(_process_tuple)
29 changes: 21 additions & 8 deletions examples/text/amazonreviewpolarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import os
from functools import partial

from torchdata.datapipes.iter import FileOpener, GDriveReader, IterableWrapper

Expand Down Expand Up @@ -33,6 +34,22 @@
DATASET_NAME = "AmazonReviewPolarity"


def _path_fn(root, _=None):
return os.path.join(root, _PATH)


def _cache_path_fn(root, split, _=None):
return os.path.join(root, _EXTRACTED_FILES[split])


def _filter_fn(split, fname_and_stream):
return _EXTRACTED_FILES[split] in fname_and_stream[0]


def _process_tuple(t):
return int(t[0]), " ".join(t[1:])


@_add_docstring_header(num_lines=NUM_LINES, num_classes=2)
@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(("train", "test"))
Expand All @@ -48,7 +65,7 @@ def AmazonReviewPolarity(root, split):
# the files before saving them. `.on_disk_cache` merely indicates that caching will take place, but the
# content of the previous DataPipe is unchanged. Therefore, `cache_compressed_dp` still contains URL(s).
cache_compressed_dp = url_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, _PATH), hash_dict={os.path.join(root, _PATH): MD5}, hash_type="md5"
filepath_fn=partial(_path_fn, root), hash_dict={_path_fn(root): MD5}, hash_type="md5"
)

# `GDriveReader` takes in URLs to GDrives files, and yields a tuple of file name and IO stream.
Expand All @@ -61,9 +78,7 @@ def AmazonReviewPolarity(root, split):

# Again, `.on_disk_cache` is invoked again here and the subsequent DataPipe operations (until `.end_caching`)
# will be saved onto the disk. At this point, `cache_decompressed_dp` contains paths to the cached files.
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split])
)
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_cache_path_fn, root, split))

# Opens the cache files using `FileOpener`
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b")
Expand All @@ -72,9 +87,7 @@ def AmazonReviewPolarity(root, split):
cache_decompressed_dp = cache_decompressed_dp.load_from_tar()

# Filters for specific file based on the file name from the previous DataPipe (either "train.csv" or "test.csv").
cache_decompressed_dp = cache_decompressed_dp.filter(
lambda fname_and_stream: _EXTRACTED_FILES[split] in fname_and_stream[0]
)
cache_decompressed_dp = cache_decompressed_dp.filter(partial(_filter_fn, split))

# ".end_caching" saves the decompressed file onto disks and yields the path to the file.
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)
Expand All @@ -83,4 +96,4 @@ def AmazonReviewPolarity(root, split):
data_dp = FileOpener(cache_decompressed_dp, mode="b")

# Finally, this parses content of the decompressed CSV file and returns the result line by line.
return data_dp.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:])))
return data_dp.parse_csv().map(_process_tuple)
43 changes: 35 additions & 8 deletions examples/text/examples_usage.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
"cell_type": "code",
"execution_count": 1,
"id": "ddf60620",
"metadata": {},
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"# print first n examples\n",
Expand All @@ -23,7 +27,10 @@
"execution_count": 2,
"id": "839c377f",
"metadata": {
"scrolled": false
"scrolled": false,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
Expand Down Expand Up @@ -54,7 +61,11 @@
"cell_type": "code",
"execution_count": 3,
"id": "2bd4a0f8",
"metadata": {},
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stdout",
Expand All @@ -78,17 +89,24 @@
"# print first batch with 2 examples\n",
"print_first_n_items(train_dp)\n",
"\n",
"\n",
"def _process_batch(batch):\n",
" return {'labels': [sample[0] for sample in batch], 'text': [sample[1].split() for sample in batch]}\n",
"\n",
"#Apply tokenization and create labels and text named batch\n",
"train_dp = train_dp.map(lambda batch: {'labels': [sample[0] for sample in batch],\\\n",
" 'text': [sample[1].split() for sample in batch]})\n",
"train_dp = train_dp.map(_process_batch)\n",
"print_first_n_items(train_dp)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b1401a57",
"metadata": {},
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stdout",
Expand Down Expand Up @@ -116,7 +134,11 @@
"cell_type": "code",
"execution_count": 5,
"id": "efe92627",
"metadata": {},
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stdout",
Expand All @@ -133,9 +155,14 @@
"train_dp = IMDB(split='train')\n",
"print_first_n_items(train_dp)\n",
"\n",
"\n",
"#convert label into integer using map\n",
"labels = {'neg':0,'pos':1}\n",
"train_dp = train_dp.map(lambda x: (labels[x[0]],x[1]))\n",
"\n",
"def _process_tuple(x):\n",
" return labels[x[0]], x[1]\n",
"\n",
"train_dp = train_dp.map(_process_tuple)\n",
"print_first_n_items(train_dp)"
]
}
Expand Down
23 changes: 17 additions & 6 deletions examples/text/imdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import os
from functools import partial
from pathlib import Path

from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper
Expand All @@ -25,6 +26,18 @@
DATASET_NAME = "IMDB"


def _path_fn(root, path):
return os.path.join(root, os.path.basename(path))


def _filter_fn(split, t):
return Path(t[0]).parts[-3] == split and Path(t[0]).parts[-2] in ["pos", "neg"]


def _file_to_sample(t):
return Path(t[0]).parts[-2], t[1].read().decode("utf-8")


@_add_docstring_header(num_lines=NUM_LINES, num_classes=2)
@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(("train", "test"))
Expand All @@ -39,8 +52,8 @@ def IMDB(root, split):
url_dp = IterableWrapper([URL])
# cache data on-disk
cache_dp = url_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, os.path.basename(x)),
hash_dict={os.path.join(root, os.path.basename(URL)): MD5},
filepath_fn=partial(_path_fn, root),
hash_dict={_path_fn(root, URL): MD5},
hash_type="md5",
)
cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True)
Expand All @@ -51,9 +64,7 @@ def IMDB(root, split):
extracted_files = cache_dp.load_from_tar()

# filter the files as applicable to create dataset for given split (train or test)
filter_files = extracted_files.filter(
lambda x: Path(x[0]).parts[-3] == split and Path(x[0]).parts[-2] in ["pos", "neg"]
)
filter_files = extracted_files.filter(partial(_filter_fn, split))

# map the file to yield proper data samples
return filter_files.map(lambda x: (Path(x[0]).parts[-2], x[1].read().decode("utf-8")))
return filter_files.map(_file_to_sample)
9 changes: 7 additions & 2 deletions examples/text/squad1.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import os
from functools import partial

from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper, IterDataPipe

Expand All @@ -29,6 +30,10 @@
DATASET_NAME = "SQuAD1"


def _path_fn(root, path):
return os.path.join(root, os.path.basename(path))


class _ParseSQuADQAData(IterDataPipe):
def __init__(self, source_datapipe) -> None:
self.source_datapipe = source_datapipe
Expand Down Expand Up @@ -60,8 +65,8 @@ def SQuAD1(root, split):
url_dp = IterableWrapper([URL[split]])
# cache data on-disk with sanity check
cache_dp = url_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, os.path.basename(x)),
hash_dict={os.path.join(root, os.path.basename(URL[split])): MD5[split]},
filepath_fn=partial(_path_fn, root),
hash_dict={_path_fn(root, URL[split]): MD5[split]},
hash_type="md5",
)
cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True)
Expand Down
Loading

0 comments on commit e276513

Please sign in to comment.