Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib; Offline RL] 2. Multiple optimizations for streaming data. #49195

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
2f526cc
Optimized RLUnplugged benchmark example and changed name. Furthermore…
simonsays1980 Nov 8, 2024
df8afbe
Adapted all tuned examples to the optimizations made in this branch.
simonsays1980 Nov 8, 2024
5a179ae
Removed call for schema completely to not hurt perforamance anyhow. F…
simonsays1980 Nov 8, 2024
a407064
Added a scheduling strategy for multi-node and multi-learner setups. …
simonsays1980 Nov 8, 2024
af757e8
Changed dataset iterations per learner to a higher value to improve r…
simonsays1980 Nov 11, 2024
716fa69
Removed local shuffle buffer from 'OfflineData''s defaults b/c it slo…
simonsays1980 Nov 20, 2024
db07517
Merged master and resolved conflicts.
simonsays1980 Dec 10, 2024
c296b56
Removed 'locality_hints' from 'OfflinePreLearner' b/c it is not neede…
simonsays1980 Dec 10, 2024
1d9e02c
Removed unused customization in 'RLUnplugged' example.
simonsays1980 Dec 10, 2024
55d11a4
Merge branch 'master' into offline-rl-streaming-optimizations
simonsays1980 Dec 11, 2024
929f6f7
Changed the signature of the 'OfflinePreLearner' to force keyword arg…
simonsays1980 Dec 11, 2024
68dfd70
Removed ALE syntax in RLUnplugged example b/c gymnasium 1.x does not …
simonsays1980 Dec 11, 2024
6904042
Merge branch 'master' into offline-rl-streaming-optimizations
simonsays1980 Dec 12, 2024
45924b5
Fixed small bug in MARWIL test file that was due to forced keyword ar…
simonsays1980 Dec 12, 2024
48df937
Merge branch 'master' into offline-rl-streaming-optimizations
simonsays1980 Dec 13, 2024
5827253
Added keyword to first argument in 'OfflinePreLearner' initialisation…
simonsays1980 Dec 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion rllib/algorithms/marwil/tests/test_marwil.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,9 @@ def test_marwil_loss_function(self):
batch = algo.offline_data.data.take_batch(2000)

# Create the prelearner and compute advantages and values.
offline_prelearner = OfflinePreLearner(config, algo.learner_group._learner)
offline_prelearner = OfflinePreLearner(
config=config, learner=algo.learner_group._learner
)
# Note, for `ray.data`'s pipeline everything has to be a dictionary
# therefore the batch is embedded into another dictionary.
batch = offline_prelearner(batch)["batch"][0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def __init__(
self,
config: "AlgorithmConfig",
learner: Union[Learner, List[ActorHandle]],
locality_hints: Optional[List[str]] = None,
spaces: Optional[Tuple[gym.Space, gym.Space]] = None,
module_spec: Optional[MultiRLModuleSpec] = None,
module_state: Optional[Dict[ModuleID, Any]] = None,
Expand Down
32 changes: 12 additions & 20 deletions rllib/offline/offline_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@ def __init__(self, config: AlgorithmConfig):
# Use `read_parquet` as default data read method.
self.data_read_method = self.config.input_read_method
# Override default arguments for the data read method.
self.data_read_method_kwargs = (
self.default_read_method_kwargs | self.config.input_read_method_kwargs
)
self.data_read_method_kwargs = self.config.input_read_method_kwargs
# In case `EpisodeType` or `BatchType` batches are read the size
# could differ from the final `train_batch_size_per_learner`.
self.data_read_batch_size = self.config.input_read_batch_size
Expand Down Expand Up @@ -75,11 +73,12 @@ def __init__(self, config: AlgorithmConfig):
"'gcs' for GCS, 's3' for S3, or 'abs'"
)
# Add the filesystem object to the write method kwargs.
self.data_read_method_kwargs.update(
{
"filesystem": self.filesystem_object,
}
)
if self.filesystem_object:
self.data_read_method_kwargs.update(
{
"filesystem": self.filesystem_object,
}
)

try:
# Load the dataset.
Expand All @@ -90,9 +89,11 @@ def __init__(self, config: AlgorithmConfig):
if self.materialize_data:
self.data = self.data.materialize()
stop_time = time.perf_counter()
logger.debug(f"Time for loading dataset: {stop_time - start_time}s.")
logger.debug(
"===> [OfflineData] - Time for loading dataset: "
f"{stop_time - start_time}s."
)
logger.info("Reading data from {}".format(self.path))
logger.info(self.data.schema())
except Exception as e:
logger.error(e)
# Avoids reinstantiating the batch iterator each time we sample.
Expand Down Expand Up @@ -146,8 +147,7 @@ def sample(
# Add constructor `kwargs` when using remote learners.
fn_constructor_kwargs.update(
{
"learner": self.learner_handles,
"locality_hints": self.locality_hints,
"learner": None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, what changed here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Locality hints are gone. The learner will go in a third PR blocked by the Ray Data w/ Ray Tune issue

"module_spec": self.module_spec,
"module_state": module_state,
}
Expand Down Expand Up @@ -220,12 +220,6 @@ def sample(
num_shards=num_shards,
)

@property
def default_read_method_kwargs(self):
return {
"override_num_blocks": max(self.config.num_learners * 2, 2),
}

@property
def default_map_batches_kwargs(self):
return {
Expand All @@ -237,6 +231,4 @@ def default_map_batches_kwargs(self):
def default_iter_batches_kwargs(self):
return {
"prefetch_batches": 2,
"local_shuffle_buffer_size": self.config.train_batch_size_per_learner
or (self.config.train_batch_size // max(1, self.config.num_learners)) * 4,
}
2 changes: 1 addition & 1 deletion rllib/offline/offline_prelearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ class OfflinePreLearner:
@OverrideToImplementCustomLogic_CallToSuperRecommended
def __init__(
self,
*,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

config: "AlgorithmConfig",
learner: Union[Learner, list[ActorHandle]],
locality_hints: Optional[List[str]] = None,
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
spaces: Optional[Tuple[gym.Space, gym.Space]] = None,
module_spec: Optional[MultiRLModuleSpec] = None,
module_state: Optional[Dict[ModuleID, Any]] = None,
Expand Down
5 changes: 2 additions & 3 deletions rllib/offline/tests/test_offline_prelearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_offline_prelearner_buffer_class(self):
algo = self.config.build()
# Build the `OfflinePreLearner` and add the learner.
oplr = OfflinePreLearner(
self.config,
config=self.config,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

learner=algo.offline_data.learner_handles[0],
)

Expand Down Expand Up @@ -164,7 +164,6 @@ def test_offline_prelearner_in_map_batches(self):
).iter_batches(
batch_size=10,
prefetch_batches=1,
local_shuffle_buffer_size=100,
)

# Now sample a single batch.
Expand Down Expand Up @@ -193,7 +192,7 @@ def test_offline_prelearner_sample_from_old_sample_batch_data(self):
algo = self.config.build()
# Build the `OfflinePreLearner` and add the learner.
oplr = OfflinePreLearner(
self.config,
config=self.config,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

learner=algo.offline_data.learner_handles[0],
)
# Now, pull a batch of defined size formt he dataset.
Expand Down
Loading
Loading