Skip to content

Commit

Permalink
Refactor recursive functions to iterative to prevent stack overflow
Browse files Browse the repository at this point in the history
Convert `_match_prefix_helper`, `_insert_helper`, `_print_helper`, and `_total_size_helper` from recursive to iterative implementations using stacks or loop. This prevents potential stack overflow errors with deep trees while maintaining the same functionality. Also update the test cases to use tensor values instead of strings for better testing of the radix cache's core functionality.
  • Loading branch information
luzengxiangcn committed Jan 26, 2025
1 parent 6c856b4 commit d02f0d0
Showing 1 changed file with 70 additions and 53 deletions.
123 changes: 70 additions & 53 deletions python/sglang/srt/mem_cache/radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,12 @@ def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
if self.disable:
return [], self.root_node

value = []
last_node = [self.root_node]
self._match_prefix_helper(self.root_node, key, value, last_node)
value, last_node = self._match_prefix_helper(self.root_node, key)
if value:
value = torch.concat(value)
else:
value = torch.tensor([], dtype=torch.int32)
return value, last_node[0]
return value, last_node

def insert(self, key: List, value=None):
if self.disable:
Expand Down Expand Up @@ -170,7 +168,7 @@ def pretty_print(self):
print(f"#tokens: {self.total_size()}")

def total_size(self):
return self._total_size_helper(self.root_node)
return self._total_size_helper()

def evict(self, num_tokens: int, evict_callback: Callable):
if self.disable:
Expand Down Expand Up @@ -226,24 +224,23 @@ def evictable_size(self):

##### Internal Helper Functions #####

def _match_prefix_helper(
self, node: TreeNode, key: List, value, last_node: TreeNode
):
def _match_prefix_helper(self, node: TreeNode, key: List):
node.last_access_time = time.time()
if len(key) == 0:
return

if key[0] in node.children.keys():
child = node.children[key[0]]
prefix_len = _key_match(child.key, key)
if prefix_len < len(child.key):
new_node = self._split_node(child.key, child, prefix_len)
value.append(new_node.value)
last_node[0] = new_node
else:
value.append(child.value)
last_node[0] = child
self._match_prefix_helper(child, key[prefix_len:], value, last_node)
value = []
while len(key) > 0:
if key[0] in node.children.keys():
child = node.children[key[0]]
prefix_len = _key_match(child.key, key)
if prefix_len < len(child.key):
new_node = self._split_node(child.key, child, prefix_len)
value.append(new_node.value)
node = new_node
break
else:
value.append(child.value)
node = child
key = key[prefix_len:]
return value, node

def _split_node(self, key, child: TreeNode, split_len: int):
# new_node -> child
Expand All @@ -264,22 +261,25 @@ def _insert_helper(self, node: TreeNode, key: List, value):
if len(key) == 0:
return 0

if key[0] in node.children.keys():
total_prefix_length = 0
while len(key) > 0 and key[0] in node.children.keys():
child = node.children[key[0]]
child.last_access_time = time.time()
prefix_len = _key_match(child.key, key)

total_prefix_length += prefix_len
if prefix_len == len(child.key):
if prefix_len == len(key):
return prefix_len
break
else:
key = key[prefix_len:]
value = value[prefix_len:]
return prefix_len + self._insert_helper(child, key, value)

new_node = self._split_node(child.key, child, prefix_len)
return prefix_len + self._insert_helper(
new_node, key[prefix_len:], value[prefix_len:]
)
node = child
else:
new_node = self._split_node(child.key, child, prefix_len)
key = key[prefix_len:]
value = value[prefix_len:]
node = new_node
break

if len(key):
new_node = TreeNode()
Expand All @@ -288,12 +288,21 @@ def _insert_helper(self, node: TreeNode, key: List, value):
new_node.value = value
node.children[key[0]] = new_node
self.evictable_size_ += len(value)
return 0
return total_prefix_length

def _print_helper(self, node: TreeNode, indent: int):
for _, child in node.children.items():
print(" " * indent, len(child.key), child.key[:10], f"r={child.lock_ref}")
self._print_helper(child, indent=indent + 2)
"""Prints the radix tree in a human-readable format."""
stack = [(node, indent)]
while stack:
current_node, current_indent = stack.pop()
print(
" " * current_indent,
len(current_node.key),
current_node.key[:10],
f"r={current_node.lock_ref}",
)
for _, child in current_node.children.items():
stack.append((child, current_indent + 2))

def _delete_leaf(self, node):
for k, v in node.parent.children.items():
Expand All @@ -302,11 +311,14 @@ def _delete_leaf(self, node):
del node.parent.children[k]
self.evictable_size_ -= len(node.key)

def _total_size_helper(self, node: TreeNode):
x = len(node.value)
for child in node.children.values():
x += self._total_size_helper(child)
return x
def _total_size_helper(self):
total_size = 0
stack = [self.root_node]
while stack:
current_node = stack.pop()
total_size += len(current_node.value)
stack.extend(current_node.children.values())
return total_size

def _collect_leaves(self):
ret_list = []
Expand All @@ -324,20 +336,25 @@ def _collect_leaves(self):

if __name__ == "__main__":
tree = RadixCache(None, None, False)
a = torch.Tensor([1, 2, 3])
b = torch.Tensor([1, 2, 4])
c = torch.Tensor([1, 3, 5])
tree.insert([], torch.Tensor([]))
tree.insert(a.tolist(), a)
tree.insert(b.tolist(), b)
val, node = tree.match_prefix(c.tolist())
tree.insert([1, 1, 3, 5], torch.Tensor([1, 1, 3, 5]))
tree.insert([1, 1, 4, 5], torch.Tensor([1, 1, 4, 5]))

val, node = tree.match_prefix([])

tree.insert("Hello")
tree.insert("Hello")
tree.insert("Hello_L.A.!")
# tree.insert("Hello_world! Happy")
# tree.insert("I love you!")
tree.pretty_print()

# print(tree.match_prefix("I love you! aha"))

# def evict_callback(x):
# print("evict", x)
# return len(x)
def evict_callback(x):
print("evict", x)
return len(x)

# tree.evict(5, evict_callback)
# tree.evict(10, evict_callback)
# tree.pretty_print()
tree.evict(1, evict_callback)
tree.evict(1, evict_callback)
tree.evict(1, evict_callback)
tree.pretty_print()

0 comments on commit d02f0d0

Please sign in to comment.