Skip to content

Commit

Permalink
Merge pull request #5 from Checho3388/allow-introspection
Browse files Browse the repository at this point in the history
Allow introspection
  • Loading branch information
Checho3388 committed Feb 26, 2024
2 parents 76f981f + 1ee69d8 commit 372a3d7
Show file tree
Hide file tree
Showing 13 changed files with 315 additions and 192 deletions.
19 changes: 17 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# graphql-complexity
Python library to compute the complexity of a GraphQL operation

![Unit Tests](https://github.com/Checho3388/graphql-complexity/actions/workflows/python-package.yml/badge.svg)
![Build](https://github.com/Checho3388/graphql-complexity/actions/workflows/python-buildlu.yml/badge.svg)
[![PyPI](https://img.shields.io/pypi/v/graphql-complexity?label=pypi%20package)](https://pypi.org/project/graphql-complexity/)
[![codecov](https://codecov.io/gh/Checho3388/graphql-complexity/graph/badge.svg?token=4LH7AVN119)](https://codecov.io/gh/Checho3388/graphql-complexity)

## Installation (Quick Start)
The library can be installed using pip:
Expand All @@ -18,6 +19,18 @@ pip install graphql-complexity[strawberry-graphql]
Create a file named `complexity.py` with the following content:
```python
from graphql_complexity import (get_complexity, SimpleEstimator)
from graphql import build_schema


schema = build_schema("""
type User {
id: ID!
name: String!
}
type Query {
user: User
}
""")

query = """
query SomeQuery {
Expand All @@ -30,6 +43,7 @@ query = """

complexity = get_complexity(
query=query,
schema=schema,
estimator=SimpleEstimator(complexity=1, multiplier=10)
)
if complexity > 10:
Expand Down Expand Up @@ -163,6 +177,7 @@ extension that can be added to the schema.

```python
import strawberry
from graphql_complexity.estimators import SimpleEstimator
from graphql_complexity.extensions.strawberry_graphql import build_complexity_extension


Expand All @@ -172,7 +187,7 @@ class Query:
def hello_world(self) -> str:
return "Hello world!"

extension = build_complexity_extension()
extension = build_complexity_extension(estimator=SimpleEstimator())
schema = strawberry.Schema(query=Query, extensions=[extension])

schema.execute_sync("query { helloWorld }")
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "graphql_complexity"
version = "0.1.0"
version = "0.2.0"
description = "A python library that provides complexity calculation helpers for GraphQL"
authors = ["Checho3388 <[email protected]>"]
packages = [
Expand Down
4 changes: 3 additions & 1 deletion src/graphql_complexity/estimators/directive.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

from graphql import Visitor, parse, visit

from graphql_complexity.estimators.base import ComplexityEstimator
Expand Down Expand Up @@ -63,7 +65,7 @@ def __init__(

@staticmethod
def collect_from_schema(schema: str, directive_name: str) -> dict[str, int]:
collector = {}
collector: dict[str, Any] = {}
ast = parse(schema)
visitor = DirectivesVisitor(collector=collector, directive_name=directive_name)
visit(ast, visitor)
Expand Down
6 changes: 5 additions & 1 deletion src/graphql_complexity/evaluator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from .complexity import get_complexity
from .complexity import (
get_ast_complexity,
get_complexity,
)

__all__ = [
'get_complexity',
'get_ast_complexity'
]
16 changes: 9 additions & 7 deletions src/graphql_complexity/evaluator/complexity.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from graphql import parse, visit
from graphql import parse, visit, TypeInfo, TypeInfoVisitor, GraphQLSchema

from ..estimators import ComplexityEstimator
from .visitor import ComplexityVisitor
from ..estimators import ComplexityEstimator


def get_complexity(query: str, estimator: ComplexityEstimator) -> int:
def get_complexity(query: str, schema: GraphQLSchema, estimator: ComplexityEstimator) -> int:
"""Calculate the complexity of a query using the provided estimator."""
ast = parse(query)
return get_ast_complexity(ast, estimator=estimator)
return get_ast_complexity(ast, schema=schema, estimator=estimator)


def get_ast_complexity(ast, estimator: ComplexityEstimator) -> int:
def get_ast_complexity(ast, schema: GraphQLSchema, estimator: ComplexityEstimator) -> int:
"""Calculate the complexity of a query using the provided estimator."""
visitor = ComplexityVisitor(estimator=estimator)
visit(ast, visitor)
type_info = TypeInfo(schema)

visitor = ComplexityVisitor(estimator=estimator, type_info=type_info)
visit(ast, TypeInfoVisitor(type_info, visitor))

return visitor.evaluate()
35 changes: 17 additions & 18 deletions src/graphql_complexity/evaluator/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
GraphQLSkipDirective,
OperationDefinitionNode,
VariableNode,
Visitor
Visitor, is_introspection_type, get_named_type, TypeInfo
)

from graphql_complexity.estimators.base import ComplexityEstimator
Expand Down Expand Up @@ -61,25 +61,14 @@ class ComplexityVisitor(Visitor):
The complexity of the operations is calculated by summing the complexity of
the fields in the operation.
Examples:
>>> from graphql import parse, visit
>>> from graphql_complexity.estimators import SimpleEstimator
>>> from graphql_complexity.visitor import ComplexityVisitor
>>> query = "query { user { name email } }"
>>> ast = parse(query)
>>> visitor = ComplexityVisitor(estimator=SimpleEstimator())
>>> exc = visit(ast, visitor)
>>> visitor.evaluate()
3
"""

def __init__(self, estimator: ComplexityEstimator, variables=None):
def __init__(self, estimator: ComplexityEstimator, type_info: TypeInfo, variables: dict[str, Any] | None = None):
if not isinstance(estimator, ComplexityEstimator):
raise ValueError("Estimator must be of type 'ComplexityEstimator'")
self.estimator: ComplexityEstimator = estimator
self.variables = variables or {}
self.type_info = type_info
self._operations: dict[str, list[ComplexityEvaluationNode]] = {}
self._fragments: dict[str, list[ComplexityEvaluationNode]] = {}
self._current_complexity_stack: list[ComplexityEvaluationNode] = []
Expand Down Expand Up @@ -130,7 +119,13 @@ def leave_operation_definition(self, node, *args, **kwargs):

def enter_field(self, node, key, parent, path, ancestors):
"""Add the complexity of the current field to the current complexity list."""
type_ = get_named_type(self.type_info.get_type())
if type_ is not None and is_introspection_type(type_):
# Skip introspection fields
return SKIP

if self._ignore_until_leave is not None:
# Skip fields until the parent field is left
return SKIP

complexity = self.estimator.get_field_complexity(
Expand All @@ -149,6 +144,7 @@ def leave_field(self, node, key, parent, path, ancestors):
self._ignore_until_leave is not None
and node.name.value == self._ignore_until_leave.name
):
# If we are leaving the ignored node, reset the flag
self._ignore_until_leave = None

def enter_fragment_definition(self, *args, **kwargs):
Expand All @@ -168,7 +164,7 @@ def enter_inline_fragment(self, *args, **kwargs):
self._current_complexity_stack = []


def should_include_field(node: DirectiveNode, variables: dict[str, Any] = None) -> bool:
def should_include_field(node: DirectiveNode, variables: dict[str, Any]) -> bool:
"""Check if a field should be ignored based on the 'skip' and 'include' directives."""
if node.name.value == GraphQLIncludeDirective.name:
return get_directive_if_value(node, variables)
Expand All @@ -180,9 +176,12 @@ def should_include_field(node: DirectiveNode, variables: dict[str, Any] = None)


def get_directive_if_value(directive: DirectiveNode, variables: dict[str, Any]) -> bool:
"""Returns the value of the `if` argument from the Directive"""
"""Returns the value of the `if` argument from the Directive. Used to get the boolean
value for skip/include directives."""
if_arg = next(arg for arg in directive.arguments if arg.name.value == "if")
if isinstance(if_arg.value, VariableNode):
return variables.get(if_arg.value.name.value)
return bool(variables.get(if_arg.value.name.value))
elif isinstance(if_arg.value, BooleanValueNode):
return if_arg.value.value
return bool(if_arg.value.value)

raise ValueError("Value for `if` argument not found")
12 changes: 5 additions & 7 deletions src/graphql_complexity/extensions/strawberry_graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,23 @@
from graphql import GraphQLError
from strawberry.extensions import SchemaExtension

from graphql_complexity.estimators import ComplexityEstimator, SimpleEstimator
from graphql_complexity.evaluator.complexity import get_ast_complexity
from graphql_complexity.estimators import ComplexityEstimator
from graphql_complexity.evaluator import get_ast_complexity


def build_complexity_extension(
estimator: ComplexityEstimator | None = None,
estimator: ComplexityEstimator,
max_complexity: int | None = None,
) -> Type[SchemaExtension]:
estimator = estimator or SimpleEstimator(1, 1)

class ComplexityExtension(SchemaExtension):
visitor = None
estimated_complexity: int = None
estimated_complexity: int | None = None

def on_validate(
self,
):
self.estimated_complexity = get_ast_complexity(
ast=self.execution_context.graphql_document,
schema=self.execution_context.schema._schema,
estimator=estimator
)

Expand Down
Loading

0 comments on commit 372a3d7

Please sign in to comment.