Skip to content

Commit

Permalink
Fix sharing of model versions (zenml-io#2380)
Browse files Browse the repository at this point in the history
* Fix sharing of model versions

* Auto-update of E2E template

* Fix mypy issues

* Auto-update of Starter template

* Auto-update of E2E template

---------

Co-authored-by: GitHub Actions <[email protected]>
  • Loading branch information
2 people authored and adtygan committed Mar 20, 2024
1 parent 930e8c4 commit e94cb82
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 20 deletions.
21 changes: 7 additions & 14 deletions src/zenml/zen_server/routers/users_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
get_allowed_resource_ids,
get_schema_for_resource_type,
update_resource_membership,
verify_permission,
verify_permission_for_model,
)
from zenml.zen_server.utils import (
Expand Down Expand Up @@ -514,27 +513,21 @@ def update_user_resource_membership(
resource = Resource(type=resource_type, id=resource_id)

schema_class = get_schema_for_resource_type(resource_type)
if not zen_store().object_exists(
object_id=resource_id, schema_class=schema_class
):
model = zen_store().get_entity_by_id(
entity_id=resource_id, schema_class=schema_class
)

if not model:
raise KeyError(
f"Resource of type {resource_type} with ID {resource_id} does "
"not exist."
)

verify_permission(
resource_type=resource_type,
action=Action.SHARE,
resource_id=resource_id,
)
verify_permission_for_model(model=model, action=Action.SHARE)
for action in actions:
# Make sure users aren't able to share permissions they don't have
# themselves
verify_permission(
resource_type=resource_type,
action=Action(action),
resource_id=resource_id,
)
verify_permission_for_model(model=model, action=Action(action))

update_resource_membership(
user=user,
Expand Down
41 changes: 35 additions & 6 deletions src/zenml/zen_stores/sql_zen_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -7307,25 +7307,54 @@ def _count_entity(

return int(entity_count)

def object_exists(
self, object_id: UUID, schema_class: Type[AnySchema]
def entity_exists(
self, entity_id: UUID, schema_class: Type[AnySchema]
) -> bool:
"""Check whether an object exists in the database.
"""Check whether an entity exists in the database.
Args:
object_id: The ID of the object to check.
entity_id: The ID of the entity to check.
schema_class: The schema class.
Returns:
If the object exists.
If the entity exists.
"""
with Session(self.engine) as session:
schema = session.exec(
select(schema_class.id).where(schema_class.id == object_id)
select(schema_class.id).where(schema_class.id == entity_id)
).first()

return False if schema is None else True

def get_entity_by_id(
self, entity_id: UUID, schema_class: Type[AnySchema]
) -> Optional[B]:
"""Get an entity by ID.
Args:
entity_id: The ID of the entity to get.
schema_class: The schema class.
Raises:
RuntimeError: If the schema to model conversion failed.
Returns:
The entity if it exists, None otherwise
"""
with Session(self.engine) as session:
schema = session.exec(
select(schema_class).where(schema_class.id == entity_id)
).first()

if not schema:
return None

to_model = getattr(schema, "to_model", None)
if callable(to_model):
return cast(B, to_model(hydrate=True))
else:
raise RuntimeError("Unable to convert schema to model.")

@staticmethod
def _get_schema_by_name_or_id(
object_name_or_id: Union[str, UUID],
Expand Down

0 comments on commit e94cb82

Please sign in to comment.