Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add function to extract MLS external group #2029

Merged
merged 3 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions core/node/events/stream_viewstate_mls.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ func (r *streamViewImpl) GetMlsGroupState() (*mls_tools.MlsGroupState, error) {
mlsGroupState.ExternalGroupSnapshot = content.Mls.GetInitializeGroup().ExternalGroupSnapshot
case *protocol.MemberPayload_Mls_ExternalJoin_:
mlsGroupState.Commits = append(mlsGroupState.Commits, content.Mls.GetExternalJoin().Commit)
case *protocol.MemberPayload_Mls_WelcomeMessage_:
mlsGroupState.Commits = append(mlsGroupState.Commits, content.Mls.GetWelcomeMessage().Commit)
default:
break
}
Expand Down
14 changes: 14 additions & 0 deletions packages/sdk/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ import { SignerContext } from './signerContext'
import { decryptAESGCM, deriveKeyAndIV, encryptAESGCM, uint8ArrayToBase64 } from './crypto_utils'
import { makeTags, makeTipTags } from './tags'
import { TipEventObject } from '@river-build/generated/dev/typings/ITipping'
import { extractMlsExternalGroup } from './mls/utils/mlsutils'

export type ClientEvents = StreamEvents & DecryptionEvents

Expand Down Expand Up @@ -2525,6 +2526,19 @@ export class Client
method: 'mls',
})
}

public async getMlsExternalGroupInfo(streamId: string): Promise<{
externalGroupSnapshot: Uint8Array
groupInfoMessage: Uint8Array
commits: { commit: Uint8Array; groupInfoMessage: Uint8Array }[]
}> {
let streamView = this.stream(streamId)?.view
if (!streamView || !streamView.isInitialized) {
streamView = await this.getStream(streamId)
}
check(isDefined(streamView), `stream not found: ${streamId}`)
return extractMlsExternalGroup(streamView)
}
}

function ensureNoHexPrefix(value: string): string {
Expand Down
53 changes: 53 additions & 0 deletions packages/sdk/src/mls/utils/mlsutils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import { check } from '@river-build/dlog'
import { IStreamStateView } from '../../streamStateView'

export function extractMlsExternalGroup(streamView: IStreamStateView): {
externalGroupSnapshot: Uint8Array
groupInfoMessage: Uint8Array
commits: { commit: Uint8Array; groupInfoMessage: Uint8Array }[]
} {
const indexOfLastSnapshot = streamView.timeline.findLastIndex((event) => {
const payload = event.remoteEvent?.event.payload
if (payload?.case !== 'miniblockHeader') {
return false
}
return payload.value.snapshot !== undefined
})

const payload = streamView.timeline[indexOfLastSnapshot].remoteEvent?.event.payload
check(payload?.case === 'miniblockHeader', 'no snapshot found')
const snapshot = payload.value.snapshot
check(snapshot !== undefined, 'no snapshot found')
const externalGroupSnapshot = snapshot.members?.mls?.externalGroupSnapshot
check(externalGroupSnapshot !== undefined, 'no externalGroupSnapshot found')
const groupInfoMessage = snapshot.members?.mls?.groupInfoMessage
check(groupInfoMessage !== undefined, 'no groupInfoMessage found')
const commits: { commit: Uint8Array; groupInfoMessage: Uint8Array }[] = []
for (let i = indexOfLastSnapshot; i < streamView.timeline.length; i++) {
const event = streamView.timeline[i]
const payload = event.remoteEvent?.event.payload
if (payload?.case !== 'memberPayload') {
continue
}
if (payload?.value?.content.case !== 'mls') {
continue
}

const mlsPayload = payload.value.content.value
switch (mlsPayload.content.case) {
case 'externalJoin':
case 'welcomeMessage':
commits.push({
commit: mlsPayload.content.value.commit,
groupInfoMessage: mlsPayload.content.value.groupInfoMessage,
})
break

case undefined:
break
default:
break
}
}
return { externalGroupSnapshot, groupInfoMessage, commits: commits }
}
53 changes: 51 additions & 2 deletions packages/sdk/src/tests/multi_ne/mls.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ describe('mlsTests', () => {
client: MlsClient,
groupInfoMessage: Uint8Array,
externalGroupSnapshot: Uint8Array,
): Promise<{ commit: Uint8Array; groupInfoMessage: Uint8Array }> {
): Promise<{ commit: Uint8Array; groupInfoMessage: Uint8Array; group: MlsGroup }> {
const externalClient = new ExternalClient()
const externalSnapshot = ExternalSnapshot.fromBytes(externalGroupSnapshot)
const externalGroup = await externalClient.loadGroup(externalSnapshot)
Expand All @@ -201,6 +201,7 @@ describe('mlsTests', () => {
return {
commit: commitOutput.commit.toBytes(),
groupInfoMessage: updatedGroupInfoMessage.toBytes(),
group: commitOutput.group,
}
}

Expand Down Expand Up @@ -566,6 +567,7 @@ describe('mlsTests', () => {
await expect(
bobMlsGroup.processIncomingMessage(commitOutput.commitMessage),
).resolves.not.toThrow()
latestGroupInfoMessage = groupInfoMessage!.toBytes()
commits.push(commit)
})

Expand All @@ -587,6 +589,53 @@ describe('mlsTests', () => {
})
})

test('correct external group info is returned', async () => {
const externalGroupInfo = await bobClient.getMlsExternalGroupInfo(streamId)
const externalClient = new ExternalClient()
const externalGroupSnapshot = ExternalSnapshot.fromBytes(
externalGroupInfo.externalGroupSnapshot,
)
expect(externalGroupInfo.commits.length).toBe(1)

let latestValidGroupInfoMessage = externalGroupInfo.groupInfoMessage
const externalGroup = await externalClient.loadGroup(externalGroupSnapshot)
for (const commit of externalGroupInfo.commits) {
try {
const mlsMessage = MlsMessage.fromBytes(commit.commit)
await externalGroup.processIncomingMessage(mlsMessage)
latestValidGroupInfoMessage = commit.groupInfoMessage
} catch {
// catch, in case this is an invalid commit
}
}

expect(bin_equal(latestValidGroupInfoMessage, latestGroupInfoMessage)).toBe(true)

const aliceThrowawayClient = await MlsClient.create(new Uint8Array(randomBytes(32)))
const {
commit: aliceCommit,
groupInfoMessage: aliceGroupInfoMessage,
group: aliceGroup,
} = await commitExternal(
aliceThrowawayClient,
latestValidGroupInfoMessage,
externalGroup.snapshot().toBytes(),
)
const aliceMlsPayload = makeMlsPayloadExternalJoin(
aliceThrowawayClient.signaturePublicKey(),
aliceCommit,
aliceGroupInfoMessage,
)

await expect(
bobMlsGroup.processIncomingMessage(MlsMessage.fromBytes(aliceCommit)),
).resolves.not.toThrow()

expect(bobMlsGroup.currentEpoch).toBe(aliceGroup.currentEpoch)
await expect(aliceClient._debugSendMls(streamId, aliceMlsPayload)).resolves.not.toThrow()
commits.push(aliceCommit)
})

test('devices added from key packages are snapshotted', async () => {
// force snapshot
await expect(
Expand All @@ -596,6 +645,6 @@ describe('mlsTests', () => {
// verify that the key package is picked up in the snapshot
const streamAfterSnapshot = await bobClient.getStream(streamId)
const mls = streamAfterSnapshot.membershipContent.mls
expect(mls.members[aliceClient.userId].signaturePublicKeys.length).toBe(2)
expect(mls.members[aliceClient.userId].signaturePublicKeys.length).toBe(3)
})
})
Loading