Skip to content

Commit

Permalink
add function to extract MLS external group
Browse files Browse the repository at this point in the history
  • Loading branch information
erikolsson committed Jan 13, 2025
1 parent fe84066 commit 55e56a1
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 2 deletions.
2 changes: 2 additions & 0 deletions core/node/events/stream_viewstate_mls.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ func (r *streamViewImpl) GetMlsGroupState() (*mls_tools.MlsGroupState, error) {
mlsGroupState.ExternalGroupSnapshot = content.Mls.GetInitializeGroup().ExternalGroupSnapshot
case *protocol.MemberPayload_Mls_ExternalJoin_:
mlsGroupState.Commits = append(mlsGroupState.Commits, content.Mls.GetExternalJoin().Commit)
case *protocol.MemberPayload_Mls_WelcomeMessage_:
mlsGroupState.Commits = append(mlsGroupState.Commits, content.Mls.GetWelcomeMessage().Commit)
default:
break
}
Expand Down
14 changes: 14 additions & 0 deletions packages/sdk/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ import { SignerContext } from './signerContext'
import { decryptAESGCM, deriveKeyAndIV, encryptAESGCM, uint8ArrayToBase64 } from './crypto_utils'
import { makeTags, makeTipTags } from './tags'
import { TipEventObject } from '@river-build/generated/dev/typings/ITipping'
import { extractMlsExternalGroup } from './mls/utils/mlsutils'

export type ClientEvents = StreamEvents & DecryptionEvents

Expand Down Expand Up @@ -2525,6 +2526,19 @@ export class Client
method: 'mls',
})
}

public async getMlsExternalGroupInfo(streamId: string): Promise<{
externalGroupSnapshot: Uint8Array
groupInfoMessage: Uint8Array
commits: { commit: Uint8Array; groupInfoMessage: Uint8Array }[]
}> {
let streamView = this.stream(streamId)?.view
if (!streamView || !streamView.isInitialized) {
streamView = await this.getStream(streamId)
}
check(isDefined(streamView), `stream not found: ${streamId}`)
return extractMlsExternalGroup(streamView)
}
}

function ensureNoHexPrefix(value: string): string {
Expand Down
57 changes: 57 additions & 0 deletions packages/sdk/src/mls/utils/mlsutils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import { check } from '@river-build/dlog'
import { IStreamStateView } from '../../streamStateView'

export function extractMlsExternalGroup(streamView: IStreamStateView): {
externalGroupSnapshot: Uint8Array
groupInfoMessage: Uint8Array
commits: { commit: Uint8Array; groupInfoMessage: Uint8Array }[]
} {
const indexOfLastSnapshot = streamView.timeline.findLastIndex((event) => {
const payload = event.remoteEvent?.event.payload
if (payload?.case !== 'miniblockHeader') {
return false
}
return payload.value.snapshot !== undefined
})

const payload = streamView.timeline[indexOfLastSnapshot].remoteEvent?.event.payload
check(payload?.case === 'miniblockHeader', 'no snapshot found')
const snapshot = payload.value.snapshot
check(snapshot !== undefined, 'no snapshot found')
const externalGroupSnapshot = snapshot.members?.mls?.externalGroupSnapshot
check(externalGroupSnapshot !== undefined, 'no externalGroupSnapshot found')
const groupInfoMessage = snapshot.members?.mls?.groupInfoMessage
check(groupInfoMessage !== undefined, 'no groupInfoMessage found')
const commits: { commit: Uint8Array; groupInfoMessage: Uint8Array }[] = []
for (let i = indexOfLastSnapshot; i < streamView.timeline.length; i++) {
const event = streamView.timeline[i]
const payload = event.remoteEvent?.event.payload
if (payload?.case !== 'memberPayload') {
continue
}
if (payload?.value?.content.case !== 'mls') {
continue
}

const mlsPayload = payload.value.content.value
switch (mlsPayload.content.case) {
case 'externalJoin':
commits.push({
commit: mlsPayload.content.value.commit,
groupInfoMessage: mlsPayload.content.value.groupInfoMessage,
})
break
case 'welcomeMessage':
commits.push({
commit: mlsPayload.content.value.commit,
groupInfoMessage: mlsPayload.content.value.groupInfoMessage,
})
break
case undefined:
break
default:
break
}
}
return { externalGroupSnapshot, groupInfoMessage, commits: commits }
}
53 changes: 51 additions & 2 deletions packages/sdk/src/tests/multi_ne/mls.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ describe('mlsTests', () => {
client: MlsClient,
groupInfoMessage: Uint8Array,
externalGroupSnapshot: Uint8Array,
): Promise<{ commit: Uint8Array; groupInfoMessage: Uint8Array }> {
): Promise<{ commit: Uint8Array; groupInfoMessage: Uint8Array; group: MlsGroup }> {
const externalClient = new ExternalClient()
const externalSnapshot = ExternalSnapshot.fromBytes(externalGroupSnapshot)
const externalGroup = await externalClient.loadGroup(externalSnapshot)
Expand All @@ -201,6 +201,7 @@ describe('mlsTests', () => {
return {
commit: commitOutput.commit.toBytes(),
groupInfoMessage: updatedGroupInfoMessage.toBytes(),
group: commitOutput.group,
}
}

Expand Down Expand Up @@ -566,6 +567,7 @@ describe('mlsTests', () => {
await expect(
bobMlsGroup.processIncomingMessage(commitOutput.commitMessage),
).resolves.not.toThrow()
latestGroupInfoMessage = groupInfoMessage!.toBytes()
commits.push(commit)
})

Expand All @@ -587,6 +589,53 @@ describe('mlsTests', () => {
})
})

test('correct external group info is returned', async () => {
const externalGroupInfo = await bobClient.getMlsExternalGroupInfo(streamId)
const externalClient = new ExternalClient()
const externalGroupSnapshot = ExternalSnapshot.fromBytes(
externalGroupInfo.externalGroupSnapshot,
)
expect(externalGroupInfo.commits.length).toBe(1)

let latestValidGroupInfoMessage = externalGroupInfo.groupInfoMessage
const externalGroup = await externalClient.loadGroup(externalGroupSnapshot)
for (const commit of externalGroupInfo.commits) {
try {
const mlsMessage = MlsMessage.fromBytes(commit.commit)
await externalGroup.processIncomingMessage(mlsMessage)
latestValidGroupInfoMessage = commit.groupInfoMessage
} catch {
// catch, in case this is an invalid commit
}
}

expect(bin_equal(latestValidGroupInfoMessage, latestGroupInfoMessage)).toBe(true)

const aliceThrowawayClient = await MlsClient.create(new Uint8Array(randomBytes(32)))
const {
commit: aliceCommit,
groupInfoMessage: aliceGroupInfoMessage,
group: aliceGroup,
} = await commitExternal(
aliceThrowawayClient,
latestValidGroupInfoMessage,
externalGroup.snapshot().toBytes(),
)
const aliceMlsPayload = makeMlsPayloadExternalJoin(
aliceThrowawayClient.signaturePublicKey(),
aliceCommit,
aliceGroupInfoMessage,
)

await expect(
bobMlsGroup.processIncomingMessage(MlsMessage.fromBytes(aliceCommit)),
).resolves.not.toThrow()

expect(bobMlsGroup.currentEpoch).toBe(aliceGroup.currentEpoch)
await expect(aliceClient._debugSendMls(streamId, aliceMlsPayload)).resolves.not.toThrow()
commits.push(aliceCommit)
})

test('devices added from key packages are snapshotted', async () => {
// force snapshot
await expect(
Expand All @@ -596,6 +645,6 @@ describe('mlsTests', () => {
// verify that the key package is picked up in the snapshot
const streamAfterSnapshot = await bobClient.getStream(streamId)
const mls = streamAfterSnapshot.membershipContent.mls
expect(mls.members[aliceClient.userId].signaturePublicKeys.length).toBe(2)
expect(mls.members[aliceClient.userId].signaturePublicKeys.length).toBe(3)
})
})

0 comments on commit 55e56a1

Please sign in to comment.