diff --git a/core/node/events/stream_viewstate_mls.go b/core/node/events/stream_viewstate_mls.go index 2495d6639..185f5d174 100644 --- a/core/node/events/stream_viewstate_mls.go +++ b/core/node/events/stream_viewstate_mls.go @@ -77,6 +77,8 @@ func (r *streamViewImpl) GetMlsGroupState() (*mls_tools.MlsGroupState, error) { mlsGroupState.ExternalGroupSnapshot = content.Mls.GetInitializeGroup().ExternalGroupSnapshot case *protocol.MemberPayload_Mls_ExternalJoin_: mlsGroupState.Commits = append(mlsGroupState.Commits, content.Mls.GetExternalJoin().Commit) + case *protocol.MemberPayload_Mls_WelcomeMessage_: + mlsGroupState.Commits = append(mlsGroupState.Commits, content.Mls.GetWelcomeMessage().Commit) default: break } diff --git a/packages/sdk/src/client.ts b/packages/sdk/src/client.ts index 2f7f477c2..9158ff171 100644 --- a/packages/sdk/src/client.ts +++ b/packages/sdk/src/client.ts @@ -152,6 +152,7 @@ import { SignerContext } from './signerContext' import { decryptAESGCM, deriveKeyAndIV, encryptAESGCM, uint8ArrayToBase64 } from './crypto_utils' import { makeTags, makeTipTags } from './tags' import { TipEventObject } from '@river-build/generated/dev/typings/ITipping' +import { extractMlsExternalGroup } from './mls/utils/mlsutils' export type ClientEvents = StreamEvents & DecryptionEvents @@ -2525,6 +2526,19 @@ export class Client method: 'mls', }) } + + public async getMlsExternalGroupInfo(streamId: string): Promise<{ + externalGroupSnapshot: Uint8Array + groupInfoMessage: Uint8Array + commits: { commit: Uint8Array; groupInfoMessage: Uint8Array }[] + }> { + let streamView = this.stream(streamId)?.view + if (!streamView || !streamView.isInitialized) { + streamView = await this.getStream(streamId) + } + check(isDefined(streamView), `stream not found: ${streamId}`) + return extractMlsExternalGroup(streamView) + } } function ensureNoHexPrefix(value: string): string { diff --git a/packages/sdk/src/mls/utils/mlsutils.ts b/packages/sdk/src/mls/utils/mlsutils.ts new file mode 100644 index 000000000..f5b5ed7b2 --- /dev/null +++ b/packages/sdk/src/mls/utils/mlsutils.ts @@ -0,0 +1,53 @@ +import { check } from '@river-build/dlog' +import { IStreamStateView } from '../../streamStateView' + +export function extractMlsExternalGroup(streamView: IStreamStateView): { + externalGroupSnapshot: Uint8Array + groupInfoMessage: Uint8Array + commits: { commit: Uint8Array; groupInfoMessage: Uint8Array }[] +} { + const indexOfLastSnapshot = streamView.timeline.findLastIndex((event) => { + const payload = event.remoteEvent?.event.payload + if (payload?.case !== 'miniblockHeader') { + return false + } + return payload.value.snapshot !== undefined + }) + + const payload = streamView.timeline[indexOfLastSnapshot].remoteEvent?.event.payload + check(payload?.case === 'miniblockHeader', 'no snapshot found') + const snapshot = payload.value.snapshot + check(snapshot !== undefined, 'no snapshot found') + const externalGroupSnapshot = snapshot.members?.mls?.externalGroupSnapshot + check(externalGroupSnapshot !== undefined, 'no externalGroupSnapshot found') + const groupInfoMessage = snapshot.members?.mls?.groupInfoMessage + check(groupInfoMessage !== undefined, 'no groupInfoMessage found') + const commits: { commit: Uint8Array; groupInfoMessage: Uint8Array }[] = [] + for (let i = indexOfLastSnapshot; i < streamView.timeline.length; i++) { + const event = streamView.timeline[i] + const payload = event.remoteEvent?.event.payload + if (payload?.case !== 'memberPayload') { + continue + } + if (payload?.value?.content.case !== 'mls') { + continue + } + + const mlsPayload = payload.value.content.value + switch (mlsPayload.content.case) { + case 'externalJoin': + case 'welcomeMessage': + commits.push({ + commit: mlsPayload.content.value.commit, + groupInfoMessage: mlsPayload.content.value.groupInfoMessage, + }) + break + + case undefined: + break + default: + break + } + } + return { externalGroupSnapshot, groupInfoMessage, commits: commits } +} diff --git a/packages/sdk/src/tests/multi_ne/mls.test.ts b/packages/sdk/src/tests/multi_ne/mls.test.ts index bc499bff5..54bee5716 100644 --- a/packages/sdk/src/tests/multi_ne/mls.test.ts +++ b/packages/sdk/src/tests/multi_ne/mls.test.ts @@ -187,7 +187,7 @@ describe('mlsTests', () => { client: MlsClient, groupInfoMessage: Uint8Array, externalGroupSnapshot: Uint8Array, - ): Promise<{ commit: Uint8Array; groupInfoMessage: Uint8Array }> { + ): Promise<{ commit: Uint8Array; groupInfoMessage: Uint8Array; group: MlsGroup }> { const externalClient = new ExternalClient() const externalSnapshot = ExternalSnapshot.fromBytes(externalGroupSnapshot) const externalGroup = await externalClient.loadGroup(externalSnapshot) @@ -201,6 +201,7 @@ describe('mlsTests', () => { return { commit: commitOutput.commit.toBytes(), groupInfoMessage: updatedGroupInfoMessage.toBytes(), + group: commitOutput.group, } } @@ -566,6 +567,7 @@ describe('mlsTests', () => { await expect( bobMlsGroup.processIncomingMessage(commitOutput.commitMessage), ).resolves.not.toThrow() + latestGroupInfoMessage = groupInfoMessage!.toBytes() commits.push(commit) }) @@ -587,6 +589,53 @@ describe('mlsTests', () => { }) }) + test('correct external group info is returned', async () => { + const externalGroupInfo = await bobClient.getMlsExternalGroupInfo(streamId) + const externalClient = new ExternalClient() + const externalGroupSnapshot = ExternalSnapshot.fromBytes( + externalGroupInfo.externalGroupSnapshot, + ) + expect(externalGroupInfo.commits.length).toBe(1) + + let latestValidGroupInfoMessage = externalGroupInfo.groupInfoMessage + const externalGroup = await externalClient.loadGroup(externalGroupSnapshot) + for (const commit of externalGroupInfo.commits) { + try { + const mlsMessage = MlsMessage.fromBytes(commit.commit) + await externalGroup.processIncomingMessage(mlsMessage) + latestValidGroupInfoMessage = commit.groupInfoMessage + } catch { + // catch, in case this is an invalid commit + } + } + + expect(bin_equal(latestValidGroupInfoMessage, latestGroupInfoMessage)).toBe(true) + + const aliceThrowawayClient = await MlsClient.create(new Uint8Array(randomBytes(32))) + const { + commit: aliceCommit, + groupInfoMessage: aliceGroupInfoMessage, + group: aliceGroup, + } = await commitExternal( + aliceThrowawayClient, + latestValidGroupInfoMessage, + externalGroup.snapshot().toBytes(), + ) + const aliceMlsPayload = makeMlsPayloadExternalJoin( + aliceThrowawayClient.signaturePublicKey(), + aliceCommit, + aliceGroupInfoMessage, + ) + + await expect( + bobMlsGroup.processIncomingMessage(MlsMessage.fromBytes(aliceCommit)), + ).resolves.not.toThrow() + + expect(bobMlsGroup.currentEpoch).toBe(aliceGroup.currentEpoch) + await expect(aliceClient._debugSendMls(streamId, aliceMlsPayload)).resolves.not.toThrow() + commits.push(aliceCommit) + }) + test('devices added from key packages are snapshotted', async () => { // force snapshot await expect( @@ -596,6 +645,6 @@ describe('mlsTests', () => { // verify that the key package is picked up in the snapshot const streamAfterSnapshot = await bobClient.getStream(streamId) const mls = streamAfterSnapshot.membershipContent.mls - expect(mls.members[aliceClient.userId].signaturePublicKeys.length).toBe(2) + expect(mls.members[aliceClient.userId].signaturePublicKeys.length).toBe(3) }) })