Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

3 allow multipliers in directivesestimator #6

Merged
merged 15 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ query = """
complexity = get_complexity(
query=query,
schema=schema,
estimator=SimpleEstimator(complexity=1, multiplier=10)
estimator=SimpleEstimator(complexity=10)
)
if complexity > 10:
raise Exception("Query is too complex")
Expand Down Expand Up @@ -73,7 +73,7 @@ it by another **constant** which is propagated along the depth of the query.
from graphql_complexity import SimpleEstimator


estimator = SimpleEstimator(complexity=2, multiplier=1)
estimator = SimpleEstimator(complexity=2)
```

Given the following GraphQL query:
Expand Down Expand Up @@ -150,13 +150,10 @@ from graphql_complexity import ComplexityEstimator


class CustomEstimator(ComplexityEstimator):
def get_field_complexity(self, node, key, parent, path, ancestors) -> int:
def get_field_complexity(self, node, type_info, path) -> int:
if node.name.value == "specificField":
return 100
return 1

def get_field_multiplier(self, node, key, parent, path, ancestors) -> int:
return 1
```


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "graphql_complexity"
version = "0.2.0"
version = "0.3.0"
description = "A python library that provides complexity calculation helpers for GraphQL"
authors = ["Checho3388 <[email protected]>"]
packages = [
Expand Down
7 changes: 7 additions & 0 deletions src/graphql_complexity/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import dataclasses


@dataclasses.dataclass(frozen=True)
class Config:
count_arg_name: str | None = "first" # ToDo: Improve Unset
count_missing_arg_value: int = 1
6 changes: 1 addition & 5 deletions src/graphql_complexity/estimators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,5 @@

class ComplexityEstimator(abc.ABC):
@abc.abstractmethod
def get_field_complexity(self, node, key, parent, path, ancestors) -> int:
def get_field_complexity(self, node, type_info, path) -> int:
"""Return the complexity of the field."""

@abc.abstractmethod
def get_field_multiplier(self, node, key, parent, path, ancestors) -> int:
"""Return the multiplier that will be applied to the children of the given node."""
6 changes: 1 addition & 5 deletions src/graphql_complexity/estimators/directive.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,5 @@ def collect_from_schema(schema: str, directive_name: str) -> dict[str, int]:
visit(ast, visitor)
return collector

def get_field_complexity(self, node, key, parent, path, ancestors) -> int:
def get_field_complexity(self, node, type_info, path) -> int:
return self.__complexity_map.get(node.name.value, self.__missing_complexity)

def get_field_multiplier(self, node, key, parent, path, ancestors) -> int:
# ToDo: Implement this method
return 1
33 changes: 4 additions & 29 deletions src/graphql_complexity/estimators/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,16 @@


class SimpleEstimator(ComplexityEstimator):
"""Simple complexity estimator that returns a constant complexity and multiplier for all fields.
Constants can be set in the constructor.
"""Simple complexity estimator that returns a constant complexity for all fields.
Constant can be set in the constructor."""

Example:
Given the following query:
```qgl
query {
user {
name
email
}
}
```
As the complexity and multiplier are constant, the complexity of the fields will be:
- user: 1 * 1 = 1
- name: 1 * 1 = 1
- email: 1 * 1 = 1
And the total complexity will be 3.
"""

def __init__(self, complexity: int = 1, multiplier: int = 1):
def __init__(self, complexity: int = 1):
if complexity < 0:
raise ValueError(
"'complexity' must be a positive integer (greater or equal than 0)"
)
if multiplier < 0:
raise ValueError(
"'multiplier' must be a positive integer (greater or equal than 0)"
)
self.__complexity_constant = complexity
self.__multiplier_constant = multiplier
super().__init__()

def get_field_complexity(self, node, key, parent, path, ancestors) -> int:
def get_field_complexity(self, *_, **__) -> int:
return self.__complexity_constant

def get_field_multiplier(self, node, key, parent, path, ancestors) -> int:
return self.__multiplier_constant
6 changes: 1 addition & 5 deletions src/graphql_complexity/evaluator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
from .complexity import (
get_ast_complexity,
get_complexity,
)
from .complexity import get_complexity

__all__ = [
'get_complexity',
'get_ast_complexity'
]
23 changes: 16 additions & 7 deletions src/graphql_complexity/evaluator/complexity.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,29 @@
from graphql import parse, visit, TypeInfo, TypeInfoVisitor, GraphQLSchema
from graphql import GraphQLSchema, TypeInfo, TypeInfoVisitor, parse, visit

from . import nodes
from .visitor import ComplexityVisitor
from ..config import Config
from ..estimators import ComplexityEstimator


def get_complexity(query: str, schema: GraphQLSchema, estimator: ComplexityEstimator) -> int:
def get_complexity(query: str, schema: GraphQLSchema, estimator: ComplexityEstimator, config: Config = None) -> int:
"""Calculate the complexity of a query using the provided estimator."""
ast = parse(query)
return get_ast_complexity(ast, schema=schema, estimator=estimator)
tree = build_complexity_tree(query, schema, estimator, config)

return tree.evaluate()


def get_ast_complexity(ast, schema: GraphQLSchema, estimator: ComplexityEstimator) -> int:
def build_complexity_tree(
query: str,
schema: GraphQLSchema,
estimator: ComplexityEstimator,
config: Config | None = None,
) -> nodes.ComplexityNode:
"""Calculate the complexity of a query using the provided estimator."""
ast = parse(query)
type_info = TypeInfo(schema)

visitor = ComplexityVisitor(estimator=estimator, type_info=type_info)
visitor = ComplexityVisitor(estimator=estimator, type_info=type_info, config=config)
visit(ast, TypeInfoVisitor(type_info, visitor))

return visitor.evaluate()
return visitor.complexity_tree
137 changes: 137 additions & 0 deletions src/graphql_complexity/evaluator/nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import dataclasses
import logging
from typing import Any

from graphql import (
GraphQLList,
TypeInfo,
get_named_type,
is_introspection_type, FieldNode
)

from graphql_complexity.config import Config
from graphql_complexity.evaluator.utils import get_node_argument_value

logger = logging.getLogger(__name__)


@dataclasses.dataclass(slots=True, kw_only=True)
class ComplexityNode:
name: str
parent: 'ComplexityNode' = None
children: list['ComplexityNode'] = dataclasses.field(default_factory=list)

def evaluate(self) -> int:
raise NotImplementedError

def describe(self, depth=0) -> str:
"""Return a friendly representation of the node and its children complexity."""
return (
f"{chr(9) * depth}{self.name} ({self.__class__.__name__}) = {self.evaluate()}" +
f"{chr(10) if self.children else ''}" +
'\n'.join(c.describe(depth=depth+1) for c in self.children)
)

def add_child(self, node: 'ComplexityNode') -> None:
"""Add a child to the current node."""
self.children.append(node)
node.parent = self


@dataclasses.dataclass(slots=True, kw_only=True)
class RootNode(ComplexityNode):
def evaluate(self) -> int:
return sum(child.evaluate() for child in self.children)


@dataclasses.dataclass(slots=True, kw_only=True)
class FragmentSpreadNode(ComplexityNode):
fragments_definition: dict

def evaluate(self):
fragment = self.fragments_definition.get(self.name)
if not fragment:
return 0
return fragment.evaluate()


@dataclasses.dataclass(slots=True, kw_only=True)
class Field(ComplexityNode):
complexity: int

def evaluate(self) -> int:
return self.complexity + sum(child.evaluate() for child in self.children)


@dataclasses.dataclass(slots=True, kw_only=True)
class ListField(Field):
count: int = 1

def evaluate(self) -> int:
return self.complexity + self.count * sum(child.evaluate() for child in self.children)


@dataclasses.dataclass(slots=True, kw_only=True)
class SkippedField(ComplexityNode):
wraps: ComplexityNode

@classmethod
def wrap(cls, node: ComplexityNode):
wrapper = cls(
name=node.name,
parent=node.parent,
children=node.children,
wraps=node,
)
node.parent.children.remove(node)
node.parent.add_child(wrapper)
return wrapper

def evaluate(self) -> int:
return 0


@dataclasses.dataclass(slots=True, kw_only=True)
class MetaField(ComplexityNode):

def evaluate(self) -> int:
return 0


def build_node(
node: FieldNode,
type_info: TypeInfo,
complexity: int,
variables: dict[str, Any],
config: Config,
) -> ComplexityNode:
"""Build a complexity node from a field node."""
type_ = type_info.get_type()
unwrapped_type = get_named_type(type_)
if unwrapped_type is not None and is_introspection_type(unwrapped_type):
return MetaField(name=node.name.value)
if isinstance(type_, GraphQLList):
return build_list_node(node, complexity, variables, config)
return Field(
name=node.name.value,
complexity=complexity,
)


def build_list_node(node: FieldNode, complexity: int, variables: dict[str, Any], config: Config) -> ListField:
"""Build a list complexity node from a field node."""
if config.count_arg_name:
try:
count = int(
get_node_argument_value(node=node, arg_name=config.count_arg_name, variables=variables)
)
except ValueError:
logger.debug("Missing or invalid value for argument '%s' in node '%s'", config.count_arg_name, node)
count = config.count_missing_arg_value
else:
count = 1
return ListField(
name=node.name.value,
complexity=complexity,
count=count,
)
18 changes: 18 additions & 0 deletions src/graphql_complexity/evaluator/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import Any

from graphql import DirectiveNode, FieldNode, VariableNode


def get_node_argument_value(node: FieldNode | DirectiveNode, arg_name: str, variables: dict[str, Any]) -> Any:
"""Returns the value of the argument given by parameter."""
arg = next(
(arg for arg in node.arguments if arg.name.value == arg_name),
None
)
if not arg:
raise ValueError(f"Value for {arg_name!r} not found in {node.name.value!r} arguments")

if isinstance(arg.value, VariableNode):
return variables.get(arg.value.name.value)

return arg.value.value
Loading
Loading