Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix status for requests high side after syncing #9339

Merged
merged 13 commits into from
Oct 8, 2024
18 changes: 18 additions & 0 deletions notebooks/scenarios/bigquery/sync/040-do-review-requests.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,24 @@
"widget._sync_all()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Check requests status on the high side"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for job in submitted_jobs_data_should_succeed:\n",
" request = get_request_for_job_info(all_requests, job)\n",
" assert request.status == RequestStatus.APPROVED"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
6 changes: 4 additions & 2 deletions packages/syft/src/syft/client/syncing.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,13 @@ def handle_sync_batch(

logger.debug(f"Decision: Syncing {len(sync_instructions)} objects")

# Apply empty state to source side to signal that we are done syncing
src_client.apply_state(src_resolved_state)
# Apply sync instructions to target side
for sync_instruction in sync_instructions:
tgt_resolved_state.add_sync_instruction(sync_instruction)
src_resolved_state.add_sync_instruction(sync_instruction)
# Apply empty state to source side to signal that we are done syncing
# We also add permissions for users from the low side to mark L0 request as approved
src_client.apply_state(src_resolved_state)
return tgt_client.apply_state(tgt_resolved_state)


Expand Down
62 changes: 49 additions & 13 deletions packages/syft/src/syft/service/sync/diff_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1558,7 +1558,8 @@ class SyncInstruction(SyftObject):

diff: ObjectDiff
decision: SyncDecision | None
new_permissions_lowside: list[ActionObjectPermission]
new_permissions_lowside: dict[type, list[ActionObjectPermission]]
new_permissions_highside: dict[type, list[ActionObjectPermission]]
new_storage_permissions_lowside: list[StoragePermission]
new_storage_permissions_highside: list[StoragePermission]
unignore: bool = False
Expand All @@ -1575,8 +1576,8 @@ def from_batch_decision(
share_to_user: SyftVerifyKey | None,
) -> Self:
# read widget state
new_permissions_low_side = []

new_permissions_low_side = {}
new_permissions_high_side = {}
# read permissions
if sync_direction == SyncDirection.HIGH_TO_LOW:
# To create read permissions for the object
Expand All @@ -1592,13 +1593,27 @@ def from_batch_decision(
"share_to_user is required to share private data"
)
else:
new_permissions_low_side = [
ActionObjectPermission(
uid=diff.object_id,
permission=ActionPermission.READ,
credentials=share_to_user,
)
]
new_permissions_low_side = {
diff.obj_type: [
ActionObjectPermission(
uid=diff.object_id,
permission=ActionPermission.READ,
credentials=share_to_user,
)
]
}
if diff.obj_type in [Job, SyftLog, Request] or issubclass(
diff.obj_type, ActionObject
):
new_permissions_high_side = {
diff.obj_type: [
ActionObjectPermission(
uid=diff.object_id,
permission=ActionPermission.READ,
credentials=share_to_user,
)
]
}

# storage permissions
new_storage_permissions = []
Expand All @@ -1620,6 +1635,7 @@ def from_batch_decision(
diff=diff,
decision=decision,
new_permissions_lowside=new_permissions_low_side,
new_permissions_highside=new_permissions_high_side,
new_storage_permissions_lowside=new_storage_permissions,
new_storage_permissions_highside=new_storage_permissions,
mockify=mockify,
Expand All @@ -1634,7 +1650,7 @@ class ResolvedSyncState(SyftObject):
create_objs: list[SyncableSyftObject] = []
update_objs: list[SyncableSyftObject] = []
delete_objs: list[SyftObject] = []
new_permissions: list[ActionObjectPermission] = []
new_permissions: dict[type, list[ActionObjectPermission]] = {}
new_storage_permissions: list[StoragePermission] = []
ignored_batches: dict[UID, int] = {} # batch root uid -> hash of the batch
unignored_batches: set[UID] = set()
Expand Down Expand Up @@ -1666,7 +1682,10 @@ def add_sync_instruction(self, sync_instruction: SyncInstruction) -> None:
if sync_instruction.unignore:
self.unignored_batches.add(sync_instruction.batch_diff.root_id)

if diff.status == "SAME":
if (
diff.status == "SAME"
and len(sync_instruction.new_permissions_highside) == 0
):
return

my_obj = diff.low_obj if self.alias == "low" else diff.high_obj
Expand Down Expand Up @@ -1695,11 +1714,28 @@ def add_sync_instruction(self, sync_instruction: SyncInstruction) -> None:
self.delete_objs.append(my_obj)

if self.alias == "low":
self.new_permissions.extend(sync_instruction.new_permissions_lowside)
for obj_type in sync_instruction.new_permissions_lowside.keys():
if obj_type in self.new_permissions:
self.new_permissions[obj_type].extend(
sync_instruction.new_permissions_lowside[obj_type]
)
else:
self.new_permissions[obj_type] = (
sync_instruction.new_permissions_lowside[obj_type]
)
self.new_storage_permissions.extend(
sync_instruction.new_storage_permissions_lowside
)
elif self.alias == "high":
for obj_type in sync_instruction.new_permissions_highside.keys():
if obj_type in self.new_permissions:
self.new_permissions[obj_type].extend(
sync_instruction.new_permissions_highside[obj_type]
)
else:
self.new_permissions[obj_type] = (
sync_instruction.new_permissions_highside[obj_type]
)
self.new_storage_permissions.extend(
sync_instruction.new_storage_permissions_highside
)
Expand Down
33 changes: 29 additions & 4 deletions packages/syft/src/syft/service/sync/sync_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from ..code.user_code import UserCodeStatusCollection
from ..context import AuthedServiceContext
from ..job.job_stash import Job
from ..log.log import SyftLog
from ..request.request import Request
from ..response import SyftSuccess
from ..service import AbstractService
from ..service import TYPE_TO_SERVICE
Expand Down Expand Up @@ -189,14 +191,38 @@ def sync_items(
self,
context: AuthedServiceContext,
items: list[SyncableSyftObject],
permissions: list[ActionObjectPermission],
permissions: dict[type, list[ActionObjectPermission]],
storage_permissions: list[StoragePermission],
ignored_batches: dict[UID, int],
unignored_batches: set[UID],
) -> SyftSuccess:
permissions_dict = defaultdict(list)
for permission in permissions:
permissions_dict[permission.uid].append(permission)
for permission_list in permissions.values():
for permission in permission_list:
permissions_dict[permission.uid].append(permission)

item_ids = [item.id.id for item in items]

# If we just want to add permissions without having an object
# This should happen only for the high side when we sync results but
# we need to add permissions for the DS to properly show the status of the requests
for obj_type, permission_list in permissions.items():
for permission in permission_list:
if permission.uid in item_ids:
continue
if obj_type not in [Job, SyftLog, Request] and not issubclass(
obj_type, ActionObject
):
raise SyftException(
public_message="Permission for object type not supported!"
)
if issubclass(obj_type, ActionObject):
store = context.server.services.action.stash
else:
service = context.server.get_service(TYPE_TO_SERVICE[obj_type])
store = service.stash # type: ignore[assignment]
if permission.permission == ActionPermission.READ:
store.add_permission(permission)

storage_permissions_dict = defaultdict(list)
for storage_permission in storage_permissions:
Expand All @@ -213,7 +239,6 @@ def sync_items(
else:
item = self.transform_item(context, item) # type: ignore[unreachable]
self.set_object(context, item).unwrap()

self.add_permissions_for_item(context, item, new_permissions)
self.add_storage_permissions_for_item(
context, item, new_storage_permissions
Expand Down
Loading