Skip to content

Commit

Permalink
Add initialize_torch_settings function to torch_helpers.py
Browse files Browse the repository at this point in the history
  • Loading branch information
RichieHakim committed Mar 22, 2024
1 parent f7c099f commit 8defbb6
Showing 1 changed file with 61 additions and 0 deletions.
61 changes: 61 additions & 0 deletions bnpm/torch_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,67 @@ def set_device(
return device


def initialize_torch_settings(
device: Union[str, torch.device] = 'cuda:0',
benchmark: Optional[bool] = None,
enable_cudnn: Optional[bool] = None,
deterministic_cudnn: Optional[bool] = None,
deterministic_torch: Optional[bool] = None,
set_global_device: bool = True,
init_linalg: bool = True,
) -> None:
"""
Initalizes some CUDA libraries and sets some environment variables. \n
RH 2024
Args:
device (Union[str, torch.device]):
The device to use.
benchmark (Optional[bool]):
If ``True``, sets torch.backends.cudnn.benchmark to ``True``.\n
This results in the built-in cudnn auto-tuner to find the best
algorithm for the hardware. Good for when input sizes are the same
for each batch.
enable_cudnn (Optional[bool]):
If ``True``, sets torch.backends.cudnn.enabled to ``True``.\n
This enables the cudnn library.
deterministic_cudnn (Optional[bool]):
If ``True``, sets torch.backends.cudnn.deterministic to ``True``.\n
This makes cudnn deterministic. It may slow down operations.
deterministic_torch (Optional[bool]):
If ``True``, sets torch.set_deterministic to ``True``.\n
This makes torch deterministic. It may slow down operations.
set_global_device (bool):
If ``True``, sets the global device to the provided device.\n
This is discouraged in favor of explicit device setting, but useful
for when you want to set the device globally.
init_linalg (bool):
If ``True``, initializes the linalg library. This is necessary to
avoid a bug. Often solves the error: "RuntimeError: lazy wrapper
should be called at most once".
"""
if type(device) is str:
device = torch.device(device)

if benchmark is not None:
torch.backends.cudnn.benchmark = benchmark
if enable_cudnn:
torch.backends.cudnn.enabled = enable_cudnn
if deterministic_cudnn:
torch.backends.cudnn.deterministic = False
if deterministic_torch:
torch.set_deterministic(False)
if set_global_device:
torch.cuda.set_device(device)

## Initialize linalg libarary
## https://github.com/pytorch/pytorch/issues/90613
if init_linalg:
torch.inverse(torch.ones((1, 1), device=device))
torch.linalg.qr(torch.as_tensor([[1.0, 2.0], [3.0, 4.0]], device=device))



######################################
############ DATA HELPERS ############
######################################
Expand Down

0 comments on commit 8defbb6

Please sign in to comment.