This repository has been archived by the owner on Nov 13, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #347 from rstudio/344-async-coroutine-bug
344 async coroutine bug
- Loading branch information
Showing
4 changed files
with
61 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from inspect import isawaitable | ||
from typing import Union, Awaitable | ||
|
||
from jupyter_server.services.contents.manager import ContentsManager | ||
|
||
|
||
async def get_model(manager: ContentsManager, path: str) -> dict: | ||
""" | ||
Gets the model via the ContentsManager. | ||
If the ContentsManager is async (e.g., AsyncContentsManager), then an await is issued. Otherwise, | ||
the model is returned under synchronous expectations. | ||
:param manager: A Jupyter ContentsManager | ||
:param path: The model path | ||
:return: The model | ||
""" | ||
model: Union[dict, Awaitable[dict]] = manager.get(path) | ||
if isawaitable(model): | ||
model = await model | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from unittest import TestCase | ||
from unittest.mock import Mock, MagicMock, AsyncMock | ||
|
||
from rsconnect_jupyter.managers import ContentsManager, get_model, isawaitable | ||
|
||
|
||
class GetModelTestCase(TestCase): | ||
async def test_synchronous(self): | ||
model = AsyncMock() | ||
manager = MagicMock(spec=ContentsManager) | ||
manager.get = Mock(return_value=model) | ||
path = "path" | ||
spy = Mock(wraps=isawaitable, return_value=False) | ||
res = await get_model(manager, path) | ||
self.assertEqual(res, model) | ||
model.assert_not_awaited() | ||
manager.get.assert_called_once_with(path) | ||
spy.assert_called_once_with(model) | ||
|
||
async def test_asynchronous(self): | ||
model = AsyncMock() | ||
manager = MagicMock(spec=ContentsManager) | ||
manager.get = Mock(return_value=model) | ||
path = "path" | ||
spy = Mock(wraps=isawaitable, return_value=True) | ||
res = await get_model(manager, path) | ||
self.assertEqual(res, model) | ||
model.assert_awaited() | ||
manager.get.assert_called_once_with(path) | ||
spy.assert_called_once_with(model) |