Skip to content

Commit

Permalink
Backport PR jupyter-server#1180: Only load enabled extension packages
Browse files Browse the repository at this point in the history
  • Loading branch information
minrk authored and gogasca committed Feb 6, 2023
1 parent ef9663e commit aa7420f
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 41 deletions.
76 changes: 37 additions & 39 deletions jupyter_server/extension/manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import importlib

from tornado.gen import multi
from traitlets import Any, Bool, Dict, HasTraits, Instance, Unicode, default, observe
from traitlets import Any, Bool, Dict, HasTraits, Instance, List, Unicode, default, observe
from traitlets import validate as validate_trait
from traitlets.config import LoggingConfigurable

Expand Down Expand Up @@ -156,54 +156,52 @@ class ExtensionPackage(HasTraits):
ext_name = "my_extensions"
extpkg = ExtensionPackage(name=ext_name)
"""

name = Unicode(help="Name of the an importable Python package.")
enabled = Bool(False).tag(config=True)

def __init__(self, *args, **kwargs):
# Store extension points that have been linked.
self._linked_points = {}
super().__init__(*args, **kwargs)
enabled = Bool(False, help="Whether the extension package is enabled.")

_linked_points: dict = {}
_linked_points = Dict()
extension_points = Dict()
module = Any(allow_none=True, help="The module for this extension package. None if not enabled")
metadata = List(Dict(), help="Extension metadata loaded from the extension package.")
version = Unicode(
help="""
The version of this extension package, if it can be found.
Otherwise, an empty string.
""",
)

@validate_trait("name")
def _validate_name(self, proposed):
name = proposed["value"]
self._extension_points = {}
@default("version")
def _load_version(self):
if not self.enabled:
return ""
return getattr(self.module, "__version__", "")

def __init__(self, **kwargs):
"""Initialize an extension package."""
super().__init__(**kwargs)
if self.enabled:
self._load_metadata()

def _load_metadata(self):
"""Import package and load metadata
Only used if extension package is enabled
"""
name = self.name
try:
self._module, self._metadata = get_metadata(name)
self.module, self.metadata = get_metadata(name, logger=self.log)
except ImportError as e:
raise ExtensionModuleNotFound(
"The module '{name}' could not be found ({e}). Are you "
"sure the extension is installed?".format(name=name, e=e)
msg = (
f"The module '{name}' could not be found ({e}). Are you "
"sure the extension is installed?"
)
raise ExtensionModuleNotFound(msg) from None
# Create extension point interfaces for each extension path.
for m in self._metadata:
for m in self.metadata:
point = ExtensionPoint(metadata=m)
self._extension_points[point.name] = point
self.extension_points[point.name] = point
return name

@property
def module(self):
"""Extension metadata loaded from the extension package."""
return self._module

@property
def version(self):
"""Get the version of this package, if it's given. Otherwise, return an empty string"""
return getattr(self._module, "__version__", "")

@property
def metadata(self):
"""Extension metadata loaded from the extension package."""
return self._metadata

@property
def extension_points(self):
"""A dictionary of extension points."""
return self._extension_points

def validate(self):
"""Validate all extension points in this package."""
for extension in self.extension_points.values():
Expand Down
27 changes: 25 additions & 2 deletions tests/extension/test_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import sys
import unittest.mock as mock

import pytest
Expand Down Expand Up @@ -60,7 +61,7 @@ def test_extension_package_api():
path1 = metadata_list[0]
app = path1["app"]

e = ExtensionPackage(name="tests.extension.mockextensions")
e = ExtensionPackage(name="tests.extension.mockextensions", enabled=True)
e.extension_points
assert hasattr(e, "extension_points")
assert len(e.extension_points) == len(metadata_list)
Expand All @@ -70,7 +71,9 @@ def test_extension_package_api():

def test_extension_package_notfound_error():
with pytest.raises(ExtensionModuleNotFound):
ExtensionPackage(name="nonexistent")
ExtensionPackage(name="nonexistent", enabled=True)
# no raise if not enabled
ExtensionPackage(name="nonexistent", enabled=False)


def _normalize_path(path_list):
Expand Down Expand Up @@ -132,3 +135,23 @@ def test_extension_manager_fail_load(jp_serverapp):
jp_serverapp.reraise_server_extension_failures = True
with pytest.raises(RuntimeError):
manager.load_extension(name)


@pytest.mark.parametrize("has_app", [True, False])
def test_disable_no_import(jp_serverapp, has_app):
# de-import modules so we can detect if they are re-imported
disabled_ext = "tests.extension.mockextensions.mock1"
enabled_ext = "tests.extension.mockextensions.mock2"
sys.modules.pop(disabled_ext, None)
sys.modules.pop(enabled_ext, None)

manager = ExtensionManager(serverapp=jp_serverapp if has_app else None)
manager.add_extension(disabled_ext, enabled=False)
manager.add_extension(enabled_ext, enabled=True)
assert disabled_ext not in sys.modules
assert enabled_ext in sys.modules

ext_pkg = manager.extensions[disabled_ext]
assert ext_pkg.extension_points == {}
assert ext_pkg.version == ""
assert ext_pkg.metadata == []

0 comments on commit aa7420f

Please sign in to comment.