Skip to content

Commit

Permalink
Merge pull request #37 from argyle-engineering/improve-alias-resolution
Browse files Browse the repository at this point in the history
Improve alias resolution
  • Loading branch information
povilasb committed Jun 29, 2023
2 parents a514082 + 88d84f1 commit ec5fc9f
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 7 deletions.
21 changes: 15 additions & 6 deletions pydantic2zod/_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,16 +295,25 @@ def _finish_parsing_class(self, cls_decl: ClassDecl) -> ClassDecl | None:
self._model_graph.add_node(cls.full_path)
self._pydantic_classes[cls.name] = cls

# Try to resolve type aliases and generic type variables.
for f in cls.fields:
# TODO(povilas): recurse into generic types
if isinstance(f.type, UserDefinedType):
if node := self._alias_nodes.get(f.type.name):
assert node.value
f.type = _extract_type(node.value)
f.type = self._resolve_type_aliases(f.type)

return cls

def _resolve_type_aliases(self, tp: PyType) -> PyType:
match tp:
case UserDefinedType(name=name):
if node := self._alias_nodes.get(name):
assert node.value
return _extract_type(node.value)
case GenericType(type_vars=type_vars):
for i, type_var in enumerate(type_vars):
tp.type_vars[i] = self._resolve_type_aliases(type_var)
case _:
...

return tp

def _is_pydantic_model(self, cls: ClassDecl) -> bool:
for b in cls.base_classes:
if self._is_imported(b) in [
Expand Down
1 change: 1 addition & 0 deletions tests/fixtures/type_alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ class LambdaFunc(BaseModel):

class EventBus(BaseModel):
handlers: EventHandler
handlers2: list[EventHandler]
20 changes: 19 additions & 1 deletion tests/test_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,25 @@ def test_supports_explicit_type_alias(self):
),
]
),
)
),
ClassField(
name="handlers2",
type=GenericType(
generic="list",
type_vars=[
UnionType(
types=[
UserDefinedType(
name="tests.fixtures.type_alias.Function"
),
UserDefinedType(
name="tests.fixtures.type_alias.LambdaFunc"
),
]
)
],
),
),
],
base_classes=["BaseModel"],
),
Expand Down

0 comments on commit ec5fc9f

Please sign in to comment.