Skip to content

Commit

Permalink
Merge pull request #3 from strollby/main
Browse files Browse the repository at this point in the history
Federation v2 Support
  • Loading branch information
arunsureshkumar authored Oct 31, 2022
2 parents e865642 + 053e0e7 commit 14de45b
Show file tree
Hide file tree
Showing 16 changed files with 191 additions and 109 deletions.
6 changes: 3 additions & 3 deletions examples/inaccessible.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ class Position(graphene.ObjectType):
x = graphene.Int(required=True)
y = external(graphene.Int(required=True))
z = inaccessible(graphene.Int(required=True))
a = provides(graphene.Int(required=True), fields="x")
b = override(graphene.Int(required=True), _from="h")
a = provides(graphene.Int(), fields="x")
b = override(graphene.Int(required=True), from_="h")


class Query(graphene.ObjectType):
position = graphene.Field(Position)


schema = build_schema(Query)
schema = build_schema(Query, enable_federation_2=True)

query = '''
query getSDL {
Expand Down
2 changes: 1 addition & 1 deletion examples/override.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Query(graphene.ObjectType):
position = graphene.Field(Product)


schema = build_schema(Query)
schema = build_schema(Query, enable_federation_2=True)

query = '''
query getSDL {
Expand Down
2 changes: 1 addition & 1 deletion examples/shareable.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Query(graphene.ObjectType):
position = graphene.Field(Position)


schema = build_schema(Query)
schema = build_schema(Query, enable_federation_2=True)

query = '''
query getSDL {
Expand Down
2 changes: 1 addition & 1 deletion examples/tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class Query(graphene.ObjectType):
position = graphene.Field(Product)


schema = build_schema(Query)
schema = build_schema(Query, enable_federation_2=True)

query = '''
query getSDL {
Expand Down
3 changes: 2 additions & 1 deletion graphene_federation/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def resolve_entities(self, info, representations):
return EntityQuery


def key(fields: str) -> Callable:
def key(fields: str, resolvable: bool = True) -> Callable:
"""
Take as input a field that should be used as key for that entity.
See specification: https://www.apollographql.com/docs/federation/federation-spec/#key
Expand All @@ -99,6 +99,7 @@ def decorator(Type):
keys = getattr(Type, "_keys", [])
keys.append(fields)
setattr(Type, "_keys", keys)
setattr(Type, "_resolvable", resolvable)

return Type

Expand Down
26 changes: 9 additions & 17 deletions graphene_federation/inaccessible.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

def get_inaccessible_types(schema: Schema) -> dict[str, Any]:
"""
Find all the extended types from the schema.
Find all the inaccessible types from the schema.
They can be easily distinguished from the other type as
the `@extend` decorator adds a `_extended` attribute to them.
the `@inaccessible` decorator adds a `_inaccessible` attribute to them.
"""
inaccessible_types = {}
for type_name, type_ in schema.graphql_schema.type_map.items():
Expand All @@ -26,16 +26,8 @@ def inaccessible(field: Optional[Any] = None) -> Any:

# noinspection PyProtectedMember,PyPep8Naming
def decorator(Type):
assert not hasattr(
Type, "_keys"
), "Can't extend type which is already extended or has @key"
# Check the provided fields actually exist on the Type.
assert getattr(Type._meta, "description", None) is None, (
f'Type "{Type.__name__}" has a non empty description and it is also marked with extend.'
"\nThey are mutually exclusive."
"\nSee https://github.com/graphql/graphql-js/issues/2385#issuecomment-577997521"
)
# Set a `_extended` attribute to be able to distinguish it from the other entities
# TODO Check the provided fields actually exist on the Type.
# Set a `_inaccessible` attribute to be able to distinguish it from the other entities
setattr(Type, "_inaccessible", True)
return Type

Expand All @@ -47,16 +39,16 @@ def decorator(Type):

def get_inaccessible_fields(schema: Schema) -> dict:
"""
Find all the extended types from the schema.
Find all the inacessible types from the schema.
They can be easily distinguished from the other type as
the `@_tag` decorator adds a `_tag` attribute to them.
the `@inaccessible` decorator adds a `_inaccessible` attribute to them.
"""
shareable_types = {}
inaccessible_types = {}
for type_name, type_ in schema.graphql_schema.type_map.items():
if not hasattr(type_, "graphene_type"):
continue
for field in list(type_.graphene_type.__dict__):
if getattr(getattr(type_.graphene_type, field), "_inaccessible", False):
shareable_types[type_name] = type_.graphene_type
inaccessible_types[type_name] = type_.graphene_type
continue
return shareable_types
return inaccessible_types
6 changes: 5 additions & 1 deletion graphene_federation/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@ def _get_query(schema: Schema, query_cls: Optional[ObjectType] = None) -> Object


def build_schema(
query: Optional[ObjectType] = None, mutation: Optional[ObjectType] = None, **kwargs
query: Optional[ObjectType] = None,
mutation: Optional[ObjectType] = None,
enable_federation_2=False,
**kwargs
) -> Schema:
schema = Schema(query=query, mutation=mutation, **kwargs)
schema.auto_camelcase = kwargs.get("auto_camelcase", True)
schema.federation_version = 2 if enable_federation_2 else 1
federation_query = _get_query(schema, query)
return Schema(query=federation_query, mutation=mutation, **kwargs)
8 changes: 4 additions & 4 deletions graphene_federation/override.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from graphene import Schema


def override(field, _from: str):
def override(field, from_: str):
"""
Decorator to use to override a given type.
"""
field._override = _from
field._override = from_
return field


def get_override_fields(schema: Schema) -> dict:
"""
Find all the extended types from the schema.
Find all the overridden types from the schema.
They can be easily distinguished from the other type as
the `@tag` decorator adds a `_tag` attribute to them.
the `@override` decorator adds a `_override` attribute to them.
"""
override_fields = {}
for type_name, type_ in schema.graphql_schema.type_map.items():
Expand Down
2 changes: 1 addition & 1 deletion graphene_federation/requires.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def requires(field, fields: Union[str, list[str]]):

def get_required_fields(schema: Schema) -> dict:
"""
Find all the extended types from the schema.
Find all the extended types with required fields from the schema.
They can be easily distinguished from the other type as
the `@requires` decorator adds a `_requires` attribute to them.
"""
Expand Down
145 changes: 86 additions & 59 deletions graphene_federation/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ def convert_fields(schema: Schema, fields: list[str]) -> str:


DECORATORS = {
"_external": lambda _schema, _fields: "@external",
"_external": lambda schema, fields: "@external",
"_requires": lambda schema, fields: f'@requires(fields: "{convert_fields(schema, fields)}")',
"_provides": lambda schema, fields: f'@provides(fields: "{convert_fields(schema, fields)}")',
"_shareable": lambda _schema, _fields: "@shareable",
"_inaccessible": lambda _schema, _fields: "@inaccessible",
"_override": lambda schema, _from: f'@override(from: "{_from}")',
"_shareable": lambda schema, fields: "@shareable",
"_inaccessible": lambda schema, fields: "@inaccessible",
"_override": lambda schema, from_: f'@override(from: "{from_}")',
"_tag": lambda schema, name: f'@tag(name: "{name}")',
}

Expand Down Expand Up @@ -101,87 +101,114 @@ def get_sdl(schema: Schema) -> str:

# Get various objects that need to be amended
extended_types = get_extended_types(schema)
shareable_types = get_shareable_types(schema)
inaccessible_types = get_inaccessible_types(schema)
provides_parent_types = get_provides_parent_types(schema)
provides_fields = get_provides_fields(schema)
entities = get_entities(schema)
shareable_fields = get_shareable_fields(schema)
tagged_fields = get_tagged_fields(schema)
inaccessible_fields = get_inaccessible_fields(schema)
required_fields = get_required_fields(schema)
external_fields = get_external_fields(schema)
override_fields = get_override_fields(schema)

_schema_import = []

if extended_types:
_schema_import.append('"@extends"')
if external_fields:
_schema_import.append('"@external"')
if entities:
_schema_import.append('"@key"')
if inaccessible_types or inaccessible_fields:
_schema_import.append('"@inaccessible"')
if override_fields:
_schema_import.append('"@override"')
if provides_parent_types or provides_fields:
_schema_import.append('"@provides"')
if required_fields:
_schema_import.append('"@requires"')
if shareable_types or shareable_fields:
_schema_import.append('"@shareable"')
if tagged_fields:
_schema_import.append('"@tag"')
schema_import = ", ".join(_schema_import)
_schema = f'extend schema @link(url: "https://specs.apollo.dev/federation/v2.0", import: [{schema_import}])\n'
_schema = ""

if schema.federation_version == 2:
shareable_types = get_shareable_types(schema)
inaccessible_types = get_inaccessible_types(schema)
shareable_fields = get_shareable_fields(schema)
tagged_fields = get_tagged_fields(schema)
inaccessible_fields = get_inaccessible_fields(schema)

_schema_import = []

if extended_types:
_schema_import.append('"@extends"')
if external_fields:
_schema_import.append('"@external"')
if entities:
_schema_import.append('"@key"')
if override_fields:
_schema_import.append('"@override"')
if provides_parent_types or provides_fields:
_schema_import.append('"@provides"')
if required_fields:
_schema_import.append('"@requires"')
if inaccessible_types or inaccessible_fields:
_schema_import.append('"@inaccessible"')
if shareable_types or shareable_fields:
_schema_import.append('"@shareable"')
if tagged_fields:
_schema_import.append('"@tag"')
schema_import = ", ".join(_schema_import)
_schema = f'extend schema @link(url: "https://specs.apollo.dev/federation/v2.0", import: [{schema_import}])\n'

# Add fields directives (@external, @provides, @requires, @shareable, @inaccessible)
for entity in (
entities_ = (
set(provides_parent_types.values())
| set(extended_types.values())
| set(shareable_types.values())
| set(inaccessible_types.values())
| set(entities.values())
| set(inaccessible_fields.values())
| set(shareable_fields.values())
| set(tagged_fields.values())
| set(required_fields.values())
| set(provides_fields.values())
):
)

if schema.federation_version == 2:
entities_ = (
entities_
| set(shareable_types.values())
| set(inaccessible_types.values())
| set(inaccessible_fields.values())
| set(shareable_fields.values())
| set(tagged_fields.values())
)
for entity in entities_:
string_schema = add_entity_fields_decorators(entity, schema, string_schema)

# Prepend `extend` keyword to the type definition of extended types
# noinspection DuplicatedCode
for entity_name, entity in extended_types.items():
type_def = re.compile(r"type %s ([^\{]*)" % entity_name)
repl_str = r"extend type %s \1" % entity_name
type_def = re.compile(rf"type {entity_name} ([^{{]*)")
repl_str = rf"extend type {entity_name} \1"
string_schema = type_def.sub(repl_str, string_schema)

# Add entity keys declarations
get_field_name = type_attribute_to_field_name(schema)
for entity_name, entity in entities.items():
type_def_re = r"(type %s [^\{]*)" % entity_name + " "
type_annotation = (
" ".join([f'@key(fields: "{get_field_name(key)}")' for key in entity._keys])
+ " "
)
repl_str = r"\1%s" % type_annotation
type_def_re = rf"(type {entity_name} [^\{{]*)" + " "
if hasattr(entity, "_resolvable") and not entity._resolvable:
type_annotation = (
(
" ".join(
[
f'@key(fields: "{get_field_name(key)}"'
for key in entity._keys
]
)
)
+ f", resolvable: {str(entity._resolvable).lower()})"
+ " "
)
else:
type_annotation = (
" ".join(
[f'@key(fields: "{get_field_name(key)}")' for key in entity._keys]
)
) + " "
repl_str = rf"\1{type_annotation}"
pattern = re.compile(type_def_re)
string_schema = pattern.sub(repl_str, string_schema)

for type_name, type in shareable_types.items():
type_def_re = r"(type %s [^\{]*)" % type_name + " "
type_annotation = f" @shareable"
repl_str = r"\1%s " % type_annotation
pattern = re.compile(type_def_re)
string_schema = pattern.sub(repl_str, string_schema)

for type_name, type in inaccessible_types.items():
type_def_re = r"(type %s [^\{]*)" % type_name + " "
type_annotation = f" @inaccessible"
repl_str = r"\1%s " % type_annotation
pattern = re.compile(type_def_re)
string_schema = pattern.sub(repl_str, string_schema)
if schema.federation_version == 2:
for type_name, type in shareable_types.items():
type_def_re = rf"(type {type_name} [^\{{]*)" + " "
type_annotation = " @shareable"
repl_str = rf"\1{type_annotation} "
pattern = re.compile(type_def_re)
string_schema = pattern.sub(repl_str, string_schema)

for type_name, type in inaccessible_types.items():
type_def_re = rf"(type {type_name} [^\{{]*)" + " "
type_annotation = " @inaccessible"
repl_str = rf"\1{type_annotation} "
pattern = re.compile(type_def_re)
string_schema = pattern.sub(repl_str, string_schema)

return _schema + string_schema

Expand Down
2 changes: 1 addition & 1 deletion graphene_federation/tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def tag(field, name: str):

def get_tagged_fields(schema: Schema) -> dict:
"""
Find all the extended types from the schema.
Find all the tagged types from the schema.
They can be easily distinguished from the other type as
the `@tag` decorator adds a `_tag` attribute to them.
"""
Expand Down
Loading

0 comments on commit 14de45b

Please sign in to comment.