diff --git a/sign/mask_test.go b/sign/mask_test.go index 06a00bb4..5d36cacb 100644 --- a/sign/mask_test.go +++ b/sign/mask_test.go @@ -45,20 +45,56 @@ func TestMask_CreateMask(t *testing.T) { require.Error(t, err) } -func TestMask_SetBit(t *testing.T) { +func TestMask_SetGetBit(t *testing.T) { mask, err := NewMask(suite, publics, publics[2]) require.NoError(t, err) + // Make sure the mask is initially as we'd expect. + + bit, err := mask.GetBit(1) + require.NoError(t, err) + require.False(t, bit) + + bit, err = mask.GetBit(2) + require.NoError(t, err) + require.True(t, bit) + + // Set bit 1 + err = mask.SetBit(1, true) require.NoError(t, err) require.Equal(t, uint8(0x6), mask.Mask()[0]) require.Equal(t, 2, len(mask.Participants())) + bit, err = mask.GetBit(1) + require.NoError(t, err) + require.True(t, bit) + + // Unset bit 2 + err = mask.SetBit(2, false) require.NoError(t, err) require.Equal(t, uint8(0x2), mask.Mask()[0]) require.Equal(t, 1, len(mask.Participants())) + bit, err = mask.GetBit(2) + require.NoError(t, err) + require.False(t, bit) + + // Unset bit 10 (using byte 2 now) + + err = mask.SetBit(10, false) + require.NoError(t, err) + require.Equal(t, uint8(0x2), mask.Mask()[0]) + require.Equal(t, uint8(0x4), mask.Mask()[1]) + require.Equal(t, 2, len(mask.Participants())) + + bit, err = mask.GetBit(10) + require.NoError(t, err) + require.True(t, bit) + + // And make sure the range limit works. + err = mask.SetBit(-1, true) require.Error(t, err) err = mask.SetBit(len(publics), true)