Skip to content

Commit

Permalink
Refactor torch_helpers.py: Remove init_linalg parameter and update in…
Browse files Browse the repository at this point in the history
…it_linalg_device default value
  • Loading branch information
RichieHakim committed Apr 6, 2024
1 parent 5fd5db6 commit 22c7c2c
Showing 1 changed file with 7 additions and 10 deletions.
17 changes: 7 additions & 10 deletions bnpm/torch_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,7 @@ def initialize_torch_settings(
deterministic_cudnn: Optional[bool] = None,
deterministic_torch: Optional[bool] = None,
set_global_device: Optional[Union[str, torch.device]] = None,
init_linalg: bool = True,
init_linalg_device: Union[str, torch.device] = 'cuda:0',
init_linalg_device: Optional[Union[str, torch.device]] = None,
) -> None:
"""
Initalizes some CUDA libraries and sets some environment variables. \n
Expand All @@ -250,13 +249,11 @@ def initialize_torch_settings(
set_global_device (bool):
If ``False``, does not set the global device. If a string or torch.device,
sets the global device to the specified device.
init_linalg (bool):
If ``True``, initializes the linalg library. This is necessary to
avoid a bug. Often solves the error: "RuntimeError: lazy wrapper
should be called at most once". (Default is ``True``)
init_linalg_device (str):
The device to use for initializing the linalg library. Either a
string or a torch.device. (Default is ``'cuda:0'``)
string or a torch.device. This is necessary to avoid a bug. Often
solves the error: "RuntimeError: lazy wrapper should be called at
most once". (Default is ``None``)
"""
if benchmark is not None:
torch.backends.cudnn.benchmark = benchmark
Expand All @@ -271,9 +268,9 @@ def initialize_torch_settings(

## Initialize linalg libarary
## https://github.com/pytorch/pytorch/issues/90613
if type(init_linalg_device) is str:
init_linalg_device = torch.device(init_linalg_device)
if init_linalg:
if init_linalg_device is not None:
if type(init_linalg_device) is str:
init_linalg_device = torch.device(init_linalg_device)
torch.inverse(torch.ones((1, 1), device=init_linalg_device))
torch.linalg.qr(torch.as_tensor([[1.0, 2.0], [3.0, 4.0]], device=init_linalg_device))

Expand Down

0 comments on commit 22c7c2c

Please sign in to comment.