Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow introspection #5

Merged
merged 5 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# graphql-complexity
Python library to compute the complexity of a GraphQL operation

![Unit Tests](https://github.com/Checho3388/graphql-complexity/actions/workflows/python-package.yml/badge.svg)
![Build](https://github.com/Checho3388/graphql-complexity/actions/workflows/python-buildlu.yml/badge.svg)
[![PyPI](https://img.shields.io/pypi/v/graphql-complexity?label=pypi%20package)](https://pypi.org/project/graphql-complexity/)
[![codecov](https://codecov.io/gh/Checho3388/graphql-complexity/graph/badge.svg?token=4LH7AVN119)](https://codecov.io/gh/Checho3388/graphql-complexity)

## Installation (Quick Start)
The library can be installed using pip:
Expand All @@ -18,6 +19,18 @@ pip install graphql-complexity[strawberry-graphql]
Create a file named `complexity.py` with the following content:
```python
from graphql_complexity import (get_complexity, SimpleEstimator)
from graphql import build_schema


schema = build_schema("""
type User {
id: ID!
name: String!
}
type Query {
user: User
}
""")

query = """
query SomeQuery {
Expand All @@ -30,6 +43,7 @@ query = """

complexity = get_complexity(
query=query,
schema=schema,
estimator=SimpleEstimator(complexity=1, multiplier=10)
)
if complexity > 10:
Expand Down Expand Up @@ -163,6 +177,7 @@ extension that can be added to the schema.

```python
import strawberry
from graphql_complexity.estimators import SimpleEstimator
from graphql_complexity.extensions.strawberry_graphql import build_complexity_extension


Expand All @@ -172,7 +187,7 @@ class Query:
def hello_world(self) -> str:
return "Hello world!"

extension = build_complexity_extension()
extension = build_complexity_extension(estimator=SimpleEstimator())
schema = strawberry.Schema(query=Query, extensions=[extension])

schema.execute_sync("query { helloWorld }")
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "graphql_complexity"
version = "0.1.0"
version = "0.2.0"
description = "A python library that provides complexity calculation helpers for GraphQL"
authors = ["Checho3388 <[email protected]>"]
packages = [
Expand Down
4 changes: 3 additions & 1 deletion src/graphql_complexity/estimators/directive.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

from graphql import Visitor, parse, visit

from graphql_complexity.estimators.base import ComplexityEstimator
Expand Down Expand Up @@ -63,7 +65,7 @@ def __init__(

@staticmethod
def collect_from_schema(schema: str, directive_name: str) -> dict[str, int]:
collector = {}
collector: dict[str, Any] = {}
ast = parse(schema)
visitor = DirectivesVisitor(collector=collector, directive_name=directive_name)
visit(ast, visitor)
Expand Down
6 changes: 5 additions & 1 deletion src/graphql_complexity/evaluator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from .complexity import get_complexity
from .complexity import (
get_ast_complexity,
get_complexity,
)

__all__ = [
'get_complexity',
'get_ast_complexity'
]
16 changes: 9 additions & 7 deletions src/graphql_complexity/evaluator/complexity.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from graphql import parse, visit
from graphql import parse, visit, TypeInfo, TypeInfoVisitor, GraphQLSchema

from ..estimators import ComplexityEstimator
from .visitor import ComplexityVisitor
from ..estimators import ComplexityEstimator


def get_complexity(query: str, estimator: ComplexityEstimator) -> int:
def get_complexity(query: str, schema: GraphQLSchema, estimator: ComplexityEstimator) -> int:
"""Calculate the complexity of a query using the provided estimator."""
ast = parse(query)
return get_ast_complexity(ast, estimator=estimator)
return get_ast_complexity(ast, schema=schema, estimator=estimator)


def get_ast_complexity(ast, estimator: ComplexityEstimator) -> int:
def get_ast_complexity(ast, schema: GraphQLSchema, estimator: ComplexityEstimator) -> int:
"""Calculate the complexity of a query using the provided estimator."""
visitor = ComplexityVisitor(estimator=estimator)
visit(ast, visitor)
type_info = TypeInfo(schema)

visitor = ComplexityVisitor(estimator=estimator, type_info=type_info)
visit(ast, TypeInfoVisitor(type_info, visitor))

return visitor.evaluate()
35 changes: 17 additions & 18 deletions src/graphql_complexity/evaluator/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
GraphQLSkipDirective,
OperationDefinitionNode,
VariableNode,
Visitor
Visitor, is_introspection_type, get_named_type, TypeInfo
)

from graphql_complexity.estimators.base import ComplexityEstimator
Expand Down Expand Up @@ -61,25 +61,14 @@

The complexity of the operations is calculated by summing the complexity of
the fields in the operation.

Examples:
>>> from graphql import parse, visit
>>> from graphql_complexity.estimators import SimpleEstimator
>>> from graphql_complexity.visitor import ComplexityVisitor
>>> query = "query { user { name email } }"
>>> ast = parse(query)
>>> visitor = ComplexityVisitor(estimator=SimpleEstimator())
>>> exc = visit(ast, visitor)
>>> visitor.evaluate()
3

"""

def __init__(self, estimator: ComplexityEstimator, variables=None):
def __init__(self, estimator: ComplexityEstimator, type_info: TypeInfo, variables: dict[str, Any] | None = None):
if not isinstance(estimator, ComplexityEstimator):
raise ValueError("Estimator must be of type 'ComplexityEstimator'")
self.estimator: ComplexityEstimator = estimator
self.variables = variables or {}
self.type_info = type_info
self._operations: dict[str, list[ComplexityEvaluationNode]] = {}
self._fragments: dict[str, list[ComplexityEvaluationNode]] = {}
self._current_complexity_stack: list[ComplexityEvaluationNode] = []
Expand Down Expand Up @@ -130,7 +119,13 @@

def enter_field(self, node, key, parent, path, ancestors):
"""Add the complexity of the current field to the current complexity list."""
type_ = get_named_type(self.type_info.get_type())
if type_ is not None and is_introspection_type(type_):
# Skip introspection fields
return SKIP

if self._ignore_until_leave is not None:
# Skip fields until the parent field is left
return SKIP

complexity = self.estimator.get_field_complexity(
Expand All @@ -149,6 +144,7 @@
self._ignore_until_leave is not None
and node.name.value == self._ignore_until_leave.name
):
# If we are leaving the ignored node, reset the flag
self._ignore_until_leave = None

def enter_fragment_definition(self, *args, **kwargs):
Expand All @@ -168,7 +164,7 @@
self._current_complexity_stack = []


def should_include_field(node: DirectiveNode, variables: dict[str, Any] = None) -> bool:
def should_include_field(node: DirectiveNode, variables: dict[str, Any]) -> bool:
"""Check if a field should be ignored based on the 'skip' and 'include' directives."""
if node.name.value == GraphQLIncludeDirective.name:
return get_directive_if_value(node, variables)
Expand All @@ -180,9 +176,12 @@


def get_directive_if_value(directive: DirectiveNode, variables: dict[str, Any]) -> bool:
"""Returns the value of the `if` argument from the Directive"""
"""Returns the value of the `if` argument from the Directive. Used to get the boolean
value for skip/include directives."""
if_arg = next(arg for arg in directive.arguments if arg.name.value == "if")
if isinstance(if_arg.value, VariableNode):
return variables.get(if_arg.value.name.value)
return bool(variables.get(if_arg.value.name.value))
elif isinstance(if_arg.value, BooleanValueNode):
return if_arg.value.value
return bool(if_arg.value.value)

raise ValueError("Value for `if` argument not found")

Check warning on line 187 in src/graphql_complexity/evaluator/visitor.py

View check run for this annotation

Codecov / codecov/patch

src/graphql_complexity/evaluator/visitor.py#L187

Added line #L187 was not covered by tests
12 changes: 5 additions & 7 deletions src/graphql_complexity/extensions/strawberry_graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,23 @@
from graphql import GraphQLError
from strawberry.extensions import SchemaExtension

from graphql_complexity.estimators import ComplexityEstimator, SimpleEstimator
from graphql_complexity.evaluator.complexity import get_ast_complexity
from graphql_complexity.estimators import ComplexityEstimator
from graphql_complexity.evaluator import get_ast_complexity


def build_complexity_extension(
estimator: ComplexityEstimator | None = None,
estimator: ComplexityEstimator,
max_complexity: int | None = None,
) -> Type[SchemaExtension]:
estimator = estimator or SimpleEstimator(1, 1)

class ComplexityExtension(SchemaExtension):
visitor = None
estimated_complexity: int = None
estimated_complexity: int | None = None

def on_validate(
self,
):
self.estimated_complexity = get_ast_complexity(
ast=self.execution_context.graphql_document,
schema=self.execution_context.schema._schema,
estimator=estimator
)

Expand Down
Loading
Loading