diff --git a/docs/changelog.md b/docs/changelog.md index 2457d5ffdd..2e82750aae 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -100,6 +100,8 @@ what we publish. ### Standard library changes +- Add a new `validate` parameter to the `b64decode()` function. + - The free floating functions for constructing different types have been deprecated for actual constructors: diff --git a/stdlib/src/base64/base64.mojo b/stdlib/src/base64/base64.mojo index c126834e25..f0929a0f3e 100644 --- a/stdlib/src/base64/base64.mojo +++ b/stdlib/src/base64/base64.mojo @@ -34,7 +34,7 @@ from ._b64encode import b64encode_with_buffers as _b64encode_with_buffers @always_inline -fn _ascii_to_value(char: StringSlice) -> Int: +fn _ascii_to_value[validate: Bool = False](char: StringSlice) raises -> Int: """Converts an ASCII character to its integer value for base64 decoding. Args: @@ -58,6 +58,12 @@ fn _ascii_to_value(char: StringSlice) -> Int: elif char == "/": return 63 else: + + @parameter + if validate: + raise Error( + 'ValueError: Unexpected character "{}" encountered'.format(char) + ) return -1 @@ -112,9 +118,12 @@ fn b64encode(input_bytes: Span[Byte, _]) -> String: @always_inline -fn b64decode(str: StringSlice) -> String: +fn b64decode[validate: Bool = False](str: StringSlice) raises -> String: """Performs base64 decoding on the input string. + Parameters: + validate: If true, the function will validate the input string. + Args: str: A base64 encoded string. @@ -122,21 +131,22 @@ fn b64decode(str: StringSlice) -> String: The decoded string. """ var n = str.byte_length() - debug_assert(n % 4 == 0, "Input length must be divisible by 4") + + @parameter + if validate: + if n % 4 != 0: + raise Error( + "ValueError: Input length {} must be divisible by 4".format(n) + ) var p = String._buffer_type(capacity=n + 1) # This algorithm is based on https://arxiv.org/abs/1704.00605 for i in range(0, n, 4): - var a = _ascii_to_value(str[i]) - var b = _ascii_to_value(str[i + 1]) - var c = _ascii_to_value(str[i + 2]) - var d = _ascii_to_value(str[i + 3]) - - debug_assert( - a >= 0 and b >= 0 and c >= 0 and d >= 0, - "Unexpected character encountered", - ) + var a = _ascii_to_value[validate](str[i]) + var b = _ascii_to_value[validate](str[i + 1]) + var c = _ascii_to_value[validate](str[i + 2]) + var d = _ascii_to_value[validate](str[i + 3]) p.append((a << 2) | (b >> 4)) if str[i + 2] == "=": diff --git a/stdlib/test/base64/test_base64.mojo b/stdlib/test/base64/test_base64.mojo index 6512844905..f8b7993c00 100644 --- a/stdlib/test/base64/test_base64.mojo +++ b/stdlib/test/base64/test_base64.mojo @@ -14,7 +14,7 @@ from base64 import b16decode, b16encode, b64decode, b64encode -from testing import assert_equal +from testing import assert_equal, assert_raises def test_b64encode(): @@ -60,6 +60,16 @@ def test_b64decode(): assert_equal(b64decode("QUJDREVGYWJjZGVm"), "ABCDEFabcdef") + with assert_raises( + contains="ValueError: Input length 21 must be divisible by 4" + ): + _ = b64decode[validate=True]("invalid base64 string") + + with assert_raises( + contains='ValueError: Unexpected character " " encountered' + ): + _ = b64decode[validate=True]("invalid base64 string!!!") + def test_b16encode(): assert_equal(b16encode("a"), "61")