Skip to content

Commit

Permalink
Switch to pynvml_utils.smi for PyNVML 12
Browse files Browse the repository at this point in the history
  • Loading branch information
jakirkham committed Jan 13, 2025
1 parent 8507cbf commit e305416
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
# Copyright (c) 2023-2025, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -36,7 +36,7 @@
def init_pytorch_worker(rank: int, use_rmm_torch_allocator: bool = False) -> None:
import cupy
import rmm
from pynvml.smi import nvidia_smi
from pynvml_utils.smi import nvidia_smi

smi = nvidia_smi.getInstance()
pool_size = 16e9 # FIXME calculate this
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
# Copyright (c) 2023-2025, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -201,7 +201,7 @@ def train(self):
)
logger.info(f"total time: {total_time_iter}")

# from pynvml.smi import nvidia_smi
# from pynvml_utils.smi import nvidia_smi
# mem_info = nvidia_smi.getInstance().DeviceQuery('memory.free, memory.total')['gpu'][self.rank % 8]['fb_memory_usage']
# logger.info(f"rank {self.rank} memory: {mem_info}")

Expand Down
4 changes: 2 additions & 2 deletions python/utils/gpu_metric_poller.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2018-2022, NVIDIA CORPORATION.
# Copyright (c) 2018-2025, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -31,7 +31,7 @@
import os
import sys
import threading
from pynvml import smi
from pynvml_utils import smi


class GPUMetricPoller(threading.Thread):
Expand Down

0 comments on commit e305416

Please sign in to comment.