Skip to content

Commit

Permalink
add map_samples() implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
brimoor committed Feb 9, 2025
1 parent 21468df commit cd2da41
Show file tree
Hide file tree
Showing 6 changed files with 453 additions and 2 deletions.
6 changes: 6 additions & 0 deletions docs/source/user_guide/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ FiftyOne supports the configuration options described below:
| `logging_level` | `FIFTYONE_LOGGING_LEVEL` | `INFO` | Controls FiftyOne's package-wide logging level. Can be any valid ``logging`` level as |
| | | | a string: ``DEBUG, INFO, WARNING, ERROR, CRITICAL``. |
+-------------------------------+---------------------------------------+-------------------------------+----------------------------------------------------------------------------------------+
| `default_map_workers` | `FIFTYONE_DEFAULT_MAP_WORKERS` | `None` | The default number of worker processes to use when |
| | | | :meth:`map_samples() <fiftyone.core.collections.SampleCollection.map_samples>` is |
| | | | called. |
+-------------------------------+---------------------------------------+-------------------------------+----------------------------------------------------------------------------------------+
| `max_thread_pool_workers` | `FIFTYONE_MAX_THREAD_POOL_WORKERS` | `None` | An optional maximum number of workers to use when creating thread pools |
+-------------------------------+---------------------------------------+-------------------------------+----------------------------------------------------------------------------------------+
| `max_process_pool_workers` | `FIFTYONE_MAX_PROCESS_POOL_WORKERS` | `None` | An optional maximum number of workers to use when creating process pools |
Expand Down Expand Up @@ -169,6 +173,7 @@ and the CLI:
"default_batcher": "latency",
"default_dataset_dir": "~/fiftyone",
"default_image_ext": ".jpg",
"default_map_workers": null,
"default_ml_backend": "torch",
"default_sequence_idx": "%06d",
"default_video_ext": ".mp4",
Expand Down Expand Up @@ -219,6 +224,7 @@ and the CLI:
"default_batcher": "latency",
"default_dataset_dir": "~/fiftyone",
"default_image_ext": ".jpg",
"default_map_workers": null,
"default_ml_backend": "torch",
"default_sequence_idx": "%06d",
"default_video_ext": ".mp4",
Expand Down
137 changes: 137 additions & 0 deletions fiftyone/core/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import eta.core.serial as etas
import eta.core.utils as etau

import fiftyone as fo
import fiftyone.core.aggregations as foa
import fiftyone.core.annotation as foan
import fiftyone.core.brain as fob
Expand All @@ -49,6 +50,7 @@
foua = fou.lazy_import("fiftyone.utils.annotations")
foud = fou.lazy_import("fiftyone.utils.data")
foue = fou.lazy_import("fiftyone.utils.eval")
foum = fou.lazy_import("fiftyone.utils.multiprocessing")
foos = fou.lazy_import("fiftyone.operators.store")


Expand Down Expand Up @@ -3212,6 +3214,141 @@ def _set_labels(self, field_name, sample_ids, label_docs, progress=False):
def _delete_labels(self, ids, fields=None):
self._dataset.delete_labels(ids=ids, fields=fields)

def map_samples(
self,
map_fcn,
reduce_fcn=None,
save=None,
num_workers=None,
shard_size=None,
shard_method="id",
progress=None,
):
"""Applies the given function to each sample in the collection.
By default, a multiprocessing pool is used to parallelize the work.
When only a ``map_fcn`` is provided, this function effectively performs
the following map operation with the outer loop in parallel::
for batch_view in fou.iter_batches(sample_collection, shard_size):
for sample in batch_view.iter_samples(autosave=True):
map_fcn(sample)
When a ``reduce_fcn`` is provided, this function effectively performs
the following map-reduce operation with the outer loop in parallel::
values = {}
for batch_view in fou.iter_batches(sample_collection, shard_size):
for sample in batch_view.iter_samples(autosave=save):
values[sample.id] = map_fcn(sample)
output = reduce_fcn(sample_collection, values)
Example::
import fiftyone as fo
import fiftyone.zoo as foz
dataset = foz.load_zoo_dataset("cifar10", split="train")
view = dataset.select_fields("ground_truth")
#
# Example 1: map
#
def map_fcn(sample):
sample.ground_truth.label = sample.ground_truth.label.upper()
view.map_samples(map_fcn)
print(dataset.count_values("ground_truth.label"))
#
# Example 2: map-reduce
#
def map_fcn(sample):
return sample.ground_truth.label.lower()
def reduce_fcn(sample_collection, values):
from collections import Counter
return dict(Counter(values.values()))
counts = view.map_samples(map_fcn, reduce_fcn=reduce_fcn)
print(counts)
Args:
map_fcn: a function to apply to each sample in the collection
reduce_fcn (None): an optional function to reduce the map outputs.
See the docstring above for usage information
save (None): whether to save any sample edits applied by
``map_fcn``. By default this is True when no ``reduce_fcn`` is
provided and False when a ``reduce_fcn`` is provided
num_workers (None): the number of workers to use. By default,
``fiftyone.config.default_map_workers`` workers are used if
this value is set, else
:meth:`fiftyone.core.utils.recommend_process_pool_workers`
workers are used. If this value is <= 1, all work is done in
the main process
shard_size (None): an optional number of samples to distribute to
each worker at a time. By default, samples are evenly
distributed to workers with one shard per worker
shard_method ("id"): whether to use IDs (``"id"``) or slices
(``"slice"``) to assign samples to workers
progress (None): whether to render a progress bar for each worker
(True/False), use the default value
``fiftyone.config.show_progress_bars`` (None), or "global" to
render a single global progress bar, or a progress callback
function to invoke instead
Returns:
the output of ``reduce_fcn``, if provided, else None
"""
if num_workers is None:
num_workers = fo.config.default_map_workers

if num_workers is not None and num_workers <= 1:
return self._map_samples_single(
map_fcn,
reduce_fcn=reduce_fcn,
save=save,
progress=progress,
)

return foum.map_samples(
self,
map_fcn,
reduce_fcn=reduce_fcn,
save=save,
num_workers=num_workers,
shard_size=shard_size,
shard_method=shard_method,
progress=progress,
)

def _map_samples_single(
self,
map_fcn,
reduce_fcn=None,
save=None,
progress=False,
):
if save is None:
save = reduce_fcn is None

if progress == "global":
progress = True

outputs = {}
for sample in self.iter_samples(autosave=save, progress=progress):
output = map_fcn(sample)
if reduce_fcn is not None:
outputs[sample.id] = output

if reduce_fcn is not None:
return reduce_fcn(self, outputs)

def compute_metadata(
self,
overwrite=False,
Expand Down
6 changes: 6 additions & 0 deletions fiftyone/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,12 @@ def __init__(self, d=None):
self.timezone = self.parse_string(
d, "timezone", env_var="FIFTYONE_TIMEZONE", default=None
)
self.default_map_workers = self.parse_int(
d,
"default_map_workers",
env_var="FIFTYONE_DEFAULT_MAP_WORKERS",
default=None,
)
self.max_thread_pool_workers = self.parse_int(
d,
"max_thread_pool_workers",
Expand Down
11 changes: 9 additions & 2 deletions fiftyone/core/odm/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import asyncio
from bson import json_util, ObjectId
from bson.codec_options import CodecOptions
from mongoengine import connect
import mongoengine
import motor.motor_asyncio as mtr

from packaging.version import Version
Expand Down Expand Up @@ -219,7 +219,7 @@ def establish_db_conn(config):
# Register cleanup method
atexit.register(_delete_non_persistent_datasets_if_allowed)

connect(config.database_name, **_connection_kwargs)
mongoengine.connect(config.database_name, **_connection_kwargs)

db_config = get_db_config()
if db_config.type != foc.CLIENT_TYPE:
Expand All @@ -244,6 +244,13 @@ def _connect():
establish_db_conn(fo.config)


def _disconnect():
global _client, _async_client
_client = None
_async_client = None
mongoengine.disconnect_all()


def _async_connect(use_global=False):
# Regular connect here first, to ensure connection kwargs are established
# for below.
Expand Down
Loading

0 comments on commit cd2da41

Please sign in to comment.