Skip to content

Commit

Permalink
[RLlib; Offline RL] 2. Multiple optimizations for streaming data. (#4…
Browse files Browse the repository at this point in the history
  • Loading branch information
simonsays1980 authored Dec 13, 2024
1 parent 31b4302 commit 797dc77
Show file tree
Hide file tree
Showing 10 changed files with 138 additions and 149 deletions.
4 changes: 3 additions & 1 deletion rllib/algorithms/marwil/tests/test_marwil.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,9 @@ def test_marwil_loss_function(self):
batch = algo.offline_data.data.take_batch(2000)

# Create the prelearner and compute advantages and values.
offline_prelearner = OfflinePreLearner(config, algo.learner_group._learner)
offline_prelearner = OfflinePreLearner(
config=config, learner=algo.learner_group._learner
)
# Note, for `ray.data`'s pipeline everything has to be a dictionary
# therefore the batch is embedded into another dictionary.
batch = offline_prelearner(batch)["batch"][0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def __init__(
self,
config: "AlgorithmConfig",
learner: Union[Learner, List[ActorHandle]],
locality_hints: Optional[List[str]] = None,
spaces: Optional[Tuple[gym.Space, gym.Space]] = None,
module_spec: Optional[MultiRLModuleSpec] = None,
module_state: Optional[Dict[ModuleID, Any]] = None,
Expand Down
32 changes: 12 additions & 20 deletions rllib/offline/offline_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@ def __init__(self, config: AlgorithmConfig):
# Use `read_parquet` as default data read method.
self.data_read_method = self.config.input_read_method
# Override default arguments for the data read method.
self.data_read_method_kwargs = (
self.default_read_method_kwargs | self.config.input_read_method_kwargs
)
self.data_read_method_kwargs = self.config.input_read_method_kwargs
# In case `EpisodeType` or `BatchType` batches are read the size
# could differ from the final `train_batch_size_per_learner`.
self.data_read_batch_size = self.config.input_read_batch_size
Expand Down Expand Up @@ -75,11 +73,12 @@ def __init__(self, config: AlgorithmConfig):
"'gcs' for GCS, 's3' for S3, or 'abs'"
)
# Add the filesystem object to the write method kwargs.
self.data_read_method_kwargs.update(
{
"filesystem": self.filesystem_object,
}
)
if self.filesystem_object:
self.data_read_method_kwargs.update(
{
"filesystem": self.filesystem_object,
}
)

try:
# Load the dataset.
Expand All @@ -90,9 +89,11 @@ def __init__(self, config: AlgorithmConfig):
if self.materialize_data:
self.data = self.data.materialize()
stop_time = time.perf_counter()
logger.debug(f"Time for loading dataset: {stop_time - start_time}s.")
logger.debug(
"===> [OfflineData] - Time for loading dataset: "
f"{stop_time - start_time}s."
)
logger.info("Reading data from {}".format(self.path))
logger.info(self.data.schema())
except Exception as e:
logger.error(e)
# Avoids reinstantiating the batch iterator each time we sample.
Expand Down Expand Up @@ -147,8 +148,7 @@ def sample(
# Add constructor `kwargs` when using remote learners.
fn_constructor_kwargs.update(
{
"learner": self.learner_handles,
"locality_hints": self.locality_hints,
"learner": None,
"module_spec": self.module_spec,
"module_state": module_state,
}
Expand Down Expand Up @@ -222,12 +222,6 @@ def sample(
num_shards=num_shards,
)

@property
def default_read_method_kwargs(self):
return {
"override_num_blocks": max(self.config.num_learners * 2, 2),
}

@property
def default_map_batches_kwargs(self):
return {
Expand All @@ -239,6 +233,4 @@ def default_map_batches_kwargs(self):
def default_iter_batches_kwargs(self):
return {
"prefetch_batches": 2,
"local_shuffle_buffer_size": self.config.train_batch_size_per_learner
or (self.config.train_batch_size // max(1, self.config.num_learners)) * 4,
}
2 changes: 1 addition & 1 deletion rllib/offline/offline_prelearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ class OfflinePreLearner:
@OverrideToImplementCustomLogic_CallToSuperRecommended
def __init__(
self,
*,
config: "AlgorithmConfig",
learner: Union[Learner, list[ActorHandle]],
locality_hints: Optional[List[str]] = None,
spaces: Optional[Tuple[gym.Space, gym.Space]] = None,
module_spec: Optional[MultiRLModuleSpec] = None,
module_state: Optional[Dict[ModuleID, Any]] = None,
Expand Down
5 changes: 2 additions & 3 deletions rllib/offline/tests/test_offline_prelearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_offline_prelearner_buffer_class(self):
algo = self.config.build()
# Build the `OfflinePreLearner` and add the learner.
oplr = OfflinePreLearner(
self.config,
config=self.config,
learner=algo.offline_data.learner_handles[0],
)

Expand Down Expand Up @@ -164,7 +164,6 @@ def test_offline_prelearner_in_map_batches(self):
).iter_batches(
batch_size=10,
prefetch_batches=1,
local_shuffle_buffer_size=100,
)

# Now sample a single batch.
Expand Down Expand Up @@ -193,7 +192,7 @@ def test_offline_prelearner_sample_from_old_sample_batch_data(self):
algo = self.config.build()
# Build the `OfflinePreLearner` and add the learner.
oplr = OfflinePreLearner(
self.config,
config=self.config,
learner=algo.offline_data.learner_handles[0],
)
# Now, pull a batch of defined size formt he dataset.
Expand Down
Loading

0 comments on commit 797dc77

Please sign in to comment.