Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shell out to model handlers to collect byte sizes #28182

Merged
merged 2 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions sdks/python/apache_beam/ml/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,11 +477,19 @@ def run_inference(
return predictions

def get_num_bytes(self, batch: Sequence[Tuple[KeyT, ExampleT]]) -> int:
keys, unkeyed_batch = zip(*batch)
batch_bytes = len(pickle.dumps(keys))
if self._single_model:
keys, unkeyed_batch = zip(*batch)
return len(
pickle.dumps(keys)) + self._unkeyed.get_num_bytes(unkeyed_batch)
return len(pickle.dumps(batch))
return batch_bytes + self._unkeyed.get_num_bytes(unkeyed_batch)

batch_by_key = defaultdict(list)
for key, examples in batch:
batch_by_key[key].append(examples)

for key, examples in batch_by_key.items():
mh_id = self._key_to_id_map[key]
batch_bytes += self._id_to_mh_map[mh_id].get_num_bytes(examples)
return batch_bytes

def get_metrics_namespace(self) -> str:
if self._single_model:
Expand Down
25 changes: 25 additions & 0 deletions sdks/python/apache_beam/ml/inference/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,15 @@ def __init__(
max_batch_size=9999,
multi_process_shared=False,
state=None,
num_bytes_per_element=None,
**kwargs):
self._fake_clock = clock
self._min_batch_size = min_batch_size
self._max_batch_size = max_batch_size
self._env_vars = kwargs.get('env_vars', {})
self._multi_process_shared = multi_process_shared
self._state = state
self._num_bytes_per_element = num_bytes_per_element

def load_model(self):
if self._fake_clock:
Expand Down Expand Up @@ -113,6 +115,11 @@ def batch_elements_kwargs(self):
def share_model_across_processes(self):
return self._multi_process_shared

def get_num_bytes(self, batch: Sequence[int]) -> int:
if self._num_bytes_per_element:
return self._num_bytes_per_element * len(batch)
return super().get_num_bytes(batch)


class FakeModelHandlerReturnsPredictionResult(
base.ModelHandler[int, base.PredictionResult, FakeModel]):
Expand Down Expand Up @@ -319,6 +326,24 @@ def mult_two(example: str) -> int:
with self.assertRaises(ValueError):
base.KeyedModelHandler(mhs)

def test_keyed_model_handler_get_num_bytes(self):
mh = base.KeyedModelHandler(FakeModelHandler(num_bytes_per_element=10))
batch = [('key1', 1), ('key2', 2), ('key1', 3)]
expected = len(pickle.dumps(('key1', 'key2', 'key1'))) + 30
actual = mh.get_num_bytes(batch)
self.assertEqual(expected, actual)

def test_keyed_model_handler_multiple_models_get_num_bytes(self):
mhs = [
base.KeyMhMapping(['key1'], FakeModelHandler(num_bytes_per_element=10)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you still plan to change the name of KeyMhMapping?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'm waiting until I don't have in flight PRs around this to change it to avoid conflicts (right now #28026 uses KeyMhMapping)

base.KeyMhMapping(['key2'], FakeModelHandler(num_bytes_per_element=20))
]
mh = base.KeyedModelHandler(mhs)
batch = [('key1', 1), ('key2', 2), ('key1', 3)]
expected = len(pickle.dumps(('key1', 'key2', 'key1'))) + 40
actual = mh.get_num_bytes(batch)
self.assertEqual(expected, actual)

def test_run_inference_impl_with_maybe_keyed_examples(self):
with TestPipeline() as pipeline:
examples = [1, 5, 3, 10]
Expand Down