Skip to content

Commit

Permalink
Transform Policy API to ListenForPolicyViolations
Browse files Browse the repository at this point in the history
This adds a new api ListenForPolicyViolations for setting policy
for all the gpus instead of individual gpus.

The primary modifications include the deprecation of the existing
Policy API and the introduction of the ListenForPolicyViolations API.
Policy API is deprecated due to usability constraints. Moreover
listening to policy violations for one gpu at a time is not very useful.
The new API enables users to set policies for all GPUs collectively,
eliminating the need to configure individual GPUs separately.
Additionally, ListenForPolicyViolations API allows users to register
and monitor policy violations across all GPUs concurrently. Policy
callbacks are required to be registered only once during the lifetime
of the program.

Signed-off-by: Dileep Ranganathan <[email protected]>
  • Loading branch information
dran-dev committed Jan 8, 2024
1 parent be68ae5 commit cae9969
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 85 deletions.
8 changes: 5 additions & 3 deletions pkg/dcgm/api.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dcgm

import (
"context"
"fmt"
"os"
"sync"
Expand Down Expand Up @@ -103,9 +104,10 @@ func HealthCheckByGpuId(gpuId uint) (DeviceHealth, error) {
return healthCheckByGpuId(gpuId)
}

// Policy sets GPU usage and error policies and notifies in case of any violations via callback functions
func Policy(gpuId uint, typ ...policyCondition) (<-chan PolicyViolation, error) {
return registerPolicy(gpuId, typ...)
// ListenForPolicyViolations sets GPU usage and error policies and notifies in case of any violations
func ListenForPolicyViolations(ctx context.Context, typ ...policyCondition) (<-chan PolicyViolation, error) {
groupId := GroupAllGPUs()
return registerPolicy(ctx, groupId, typ...)
}

// Introspect returns DCGM hostengine memory and CPU usage
Expand Down
3 changes: 3 additions & 0 deletions pkg/dcgm/bcast.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,5 +80,8 @@ func (p *publisher) broadcast() {
}

func (p *publisher) closePublisher() {
for _, s := range p.subscriberList() {
p.remove(s)
}
p.close <- true
}
63 changes: 32 additions & 31 deletions pkg/dcgm/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ extern int violationNotify(void* p);
*/
import "C"
import (
"context"
"encoding/binary"
"fmt"
"log"
"math/rand"
"sync"
"time"
"unsafe"
Expand Down Expand Up @@ -262,7 +262,7 @@ func setPolicy(groupId GroupHandle, condition C.dcgmPolicyCondition_t, paramList
for _, key := range paramList {
conditionParam, exists := paramMap[policyIndex(key)]
if !exists {
return fmt.Errorf("Error: Invalid Policy condition, %v does not exist.\n", key)
return fmt.Errorf("Error: Invalid Policy condition, %v does not exist", key)
}
// set policy condition parameters
// set condition type (bool or longlong)
Expand All @@ -287,23 +287,11 @@ func setPolicy(groupId GroupHandle, condition C.dcgmPolicyCondition_t, paramList
return
}

func registerPolicy(gpuId uint, typ ...policyCondition) (<-chan PolicyViolation, error) {
func registerPolicy(ctx context.Context, groupId GroupHandle, typ ...policyCondition) (<-chan PolicyViolation, error) {
// init policy globals for internal API
makePolicyChannels()
makePolicyParmsMap()

name := fmt.Sprintf("policy%d", rand.Uint64())
groupId, err := CreateGroup(name)
if err != nil {
return nil, err
}

if err = AddToGroup(groupId, gpuId); err != nil {
return nil, err
}

// make a list of all callback channels
var channels []chan PolicyViolation
// make a list of policy conditions for setting their parameters
var paramKeys []policyIndex
// get all conditions to be set in setPolicy()
Expand All @@ -313,54 +301,67 @@ func registerPolicy(gpuId uint, typ ...policyCondition) (<-chan PolicyViolation,
case DbePolicy:
paramKeys = append(paramKeys, dbePolicyIndex)
condition |= C.DCGM_POLICY_COND_DBE
channels = append(channels, callbacks["dbe"])
case PCIePolicy:
paramKeys = append(paramKeys, pciePolicyIndex)
condition |= C.DCGM_POLICY_COND_PCI
channels = append(channels, callbacks["pcie"])
case MaxRtPgPolicy:
paramKeys = append(paramKeys, maxRtPgPolicyIndex)
condition |= C.DCGM_POLICY_COND_MAX_PAGES_RETIRED
channels = append(channels, callbacks["maxrtpg"])
case ThermalPolicy:
paramKeys = append(paramKeys, thermalPolicyIndex)
condition |= C.DCGM_POLICY_COND_THERMAL
channels = append(channels, callbacks["thermal"])
case PowerPolicy:
paramKeys = append(paramKeys, powerPolicyIndex)
condition |= C.DCGM_POLICY_COND_POWER
channels = append(channels, callbacks["power"])
case NvlinkPolicy:
paramKeys = append(paramKeys, nvlinkPolicyIndex)
condition |= C.DCGM_POLICY_COND_NVLINK
channels = append(channels, callbacks["nvlink"])
case XidPolicy:
paramKeys = append(paramKeys, xidPolicyIndex)
condition |= C.DCGM_POLICY_COND_XID
channels = append(channels, callbacks["xid"])
}
}

var err error
if err = setPolicy(groupId, condition, paramKeys); err != nil {
return nil, err
}

result := C.dcgmPolicyRegister(handle.handle, groupId.handle, C.dcgmPolicyCondition_t(condition), C.fpRecvUpdates(C.violationNotify), C.fpRecvUpdates(C.violationNotify))
var finishCallback unsafe.Pointer
result := C.dcgmPolicyRegister(handle.handle, groupId.handle, C.dcgmPolicyCondition_t(condition), C.fpRecvUpdates(C.violationNotify), C.fpRecvUpdates(finishCallback))

if err = errorString(result); err != nil {
return nil, &DcgmError{msg: C.GoString(C.errorString(result)), Code: result}
}
log.Println("Listening for violations...")

// merge
violation := make(chan PolicyViolation, len(channels))
violation := make(chan PolicyViolation, len(typ))
go func() {
for _, c := range channels {
val := <-c
violation <- val
defer func() {
log.Println("unregister policy violation...")
close(violation)
unregisterPolicy(groupId, condition)
}()
for {
select {
case dbe := <-callbacks["dbe"]:
violation <- dbe
case pcie := <-callbacks["pcie"]:
violation <- pcie
case maxrtpg := <-callbacks["maxrtpg"]:
violation <- maxrtpg
case thermal := <-callbacks["thermal"]:
violation <- thermal
case power := <-callbacks["power"]:
violation <- power
case nvlink := <-callbacks["nvlink"]:
violation <- nvlink
case xid := <-callbacks["xid"]:
violation <- xid
case <-ctx.Done():
return
}
}
DestroyGroup(groupId)
close(violation)
}()

return violation, err
Expand All @@ -370,7 +371,7 @@ func unregisterPolicy(groupId GroupHandle, condition C.dcgmPolicyCondition_t) {
result := C.dcgmPolicyUnregister(handle.handle, groupId.handle, condition)

if err := errorString(result); err != nil {
fmt.Errorf("Error unregistering policy: %s", err)
log.Println(fmt.Errorf("error unregistering policy: %s", err))
}
}

Expand Down
Loading

0 comments on commit cae9969

Please sign in to comment.