diff --git a/pkg/driver/controller.go b/pkg/driver/controller.go index 4cb599b7..812abf73 100644 --- a/pkg/driver/controller.go +++ b/pkg/driver/controller.go @@ -75,6 +75,15 @@ func (cs *controller) LockVolume(volume string) func() { return func() { mtx.Unlock() } } +func (cs *controller) LockVolumeWithSnapshot(volume string, snapshot string) func() { + unlockVol := cs.LockVolume(volume) + unlockSnap := cs.LockVolume(snapshot) + return func() { + unlockVol() + unlockSnap() + } +} + // NewController returns a new instance // of CSI controller func NewController(d *CSIDriver) csi.ControllerServer { @@ -724,10 +733,8 @@ func (cs *controller) CreateSnapshot( if err != nil { return nil, err } - unlockVol := cs.LockVolume(volumeID) - defer unlockVol() - unlockSnap := cs.LockVolume(snapName) - defer unlockSnap() + unlock := cs.LockVolumeWithSnapshot(volumeID, snapName) + defer unlock() snapTimeStamp := time.Now().Unix() var state string @@ -824,10 +831,8 @@ func (cs *controller) DeleteSnapshot( // should succeed when an invalid snapshot id is used return &csi.DeleteSnapshotResponse{}, nil } - unlockVol := cs.LockVolume(snapshotID[0]) - defer unlockVol() - unlockSnap := cs.LockVolume(snapshotID[1]) - defer unlockSnap() + unlock := cs.LockVolumeWithSnapshot(snapshotID[0], snapshotID[1]) + defer unlock() if err := zfs.DeleteSnapshot(snapshotID[1]); err != nil { return nil, status.Errorf( codes.Internal,