Skip to content

Commit

Permalink
Merge pull request #6 from Checho3388/3-allow-multipliers-in-directiv…
Browse files Browse the repository at this point in the history
…esestimator

3 allow multipliers in directivesestimator
  • Loading branch information
Checho3388 committed Mar 13, 2024
2 parents 4d14566 + 6fb43f4 commit 52c7507
Show file tree
Hide file tree
Showing 19 changed files with 425 additions and 266 deletions.
9 changes: 3 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ query = """
complexity = get_complexity(
query=query,
schema=schema,
estimator=SimpleEstimator(complexity=1, multiplier=10)
estimator=SimpleEstimator(complexity=10)
)
if complexity > 10:
raise Exception("Query is too complex")
Expand Down Expand Up @@ -73,7 +73,7 @@ it by another **constant** which is propagated along the depth of the query.
from graphql_complexity import SimpleEstimator


estimator = SimpleEstimator(complexity=2, multiplier=1)
estimator = SimpleEstimator(complexity=2)
```

Given the following GraphQL query:
Expand Down Expand Up @@ -150,13 +150,10 @@ from graphql_complexity import ComplexityEstimator


class CustomEstimator(ComplexityEstimator):
def get_field_complexity(self, node, key, parent, path, ancestors) -> int:
def get_field_complexity(self, node, type_info, path) -> int:
if node.name.value == "specificField":
return 100
return 1

def get_field_multiplier(self, node, key, parent, path, ancestors) -> int:
return 1
```


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "graphql_complexity"
version = "0.2.0"
version = "0.3.0"
description = "A python library that provides complexity calculation helpers for GraphQL"
authors = ["Checho3388 <[email protected]>"]
packages = [
Expand Down
7 changes: 7 additions & 0 deletions src/graphql_complexity/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import dataclasses


@dataclasses.dataclass(frozen=True)
class Config:
count_arg_name: str | None = "first" # ToDo: Improve Unset
count_missing_arg_value: int = 1
6 changes: 1 addition & 5 deletions src/graphql_complexity/estimators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,5 @@

class ComplexityEstimator(abc.ABC):
@abc.abstractmethod
def get_field_complexity(self, node, key, parent, path, ancestors) -> int:
def get_field_complexity(self, node, type_info, path) -> int:
"""Return the complexity of the field."""

@abc.abstractmethod
def get_field_multiplier(self, node, key, parent, path, ancestors) -> int:
"""Return the multiplier that will be applied to the children of the given node."""
6 changes: 1 addition & 5 deletions src/graphql_complexity/estimators/directive.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,5 @@ def collect_from_schema(schema: str, directive_name: str) -> dict[str, int]:
visit(ast, visitor)
return collector

def get_field_complexity(self, node, key, parent, path, ancestors) -> int:
def get_field_complexity(self, node, type_info, path) -> int:
return self.__complexity_map.get(node.name.value, self.__missing_complexity)

def get_field_multiplier(self, node, key, parent, path, ancestors) -> int:
# ToDo: Implement this method
return 1
33 changes: 4 additions & 29 deletions src/graphql_complexity/estimators/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,16 @@


class SimpleEstimator(ComplexityEstimator):
"""Simple complexity estimator that returns a constant complexity and multiplier for all fields.
Constants can be set in the constructor.
"""Simple complexity estimator that returns a constant complexity for all fields.
Constant can be set in the constructor."""

Example:
Given the following query:
```qgl
query {
user {
name
email
}
}
```
As the complexity and multiplier are constant, the complexity of the fields will be:
- user: 1 * 1 = 1
- name: 1 * 1 = 1
- email: 1 * 1 = 1
And the total complexity will be 3.
"""

def __init__(self, complexity: int = 1, multiplier: int = 1):
def __init__(self, complexity: int = 1):
if complexity < 0:
raise ValueError(
"'complexity' must be a positive integer (greater or equal than 0)"
)
if multiplier < 0:
raise ValueError(
"'multiplier' must be a positive integer (greater or equal than 0)"
)
self.__complexity_constant = complexity
self.__multiplier_constant = multiplier
super().__init__()

def get_field_complexity(self, node, key, parent, path, ancestors) -> int:
def get_field_complexity(self, *_, **__) -> int:
return self.__complexity_constant

def get_field_multiplier(self, node, key, parent, path, ancestors) -> int:
return self.__multiplier_constant
6 changes: 1 addition & 5 deletions src/graphql_complexity/evaluator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
from .complexity import (
get_ast_complexity,
get_complexity,
)
from .complexity import get_complexity

__all__ = [
'get_complexity',
'get_ast_complexity'
]
23 changes: 16 additions & 7 deletions src/graphql_complexity/evaluator/complexity.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,29 @@
from graphql import parse, visit, TypeInfo, TypeInfoVisitor, GraphQLSchema
from graphql import GraphQLSchema, TypeInfo, TypeInfoVisitor, parse, visit

from . import nodes
from .visitor import ComplexityVisitor
from ..config import Config
from ..estimators import ComplexityEstimator


def get_complexity(query: str, schema: GraphQLSchema, estimator: ComplexityEstimator) -> int:
def get_complexity(query: str, schema: GraphQLSchema, estimator: ComplexityEstimator, config: Config = None) -> int:
"""Calculate the complexity of a query using the provided estimator."""
ast = parse(query)
return get_ast_complexity(ast, schema=schema, estimator=estimator)
tree = build_complexity_tree(query, schema, estimator, config)

return tree.evaluate()


def get_ast_complexity(ast, schema: GraphQLSchema, estimator: ComplexityEstimator) -> int:
def build_complexity_tree(
query: str,
schema: GraphQLSchema,
estimator: ComplexityEstimator,
config: Config | None = None,
) -> nodes.ComplexityNode:
"""Calculate the complexity of a query using the provided estimator."""
ast = parse(query)
type_info = TypeInfo(schema)

visitor = ComplexityVisitor(estimator=estimator, type_info=type_info)
visitor = ComplexityVisitor(estimator=estimator, type_info=type_info, config=config)
visit(ast, TypeInfoVisitor(type_info, visitor))

return visitor.evaluate()
return visitor.complexity_tree
137 changes: 137 additions & 0 deletions src/graphql_complexity/evaluator/nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import dataclasses
import logging
from typing import Any

from graphql import (
GraphQLList,
TypeInfo,
get_named_type,
is_introspection_type, FieldNode
)

from graphql_complexity.config import Config
from graphql_complexity.evaluator.utils import get_node_argument_value

logger = logging.getLogger(__name__)


@dataclasses.dataclass(slots=True, kw_only=True)
class ComplexityNode:
name: str
parent: 'ComplexityNode' = None
children: list['ComplexityNode'] = dataclasses.field(default_factory=list)

def evaluate(self) -> int:
raise NotImplementedError

def describe(self, depth=0) -> str:
"""Return a friendly representation of the node and its children complexity."""
return (
f"{chr(9) * depth}{self.name} ({self.__class__.__name__}) = {self.evaluate()}" +
f"{chr(10) if self.children else ''}" +
'\n'.join(c.describe(depth=depth+1) for c in self.children)
)

def add_child(self, node: 'ComplexityNode') -> None:
"""Add a child to the current node."""
self.children.append(node)
node.parent = self


@dataclasses.dataclass(slots=True, kw_only=True)
class RootNode(ComplexityNode):
def evaluate(self) -> int:
return sum(child.evaluate() for child in self.children)


@dataclasses.dataclass(slots=True, kw_only=True)
class FragmentSpreadNode(ComplexityNode):
fragments_definition: dict

def evaluate(self):
fragment = self.fragments_definition.get(self.name)
if not fragment:
return 0
return fragment.evaluate()


@dataclasses.dataclass(slots=True, kw_only=True)
class Field(ComplexityNode):
complexity: int

def evaluate(self) -> int:
return self.complexity + sum(child.evaluate() for child in self.children)


@dataclasses.dataclass(slots=True, kw_only=True)
class ListField(Field):
count: int = 1

def evaluate(self) -> int:
return self.complexity + self.count * sum(child.evaluate() for child in self.children)


@dataclasses.dataclass(slots=True, kw_only=True)
class SkippedField(ComplexityNode):
wraps: ComplexityNode

@classmethod
def wrap(cls, node: ComplexityNode):
wrapper = cls(
name=node.name,
parent=node.parent,
children=node.children,
wraps=node,
)
node.parent.children.remove(node)
node.parent.add_child(wrapper)
return wrapper

def evaluate(self) -> int:
return 0


@dataclasses.dataclass(slots=True, kw_only=True)
class MetaField(ComplexityNode):

def evaluate(self) -> int:
return 0


def build_node(
node: FieldNode,
type_info: TypeInfo,
complexity: int,
variables: dict[str, Any],
config: Config,
) -> ComplexityNode:
"""Build a complexity node from a field node."""
type_ = type_info.get_type()
unwrapped_type = get_named_type(type_)
if unwrapped_type is not None and is_introspection_type(unwrapped_type):
return MetaField(name=node.name.value)
if isinstance(type_, GraphQLList):
return build_list_node(node, complexity, variables, config)
return Field(
name=node.name.value,
complexity=complexity,
)


def build_list_node(node: FieldNode, complexity: int, variables: dict[str, Any], config: Config) -> ListField:
"""Build a list complexity node from a field node."""
if config.count_arg_name:
try:
count = int(
get_node_argument_value(node=node, arg_name=config.count_arg_name, variables=variables)
)
except ValueError:
logger.debug("Missing or invalid value for argument '%s' in node '%s'", config.count_arg_name, node)
count = config.count_missing_arg_value
else:
count = 1
return ListField(
name=node.name.value,
complexity=complexity,
count=count,
)
18 changes: 18 additions & 0 deletions src/graphql_complexity/evaluator/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import Any

from graphql import DirectiveNode, FieldNode, VariableNode


def get_node_argument_value(node: FieldNode | DirectiveNode, arg_name: str, variables: dict[str, Any]) -> Any:
"""Returns the value of the argument given by parameter."""
arg = next(
(arg for arg in node.arguments if arg.name.value == arg_name),
None
)
if not arg:
raise ValueError(f"Value for {arg_name!r} not found in {node.name.value!r} arguments")

if isinstance(arg.value, VariableNode):
return variables.get(arg.value.name.value)

return arg.value.value
Loading

0 comments on commit 52c7507

Please sign in to comment.