diff --git a/core/common/src/ByteStrings.kt b/core/common/src/ByteStrings.kt index 8dd269c2..a1468956 100644 --- a/core/common/src/ByteStrings.kt +++ b/core/common/src/ByteStrings.kt @@ -10,6 +10,7 @@ import kotlinx.io.bytestring.isEmpty import kotlinx.io.bytestring.unsafe.UnsafeByteStringApi import kotlinx.io.bytestring.unsafe.UnsafeByteStringOperations import kotlinx.io.unsafe.UnsafeBufferOperations +import kotlin.math.max import kotlin.math.min /** @@ -85,10 +86,14 @@ public fun Source.readByteString(byteCount: Int): ByteString { * expands the source's buffer as necessary until [byteString] is found. This reads an unbounded number of * bytes into the buffer. Returns `-1` if the stream is exhausted before the requested bytes are found. * + * For empty byte strings this function returns [startIndex] if it lays within underlying buffer's bounds, + * `0` if [startIndex] was negative and the size of the underlying buffer if [startIndex] exceeds its size. + * If the [startIndex] value was greater than the underlying buffer's size, the data will be fetched and buffered + * despite the [byteString] is empty. + * * @param byteString the sequence of bytes to find within the source. * @param startIndex the index into the source to start searching from. * - * @throws IllegalArgumentException if [startIndex] is negative. * @throws IllegalStateException if the source is closed. * @throws IOException when some I/O error occurs. * @@ -96,10 +101,11 @@ public fun Source.readByteString(byteCount: Int): ByteString { */ @OptIn(InternalIoApi::class, UnsafeByteStringApi::class) public fun Source.indexOf(byteString: ByteString, startIndex: Long = 0): Long { - require(startIndex >= 0) { "startIndex: $startIndex" } + val startIndex = max(0, startIndex) if (byteString.isEmpty()) { - return 0 + request(startIndex) + return min(startIndex, buffer.size) } var offset = startIndex @@ -117,12 +123,22 @@ public fun Source.indexOf(byteString: ByteString, startIndex: Long = 0): Long { return -1 } +/** + * Returns the index of the first match for [byteString] in the buffer at or after [startIndex]. + * + * For empty byte strings this function returns [startIndex] if it lays within buffer's bounds, + * `0` if [startIndex] was negative and [Buffer.size] if it was greater or equal to [Buffer.size]. + * + * @param byteString the sequence of bytes to find within the buffer. + * @param startIndex the index into the buffer to start searching from. + * + * @sample kotlinx.io.samples.ByteStringSamples.indexOfByteString + */ @OptIn(UnsafeByteStringApi::class) public fun Buffer.indexOf(byteString: ByteString, startIndex: Long = 0): Long { - require(startIndex <= size) { - "startIndex ($startIndex) should not exceed size ($size)" - } - if (byteString.isEmpty()) return 0 + val startIndex = max(0, min(startIndex, size)) + + if (byteString.isEmpty()) return startIndex if (startIndex > size - byteString.size) return -1L UnsafeByteStringOperations.withByteArrayUnsafe(byteString) { byteStringData -> diff --git a/core/common/test/AbstractSourceTest.kt b/core/common/test/AbstractSourceTest.kt index 71347afc..867d6cd3 100644 --- a/core/common/test/AbstractSourceTest.kt +++ b/core/common/test/AbstractSourceTest.kt @@ -1723,12 +1723,15 @@ abstract class AbstractBufferedSourceTest internal constructor( sink.writeString("flop flip flop") sink.emit() assertEquals(10, source.indexOf("flop".encodeToByteString(), 1)) + assertEquals(0, source.indexOf("flop".encodeToByteString(), -1)) source.readString() // Clear stream - // Make sure we backtrack and resume searching after partial match. + // Make sure we backtrack and resume searching after the partial match. sink.writeString("hi hi hi hi hey") sink.emit() assertEquals(6, source.indexOf("hi hi hey".encodeToByteString(), 1)) + + assertEquals(-1, source.indexOf("ho ho ho".encodeToByteString(), 9001)) } @Test @@ -1738,13 +1741,8 @@ abstract class AbstractBufferedSourceTest internal constructor( sink.writeString("blablabla") sink.emit() assertEquals(0, source.indexOf(ByteString())) - } - - @Test - fun indexOfByteStringInvalidArgumentsThrows() { - assertFailsWith { - source.indexOf("hi".encodeToByteString(), -1) - } + assertEquals(0, source.indexOf(ByteString(), -1)) + assertEquals(9, source.indexOf(ByteString(), 100000)) } /** diff --git a/core/common/test/CommonBufferTest.kt b/core/common/test/CommonBufferTest.kt index d009c5f6..d2b00cbd 100644 --- a/core/common/test/CommonBufferTest.kt +++ b/core/common/test/CommonBufferTest.kt @@ -24,6 +24,7 @@ import kotlinx.io.bytestring.ByteString import kotlinx.io.bytestring.encodeToByteString import kotlin.test.Test import kotlin.test.assertEquals +import kotlin.test.assertFails import kotlin.test.assertFailsWith import kotlin.test.assertTrue @@ -618,4 +619,24 @@ class CommonBufferTest { assertEquals(null, dst.head?.prev) assertEquals(null, dst.tail?.next) } + + @Test + fun indexOfByteString() { + val buffer = Buffer() + buffer.writeString("hello") + + assertEquals(-1, buffer.indexOf(ByteString(1, 2, 3), -1)) + assertEquals(-1, buffer.indexOf(ByteString(1, 2, 3), 10)) + + assertEquals(2, buffer.indexOf("ll".encodeToByteString())) + assertEquals(2, buffer.indexOf("ll".encodeToByteString(), 2)) + assertEquals(2, buffer.indexOf("ll".encodeToByteString(), -2)) + assertEquals(-1, buffer.indexOf("ll".encodeToByteString(), 3)) + assertEquals(-1, buffer.indexOf("hello world".encodeToByteString())) + + assertEquals(0, buffer.indexOf(ByteString())) + assertEquals(buffer.size, buffer.indexOf(ByteString(), 1000)) + assertEquals(1, buffer.indexOf(ByteString(), 1)) + assertEquals(0, buffer.indexOf(ByteString(), -1)) + } }