-
Notifications
You must be signed in to change notification settings - Fork 157
/
Copy pathbatch.py
111 lines (90 loc) · 3.79 KB
/
batch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, List, Optional, Sequence
from torchdata.nodes.base_node import BaseNode, T
class Batcher(BaseNode[List[T]]):
"""Batcher node batches the data from the source node into batches of size batch_size.
If the source node is exhausted, it will return the batch or raise StopIteration.
If drop_last is True, the last batch will be dropped if it is smaller than batch_size.
If drop_last is False, the last batch will be returned even if it is smaller than batch_size.
Args:
source (BaseNode[T]): The source node to batch the data from.
batch_size (int): The size of the batch.
drop_last (bool): Whether to drop the last batch if it is smaller than batch_size. Default is True.
"""
SOURCE_KEY = "source"
def __init__(self, source: BaseNode[T], batch_size: int, drop_last: bool = True):
super().__init__()
self.source = source
self.batch_size = batch_size
self.drop_last = drop_last
def reset(self, initial_state: Optional[Dict[str, Any]] = None):
super().reset(initial_state)
if initial_state is not None:
self.source.reset(initial_state[self.SOURCE_KEY])
else:
self.source.reset()
def next(self) -> List[T]:
batch: List[T] = []
while len(batch) < self.batch_size:
try:
item = next(self.source)
except StopIteration:
break
batch.append(item)
if len(batch) == self.batch_size:
return batch
if len(batch) == self.batch_size:
return batch
elif len(batch) and not self.drop_last:
return batch
else:
raise StopIteration()
def get_state(self) -> Dict[str, Any]:
return {self.SOURCE_KEY: self.source.state_dict()}
class Unbatcher(BaseNode[T]):
"""Unbatcher will flatten batches pulled from source, and
yields elements in sequential order when next() is called on it.
Args:
source (BaseNode[T]): The source node to pull batches from.
"""
SOURCE_KEY = "source"
BATCH_IDX_KEY = "batch_idx"
def __init__(self, source: BaseNode[Sequence[T]]):
super().__init__(self)
self.source = source
def reset(self, initial_state: Optional[Dict[str, Any]] = None):
super().reset(initial_state)
if initial_state is not None:
self.source.reset(initial_state[self.SOURCE_KEY])
self._cached_state_dict = initial_state[self.SOURCE_KEY]
try:
self._batch = next(self.source)
self._batch_idx = initial_state[self.BATCH_IDX_KEY]
except StopIteration:
# next(self.source) will be called upon subsequent self.next() call
# and raise StopIteration in the correct place.
self._batch = []
self._batch_idx = 0
else:
self.source.reset()
self._batch = []
self._cached_state_dict = None
self._batch_idx = 0
def next(self) -> T:
while self._batch_idx >= len(self._batch):
self._cached_state_dict = self.source.state_dict()
self._batch = next(self.source)
self._batch_idx = 0
self._batch_idx += 1
return self._batch[self._batch_idx - 1]
def get_state(self) -> Dict[str, Any]:
if self._cached_state_dict is None:
self._cached_state_dict = self.source.state_dict()
return {
self.SOURCE_KEY: self._cached_state_dict,
self.BATCH_IDX_KEY: self._batch_idx,
}