Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Break up task edge registration #238

Merged
merged 10 commits into from
Apr 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions changes/pr238.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# An example changelog entry
#
# 1. Choose one (or more if a PR encompasses multiple changes) of the following headers:
# - feature
# - enhancement
# - fix
# - deprecation
# - breaking (for breaking changes)
# - migration (for database migrations)
#
# 2. Fill in one (or more) bullet points under the heading, describing the change.
# Markdown syntax may be used.
#
# 3. If you would like to be credited as helping with this release, add a
# contributor section with your name and github username.
#
# Here's an example of a PR that adds an enhancement

enhancement:
- "Batch task and edge insertion during flow creation - [#238](https://github.com/PrefectHQ/server/pull/238)"
84 changes: 55 additions & 29 deletions src/prefect_server/api/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

from prefect_server import config
from prefect_server.utilities import logging
from prefect_server.utilities.collections import chunked_iterable
from prefect_server.utilities.exceptions import APIError

logger = logging.get_logger("api.flows")
schedule_schema = ScheduleSchema()
Expand Down Expand Up @@ -244,6 +246,10 @@ async def create_flow(
set_schedule_active = False

# precompute task ids to make edges easy to add to database

# create the flow without tasks or edges initially
# then insert tasks and edges in batches, so we don't exceed Postgres limits
# https://doxygen.postgresql.org/fe-exec_8c_source.html line 1409
flow_id = await models.Flow(
tenant_id=tenant_id,
project_id=project_id,
Expand All @@ -261,37 +267,57 @@ async def create_flow(
description=description,
schedule=serialized_flow.get("schedule"),
is_schedule_active=False,
tasks=[
models.Task(
id=t.id,
tenant_id=tenant_id,
name=t.name,
slug=t.slug,
type=t.type,
max_retries=t.max_retries,
tags=t.tags,
retry_delay=t.retry_delay,
trigger=t.trigger,
mapped=t.mapped,
auto_generated=t.auto_generated,
cache_key=t.cache_key,
is_reference_task=t.is_reference_task,
is_root_task=t.is_root_task,
is_terminal_task=t.is_terminal_task,
tasks=[],
edges=[],
).insert()

try:
batch_insertion_size = 2500

for tasks_chunk in chunked_iterable(flow.tasks, batch_insertion_size):
await models.Task.insert_many(
[
models.Task(
id=t.id,
flow_id=flow_id,
tenant_id=tenant_id,
name=t.name,
slug=t.slug,
type=t.type,
max_retries=t.max_retries,
tags=t.tags,
retry_delay=t.retry_delay,
trigger=t.trigger,
mapped=t.mapped,
auto_generated=t.auto_generated,
cache_key=t.cache_key,
is_reference_task=t.is_reference_task,
is_root_task=t.is_root_task,
is_terminal_task=t.is_terminal_task,
)
for t in tasks_chunk
]
)
for t in flow.tasks
],
edges=[
models.Edge(
tenant_id=tenant_id,
upstream_task_id=task_lookup[e.upstream_task].id,
downstream_task_id=task_lookup[e.downstream_task].id,
key=e.key,
mapped=e.mapped,

for edges_chunk in chunked_iterable(flow.edges, batch_insertion_size):
await models.Edge.insert_many(
[
models.Edge(
tenant_id=tenant_id,
flow_id=flow_id,
upstream_task_id=task_lookup[e.upstream_task].id,
downstream_task_id=task_lookup[e.downstream_task].id,
key=e.key,
mapped=e.mapped,
)
for e in edges_chunk
]
)
for e in flow.edges
],
).insert()

except Exception as exc:
logger.error("`create_flow` failed during insertion", exc_info=True)
await api.flows.delete_flow(flow_id=flow_id)
raise APIError() from exc

# schedule runs
if set_schedule_active:
Expand Down
21 changes: 21 additions & 0 deletions src/prefect_server/utilities/collections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import itertools
from typing import Iterable


def chunked_iterable(iterable: Iterable, size: int):
"""
Yield chunks of a certain size from an iterable

Args:
- iterable (Iterable): An iterable
- size (int): The size of chunks to return

Yields:
tuple: A chunk of the iterable
"""
it = iter(iterable)
while True:
chunk = tuple(itertools.islice(it, size))
if not chunk:
break
yield chunk
18 changes: 18 additions & 0 deletions tests/api/test_flows.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from unittest.mock import MagicMock, patch

import datetime
import uuid

Expand Down Expand Up @@ -514,6 +516,22 @@ async def test_create_flow_persists_serialized_flow(self, project_id, flow):
# confirm the keys in the serialized flow match the form we'd expect
assert persisted_flow.serialized_flow == flow.serialize()

@pytest.mark.parametrize("model_name", ["Task", "Edge"])
async def test_create_flow_cleans_up_if_task_or_edge_creation_fails(
self, project_id, flow, monkeypatch, model_name
):
patched_insert_raises_error = MagicMock(side_effect=Exception())
monkeypatch.setattr(
f"prefect.models.{model_name}.insert_many", patched_insert_raises_error
)
flow.name = "my special flow"
zangell44 marked this conversation as resolved.
Show resolved Hide resolved
assert await models.Flow.where({"name": {"_eq": flow.name}}).count() == 0
with pytest.raises(Exception):
flow_id = await api.flows.create_flow(
project_id=project_id, serialized_flow=flow.serialize()
)
assert await models.Flow.where({"name": {"_eq": flow.name}}).count() == 0


class TestCreateFlowVersions:
async def test_create_flow_assigns_random_version_group_id(self, project_id, flow):
Expand Down
12 changes: 12 additions & 0 deletions tests/utilities/test_collections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from prefect_server.utilities.collections import chunked_iterable


def test_chunked_iterable_of_list():
chunks = [chunk for chunk in chunked_iterable(list(range(10)), 4)]
expected_chunks = [(0, 1, 2, 3), (4, 5, 6, 7), (8, 9)]
assert chunks == expected_chunks


def test_chunked_iterable_of_empty_iterable():
chunks = [chunk for chunk in chunked_iterable([], 4)]
assert len(chunks) == 0