Skip to content

Commit 02c19b1

Browse files
julianKatzhajiler
authored andcommitted
Add disk suppport label to NodeGetInfo topologies.
1 parent 24692d0 commit 02c19b1

File tree

6 files changed

+190
-13
lines changed

6 files changed

+190
-13
lines changed

cmd/gce-pd-csi-driver/main.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ var (
9999

100100
diskTopology = flag.Bool("disk-topology", false, "If set to true, the driver will add a disk-type.gke.io/[disk-type] topology label when the StorageClass has the use-allowed-disk-topology parameter set to true. That topology label is included in the Topologies returned in CreateVolumeResponse. This flag is disabled by default.")
101101

102+
dynamicVolumes = flag.Bool("dynamic-volumes", false, "If set to true, the CSI driver will automatically select a compatible disk type based on the presence of the dynamic-volume parameter and disk types defined in the StorageClass. Disabled by default.")
103+
102104
diskCacheSyncPeriod = flag.Duration("disk-cache-sync-period", 10*time.Minute, "Period for the disk cache to check the /dev/disk/by-id/ directory and evaluate the symlinks")
103105

104106
enableDiskSizeValidation = flag.Bool("enable-disk-size-validation", false, "If set to true, the driver will validate that the requested disk size is matches the physical disk size. This flag is disabled by default.")
@@ -257,6 +259,7 @@ func handle() {
257259
args := &driver.GCEControllerServerArgs{
258260
EnableDiskTopology: *diskTopology,
259261
EnableDiskSizeValidation: *enableDiskSizeValidation,
262+
EnableDynamicVolumes: *dynamicVolumes,
260263
}
261264

262265
controllerServer = driver.NewControllerServer(gceDriver, cloudProvider, initialBackoffDuration, maxBackoffDuration, fallbackRequisiteZones, *enableStoragePoolsFlag, *enableDataCacheFlag, multiZoneVolumeHandleConfig, listVolumesConfig, provisionableDisksConfig, *enableHdHAFlag, args)
@@ -299,6 +302,7 @@ func handle() {
299302
SysfsPath: "/sys",
300303
MetricsManager: metricsManager,
301304
DeviceCache: deviceCache,
305+
EnableDynamicVolumes: *dynamicVolumes,
302306
}
303307
nodeServer = driver.NewNodeServer(gceDriver, mounter, deviceUtils, meta, statter, nsArgs)
304308

pkg/common/utils.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,11 @@ func MapNumber(vCPUs int64, limitMap []constants.MachineHyperdiskLimit) int64 {
492492
return 15
493493
}
494494

495+
// HasDiskTypeLabelKeyPrefix checks if the label key starts with the DiskTypeKeyPrefix.
496+
func HasDiskTypeLabelKeyPrefix(labelKey string) bool {
497+
return strings.HasPrefix(labelKey, constants.DiskTypeKeyPrefix)
498+
}
499+
495500
func DiskTypeLabelKey(diskType string) string {
496501
return fmt.Sprintf("%s/%s", constants.DiskTypeKeyPrefix, diskType)
497502
}

pkg/gce-pd-csi-driver/controller.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ type GCEControllerServer struct {
132132
type GCEControllerServerArgs struct {
133133
EnableDiskTopology bool
134134
EnableDiskSizeValidation bool
135+
EnableDynamicVolumes bool
135136
}
136137

137138
type MultiZoneVolumeHandleConfig struct {

pkg/gce-pd-csi-driver/gce-pd-driver.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ func NewNodeServer(gceDriver *GCEDriver, mounter *mount.SafeFormatAndMount, devi
160160
SysfsPath: args.SysfsPath,
161161
metricsManager: args.MetricsManager,
162162
DeviceCache: args.DeviceCache,
163+
EnableDynamicVolumes: args.EnableDynamicVolumes,
163164
}
164165
}
165166

pkg/gce-pd-csi-driver/node.go

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ import (
3232

3333
csi "github.com/container-storage-interface/spec/lib/go/csi"
3434

35+
corev1 "k8s.io/api/core/v1"
3536
"k8s.io/klog/v2"
3637
"k8s.io/mount-utils"
37-
3838
"sigs.k8s.io/gcp-compute-persistent-disk-csi-driver/pkg/common"
3939
"sigs.k8s.io/gcp-compute-persistent-disk-csi-driver/pkg/constants"
4040
"sigs.k8s.io/gcp-compute-persistent-disk-csi-driver/pkg/deviceutils"
@@ -83,6 +83,8 @@ type GCENodeServer struct {
8383
metricsManager *metrics.MetricsManager
8484
// A cache of the device paths for the volumes that are attached to the node.
8585
DeviceCache *linkcache.DeviceCache
86+
87+
EnableDynamicVolumes bool
8688
}
8789

8890
type NodeServerArgs struct {
@@ -101,6 +103,8 @@ type NodeServerArgs struct {
101103

102104
MetricsManager *metrics.MetricsManager
103105
DeviceCache *linkcache.DeviceCache
106+
107+
EnableDynamicVolumes bool
104108
}
105109

106110
var _ csi.NodeServer = &GCENodeServer{}
@@ -717,9 +721,26 @@ func (ns *GCENodeServer) NodeGetInfo(ctx context.Context, req *csi.NodeGetInfoRe
717721
Segments: map[string]string{constants.TopologyKeyZone: ns.MetadataService.GetZone()},
718722
}
719723

724+
node, err := k8sclient.GetNodeWithRetry(ctx, ns.MetadataService.GetName())
725+
if err != nil {
726+
klog.Errorf("Failed to get node %s: %v. The error is ignored so that the driver can register", ns.MetadataService.GetName(), err.Error())
727+
err = nil
728+
}
729+
730+
if ns.EnableDynamicVolumes {
731+
labels, err := ns.getDiskTypeLabels(node)
732+
if err != nil {
733+
return nil, fmt.Errorf("failed to fetch GKE topology labels: %v", err)
734+
}
735+
736+
for k, v := range labels {
737+
top.Segments[k] = v
738+
}
739+
}
740+
720741
nodeID := common.CreateNodeID(ns.MetadataService.GetProject(), ns.MetadataService.GetZone(), ns.MetadataService.GetName())
721742

722-
volumeLimits, err := ns.GetVolumeLimits(ctx)
743+
volumeLimits, err := ns.getVolumeLimits(ctx, node)
723744
if err != nil {
724745
klog.Errorf("GetVolumeLimits failed: %v. The error is ignored so that the driver can register", err.Error())
725746
// No error should be returned from NodeGetInfo, otherwise the driver will not register
@@ -881,7 +902,7 @@ func (ns *GCENodeServer) NodeExpandVolume(ctx context.Context, req *csi.NodeExpa
881902
}, nil
882903
}
883904

884-
func (ns *GCENodeServer) GetVolumeLimits(ctx context.Context) (int64, error) {
905+
func (ns *GCENodeServer) getVolumeLimits(ctx context.Context, node *corev1.Node) (int64, error) {
885906
// Machine-type format: n1-type-CPUS or custom-CPUS-RAM or f1/g1-type
886907
machineType := ns.MetadataService.GetMachineType()
887908

@@ -893,7 +914,7 @@ func (ns *GCENodeServer) GetVolumeLimits(ctx context.Context) (int64, error) {
893914
}
894915

895916
// Get attach limit override from label
896-
attachLimitOverride, err := GetAttachLimitsOverrideFromNodeLabel(ctx, ns.MetadataService.GetName())
917+
attachLimitOverride, err := getAttachLimitsOverrideFromNodeLabel(node)
897918
if err == nil && attachLimitOverride > 0 && attachLimitOverride < 128 {
898919
return attachLimitOverride, nil
899920
} else {
@@ -955,11 +976,12 @@ func (ns *GCENodeServer) GetVolumeLimits(ctx context.Context) (int64, error) {
955976
return volumeLimitBig, nil
956977
}
957978

958-
func GetAttachLimitsOverrideFromNodeLabel(ctx context.Context, nodeName string) (int64, error) {
959-
node, err := k8sclient.GetNodeWithRetry(ctx, nodeName)
960-
if err != nil {
961-
return 0, err
979+
func getAttachLimitsOverrideFromNodeLabel(node *corev1.Node) (int64, error) {
980+
// If then node is nil, return 0 which means there is no override
981+
if node == nil {
982+
return 0, fmt.Errorf("node is nil")
962983
}
984+
963985
if val, found := node.GetLabels()[fmt.Sprintf(constants.NodeRestrictionLabelPrefix, constants.AttachLimitOverrideLabel)]; found {
964986
attachLimitOverrideForNode, err := strconv.ParseInt(val, 10, 64)
965987
if err != nil {
@@ -970,3 +992,17 @@ func GetAttachLimitsOverrideFromNodeLabel(ctx context.Context, nodeName string)
970992
}
971993
return 0, nil
972994
}
995+
996+
func (ns *GCENodeServer) getDiskTypeLabels(node *corev1.Node) (map[string]string, error) {
997+
if node == nil {
998+
return nil, fmt.Errorf("node is nil")
999+
}
1000+
topology := make(map[string]string)
1001+
for k, v := range node.GetLabels() {
1002+
if common.HasDiskTypeLabelKeyPrefix(k) {
1003+
topology[k] = v
1004+
}
1005+
}
1006+
1007+
return topology, nil
1008+
}

pkg/gce-pd-csi-driver/node_test.go

Lines changed: 135 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"os"
2222
"path"
2323
"path/filepath"
24+
"sort"
2425
"strings"
2526
"testing"
2627
"time"
@@ -32,6 +33,9 @@ import (
3233
"github.com/google/go-cmp/cmp"
3334
"google.golang.org/grpc/codes"
3435
"google.golang.org/grpc/status"
36+
"google.golang.org/protobuf/testing/protocmp"
37+
corev1 "k8s.io/api/core/v1"
38+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
3539
"k8s.io/mount-utils"
3640
"sigs.k8s.io/gcp-compute-persistent-disk-csi-driver/pkg/constants"
3741
"sigs.k8s.io/gcp-compute-persistent-disk-csi-driver/pkg/deviceutils"
@@ -237,11 +241,11 @@ func TestNodeGetVolumeStats(t *testing.T) {
237241
func TestNodeGetVolumeLimits(t *testing.T) {
238242
gceDriver := getTestGCEDriver(t)
239243
ns := gceDriver.ns
240-
req := &csi.NodeGetInfoRequest{}
241244

242245
testCases := []struct {
243246
name string
244247
machineType string
248+
node *corev1.Node
245249
expVolumeLimit int64
246250
expectError bool
247251
}{
@@ -431,25 +435,35 @@ func TestNodeGetVolumeLimits(t *testing.T) {
431435
name: "a4x-medgpu-nolssd", // does not exist, testing edge case
432436
machineType: "a4x-medgpu-nolssd",
433437
expVolumeLimit: volumeLimitBig,
438+
expectError: true,
439+
},
440+
{
441+
name: "Node Override",
442+
machineType: "n1-standard-1",
443+
node: &corev1.Node{
444+
ObjectMeta: metav1.ObjectMeta{
445+
Labels: map[string]string{
446+
fmt.Sprintf(constants.NodeRestrictionLabelPrefix, constants.AttachLimitOverrideLabel): "63",
447+
},
448+
},
449+
},
450+
expVolumeLimit: 63,
434451
},
435452
}
436453

437454
for _, tc := range testCases {
438455
t.Logf("Test case: %s", tc.name)
439456

440457
metadataservice.SetMachineType(tc.machineType)
441-
res, err := ns.NodeGetInfo(context.Background(), req)
458+
volumeLimit, err := ns.getVolumeLimits(context.Background(), tc.node)
442459
if err != nil && !tc.expectError {
443460
t.Fatalf("Failed to get node info: %v", err)
444461
}
445462

446-
volumeLimit := res.GetMaxVolumesPerNode()
447463
if volumeLimit != tc.expVolumeLimit {
448464
t.Fatalf("Expected volume limit: %v, got %v, for machine-type: %v",
449465
tc.expVolumeLimit, volumeLimit, tc.machineType)
450466
}
451-
452-
t.Logf("Get node info: %v", res)
453467
}
454468
}
455469

@@ -1702,3 +1716,119 @@ func TestBlockingFormatAndMount(t *testing.T) {
17021716
gceDriver := getTestBlockingFormatAndMountGCEDriver(t, readyToExecute)
17031717
runBlockingFormatAndMount(t, gceDriver, readyToExecute)
17041718
}
1719+
1720+
func TestGetDiskTypeLabels(t *testing.T) {
1721+
const (
1722+
nodeName = "test-node"
1723+
diskA = constants.DiskTypeKeyPrefix + "/disk-a"
1724+
diskB = constants.DiskTypeKeyPrefix + "/disk-b"
1725+
)
1726+
1727+
testCases := []struct {
1728+
desc string
1729+
node *corev1.Node
1730+
want []string
1731+
wantError bool
1732+
}{
1733+
{
1734+
desc: "no topology labels",
1735+
node: &corev1.Node{
1736+
ObjectMeta: metav1.ObjectMeta{
1737+
Name: nodeName,
1738+
Labels: map[string]string{"foo": "bar"},
1739+
},
1740+
},
1741+
want: nil,
1742+
},
1743+
{
1744+
desc: "multiple topology labels",
1745+
node: &corev1.Node{
1746+
ObjectMeta: metav1.ObjectMeta{
1747+
Name: nodeName,
1748+
Labels: map[string]string{
1749+
diskA: "true",
1750+
diskB: "true",
1751+
},
1752+
},
1753+
},
1754+
want: []string{diskA, diskB},
1755+
},
1756+
{
1757+
desc: "node not found",
1758+
node: nil,
1759+
wantError: true,
1760+
},
1761+
}
1762+
1763+
for _, tc := range testCases {
1764+
t.Run(tc.desc, func(t *testing.T) {
1765+
gceDriver := getTestGCEDriverWithCustomMounter(t, mountmanager.NewFakeSafeMounter(), &NodeServerArgs{})
1766+
ns := gceDriver.ns
1767+
1768+
lbls, err := ns.getDiskTypeLabels(tc.node)
1769+
if tc.wantError {
1770+
if err == nil {
1771+
t.Fatalf("expected error but got none")
1772+
}
1773+
return
1774+
}
1775+
if err != nil {
1776+
t.Fatalf("unexpected error: %v", err)
1777+
}
1778+
1779+
var got []string
1780+
for key := range lbls {
1781+
got = append(got, key)
1782+
}
1783+
sort.Strings(got)
1784+
1785+
if diff := cmp.Diff(tc.want, got); diff != "" {
1786+
t.Fatalf("unexpected topology labels (-want +got):\n%s", diff)
1787+
}
1788+
})
1789+
}
1790+
}
1791+
1792+
func TestNodeGetInfo(t *testing.T) {
1793+
const (
1794+
machineType = "n1-standard-4"
1795+
zone = "us-central1-b"
1796+
name = "test-node"
1797+
)
1798+
tests := []struct {
1799+
desc string
1800+
want *csi.NodeGetInfoResponse
1801+
}{
1802+
{
1803+
desc: "success",
1804+
want: &csi.NodeGetInfoResponse{
1805+
NodeId: fmt.Sprintf("projects/test-project/zones/%s/instances/%s", zone, name),
1806+
MaxVolumesPerNode: volumeLimitBig,
1807+
AccessibleTopology: &csi.Topology{
1808+
Segments: map[string]string{
1809+
constants.TopologyKeyZone: zone,
1810+
},
1811+
},
1812+
},
1813+
},
1814+
}
1815+
for _, tc := range tests {
1816+
t.Run(tc.desc, func(t *testing.T) {
1817+
gceDriver := getTestGCEDriver(t)
1818+
ns := gceDriver.ns
1819+
req := &csi.NodeGetInfoRequest{}
1820+
metadataservice.SetMachineType(machineType)
1821+
metadataservice.SetZone(zone)
1822+
metadataservice.SetName(node)
1823+
1824+
got, err := ns.NodeGetInfo(context.Background(), req)
1825+
if err != nil {
1826+
t.Fatalf("Failed to get node info: %v", err)
1827+
}
1828+
1829+
if diff := cmp.Diff(tc.want, got, protocmp.Transform()); diff != "" {
1830+
t.Fatalf("NodeGetInfo() returned unexpected diff (-want +got):\n%s", diff)
1831+
}
1832+
})
1833+
}
1834+
}

0 commit comments

Comments
 (0)