diff --git a/backend/src/main/java/gov/cdc/usds/simplereport/idp/repository/DemoOktaRepository.java b/backend/src/main/java/gov/cdc/usds/simplereport/idp/repository/DemoOktaRepository.java index 7587d61207..af5b785aea 100644 --- a/backend/src/main/java/gov/cdc/usds/simplereport/idp/repository/DemoOktaRepository.java +++ b/backend/src/main/java/gov/cdc/usds/simplereport/idp/repository/DemoOktaRepository.java @@ -415,7 +415,7 @@ public void reset() { } @Override - public Integer getUsersInSingleFacility(Facility facility) { + public Integer getUsersCountInSingleFacility(Facility facility) { Integer accessCount = 0; for (OrganizationRoleClaims existingClaims : usernameOrgRolesMap.values()) { @@ -434,7 +434,7 @@ public Integer getUsersInSingleFacility(Facility facility) { } @Override - public Integer getUsersInOrganization(Organization org) { + public Integer getUsersCountInOrganization(Organization org) { return orgUsernamesMap.get(org.getExternalId()).size(); } diff --git a/backend/src/main/java/gov/cdc/usds/simplereport/idp/repository/OktaRepository.java b/backend/src/main/java/gov/cdc/usds/simplereport/idp/repository/OktaRepository.java index 2a516291ac..4390e6b796 100644 --- a/backend/src/main/java/gov/cdc/usds/simplereport/idp/repository/OktaRepository.java +++ b/backend/src/main/java/gov/cdc/usds/simplereport/idp/repository/OktaRepository.java @@ -92,9 +92,9 @@ Map getPagedUsersWithStatusForOrganization( Optional getOrganizationRoleClaimsForUser(String username); - Integer getUsersInSingleFacility(Facility facility); + Integer getUsersCountInSingleFacility(Facility facility); - Integer getUsersInOrganization(Organization org); + Integer getUsersCountInOrganization(Organization org); PartialOktaUser findUser(String username); diff --git a/backend/src/main/java/gov/cdc/usds/simplereport/service/ApiUserService.java b/backend/src/main/java/gov/cdc/usds/simplereport/service/ApiUserService.java index 71fdb81f32..87eb6d1f8f 100644 --- a/backend/src/main/java/gov/cdc/usds/simplereport/service/ApiUserService.java +++ b/backend/src/main/java/gov/cdc/usds/simplereport/service/ApiUserService.java @@ -632,7 +632,7 @@ public Page getPagedUsersAndStatusInCurrentOrg(int pageNumber .map(u -> new ApiUserWithStatus(u, emailsToStatus.get(u.getLoginEmail()))) .toList(); - Integer userCountInOrg = _oktaRepo.getUsersInOrganization(org); + Integer userCountInOrg = _oktaRepo.getUsersCountInOrganization(org); PageRequest pageRequest = PageRequest.of(pageNumber, pageSize); return new PageImpl<>(userWithStatusList, pageRequest, userCountInOrg); diff --git a/backend/src/main/java/gov/cdc/usds/simplereport/service/OrganizationService.java b/backend/src/main/java/gov/cdc/usds/simplereport/service/OrganizationService.java index 87408a26f7..c64fdb119b 100644 --- a/backend/src/main/java/gov/cdc/usds/simplereport/service/OrganizationService.java +++ b/backend/src/main/java/gov/cdc/usds/simplereport/service/OrganizationService.java @@ -496,7 +496,7 @@ public FacilityStats getFacilityStats(@Argument UUID facilityId) { usersWithSingleFacilityAccess = dbAuthorizationService.getUsersWithSingleFacilityAccessCount(facility); } else { - usersWithSingleFacilityAccess = this.oktaRepository.getUsersInSingleFacility(facility); + usersWithSingleFacilityAccess = this.oktaRepository.getUsersCountInSingleFacility(facility); } return FacilityStats.builder() .usersSingleAccessCount(usersWithSingleFacilityAccess) diff --git a/backend/src/test/java/gov/cdc/usds/simplereport/idp/repository/LiveOktaRepositoryTest.java b/backend/src/test/java/gov/cdc/usds/simplereport/idp/repository/LiveOktaRepositoryTest.java index d7547aac07..0e31311198 100644 --- a/backend/src/test/java/gov/cdc/usds/simplereport/idp/repository/LiveOktaRepositoryTest.java +++ b/backend/src/test/java/gov/cdc/usds/simplereport/idp/repository/LiveOktaRepositoryTest.java @@ -33,6 +33,7 @@ import com.okta.sdk.resource.model.UserStatus; import com.okta.sdk.resource.user.UserBuilder; import gov.cdc.usds.simplereport.api.CurrentTenantDataAccessContextHolder; +import gov.cdc.usds.simplereport.api.model.errors.BadRequestException; import gov.cdc.usds.simplereport.api.model.errors.ConflictingUserException; import gov.cdc.usds.simplereport.api.model.errors.IllegalGraphqlArgumentException; import gov.cdc.usds.simplereport.config.AuthorizationProperties; @@ -47,6 +48,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -692,6 +694,109 @@ void getPagedUsersWithStatusForOrganization() { assertEquals(Map.of("email@example.com", UserStatus.ACTIVE), actual); } + @Test + void getUsersCountInOrganization() { + String orgExternalId = "ce8782ca-18ba-4384-95fc-fbb7be0ef577"; + + var groupProfilePrefix = "SR-UNITTEST-TENANT:" + orgExternalId + ":NO_ACCESS"; + + var mockOrg = mock(Organization.class); + when(mockOrg.getExternalId()).thenReturn(orgExternalId); + + var mockGroup = mock(Group.class); + var mockGroupList = List.of(mockGroup); + when(groupApi.listGroups( + eq(groupProfilePrefix), + isNull(), + isNull(), + eq(1), + eq("stats"), + isNull(), + isNull(), + isNull())) + .thenReturn(mockGroupList); + + var mockEmbeddedMap = mock(Map.class); + when(mockGroup.getEmbedded()).thenReturn(mockEmbeddedMap); + var mockStatsLinkedHashMap = mock(LinkedHashMap.class); + when(mockEmbeddedMap.get(eq("stats"))).thenReturn(mockStatsLinkedHashMap); + when(mockStatsLinkedHashMap.get(eq("usersCount"))).thenReturn(10); + + var actual = _repo.getUsersCountInOrganization(mockOrg); + assertEquals(10, actual); + } + + @Test + void getUsersCountInSingleFacility() { + String facilityInternalId = "05bc0080-ad53-4b7a-a4b5-d1f86059a304"; + String orgExternalId = "ce8782ca-18ba-4384-95fc-fbb7be0ef577"; + + var facilitySuffix = ":FACILITY_ACCESS:" + facilityInternalId; + var groupProfilePrefix = "SR-UNITTEST-TENANT:" + orgExternalId + facilitySuffix; + + var mockOrg = mock(Organization.class); + var mockFacility = mock(Facility.class); + when(mockFacility.getOrganization()).thenReturn(mockOrg); + when(mockOrg.getExternalId()).thenReturn(orgExternalId); + when(mockFacility.getInternalId()).thenReturn(UUID.fromString(facilityInternalId)); + + var mockGroup = mock(Group.class); + var mockGroupList = List.of(mockGroup); + when(groupApi.listGroups( + eq(groupProfilePrefix), + isNull(), + isNull(), + eq(1), + eq("stats"), + isNull(), + isNull(), + isNull())) + .thenReturn(mockGroupList); + + var mockEmbeddedMap = mock(Map.class); + when(mockGroup.getEmbedded()).thenReturn(mockEmbeddedMap); + var mockStatsLinkedHashMap = mock(LinkedHashMap.class); + when(mockEmbeddedMap.get(eq("stats"))).thenReturn(mockStatsLinkedHashMap); + when(mockStatsLinkedHashMap.get(eq("usersCount"))).thenReturn(10); + + var actual = _repo.getUsersCountInSingleFacility(mockFacility); + assertEquals(10, actual); + } + + @Test + void + getUsersCountInSingleFacility_throwsBadRequestException_whenUnableToRetrieveOktaGroupStats() { + String facilityInternalId = "05bc0080-ad53-4b7a-a4b5-d1f86059a304"; + String orgExternalId = "ce8782ca-18ba-4384-95fc-fbb7be0ef577"; + + var facilitySuffix = ":FACILITY_ACCESS:" + facilityInternalId; + var groupProfilePrefix = "SR-UNITTEST-TENANT:" + orgExternalId + facilitySuffix; + + var mockOrg = mock(Organization.class); + var mockFacility = mock(Facility.class); + when(mockFacility.getOrganization()).thenReturn(mockOrg); + when(mockOrg.getExternalId()).thenReturn(orgExternalId); + when(mockFacility.getInternalId()).thenReturn(UUID.fromString(facilityInternalId)); + + var mockGroup = mock(Group.class); + var mockGroupList = List.of(mockGroup); + when(groupApi.listGroups( + eq(groupProfilePrefix), + isNull(), + isNull(), + eq(1), + eq("stats"), + isNull(), + isNull(), + isNull())) + .thenReturn(mockGroupList); + + Throwable caught = + assertThrows( + BadRequestException.class, () -> _repo.getUsersCountInSingleFacility(mockFacility)); + assertEquals("Unable to retrieve okta group stats", caught.getMessage()); + } + @Test void updateUserPrivileges() { var username = "fraud@example.com"; diff --git a/backend/src/test/java/gov/cdc/usds/simplereport/service/OrganizationServiceTest.java b/backend/src/test/java/gov/cdc/usds/simplereport/service/OrganizationServiceTest.java index 7640f8f031..a3edc639b8 100644 --- a/backend/src/test/java/gov/cdc/usds/simplereport/service/OrganizationServiceTest.java +++ b/backend/src/test/java/gov/cdc/usds/simplereport/service/OrganizationServiceTest.java @@ -459,7 +459,7 @@ void getFacilityStats_withOktaMigrationDisabled_success() { UUID facilityId = UUID.randomUUID(); Facility mockFacility = mock(Facility.class); doReturn(Optional.of(mockFacility)).when(this.facilityRepository).findById(facilityId); - doReturn(2).when(oktaRepository).getUsersInSingleFacility(mockFacility); + doReturn(2).when(oktaRepository).getUsersCountInSingleFacility(mockFacility); doReturn(1).when(personRepository).countByFacilityAndIsDeleted(mockFacility, false); FacilityStats stats = _service.getFacilityStats(facilityId); @@ -479,7 +479,7 @@ void getFacilityStats_withOktaMigrationEnabled_success() { doReturn(2).when(personRepository).countByFacilityAndIsDeleted(mockFacility, false); FacilityStats stats = _service.getFacilityStats(facilityId); - verify(oktaRepository, times(0)).getUsersInSingleFacility(mockFacility); + verify(oktaRepository, times(0)).getUsersCountInSingleFacility(mockFacility); assertEquals(4, stats.getUsersSingleAccessCount()); assertEquals(2, stats.getPatientsSingleAccessCount()); }