From e1d3d2424b6bf3ec8897c470fe491534ea6b60bf Mon Sep 17 00:00:00 2001 From: juliankatz Date: Fri, 10 Oct 2025 14:05:28 -0700 Subject: [PATCH] Add disk suppport label to NodeGetInfo topologies. --- cmd/gce-pd-csi-driver/main.go | 4 + pkg/common/utils.go | 5 + pkg/gce-pd-csi-driver/controller.go | 1 + pkg/gce-pd-csi-driver/gce-pd-driver.go | 1 + pkg/gce-pd-csi-driver/node.go | 52 ++++++-- pkg/gce-pd-csi-driver/node_test.go | 164 ++++++++++++++++++++++++- 6 files changed, 214 insertions(+), 13 deletions(-) diff --git a/cmd/gce-pd-csi-driver/main.go b/cmd/gce-pd-csi-driver/main.go index f2b2e884e..726a69506 100644 --- a/cmd/gce-pd-csi-driver/main.go +++ b/cmd/gce-pd-csi-driver/main.go @@ -99,6 +99,8 @@ var ( diskTopology = flag.Bool("disk-topology", false, "If set to true, the driver will add a disk-type.gke.io/[disk-type] topology label when the StorageClass has the use-allowed-disk-topology parameter set to true. That topology label is included in the Topologies returned in CreateVolumeResponse. This flag is disabled by default.") + dynamicVolumes = flag.Bool("dynamic-volumes", false, "If set to true, the CSI driver will automatically select a compatible disk type based on the presence of the dynamic-volume parameter and disk types defined in the StorageClass. Disabled by default.") + diskCacheSyncPeriod = flag.Duration("disk-cache-sync-period", 10*time.Minute, "Period for the disk cache to check the /dev/disk/by-id/ directory and evaluate the symlinks") enableDiskSizeValidation = flag.Bool("enable-disk-size-validation", false, "If set to true, the driver will validate that the requested disk size is matches the physical disk size. This flag is disabled by default.") @@ -257,6 +259,7 @@ func handle() { args := &driver.GCEControllerServerArgs{ EnableDiskTopology: *diskTopology, EnableDiskSizeValidation: *enableDiskSizeValidation, + EnableDynamicVolumes: *dynamicVolumes, } controllerServer = driver.NewControllerServer(gceDriver, cloudProvider, initialBackoffDuration, maxBackoffDuration, fallbackRequisiteZones, *enableStoragePoolsFlag, *enableDataCacheFlag, multiZoneVolumeHandleConfig, listVolumesConfig, provisionableDisksConfig, *enableHdHAFlag, args) @@ -299,6 +302,7 @@ func handle() { SysfsPath: "/sys", MetricsManager: metricsManager, DeviceCache: deviceCache, + EnableDynamicVolumes: *dynamicVolumes, } nodeServer = driver.NewNodeServer(gceDriver, mounter, deviceUtils, meta, statter, nsArgs) diff --git a/pkg/common/utils.go b/pkg/common/utils.go index 74bbb59a1..923c93e0a 100644 --- a/pkg/common/utils.go +++ b/pkg/common/utils.go @@ -492,6 +492,11 @@ func MapNumber(vCPUs int64, limitMap []constants.MachineHyperdiskLimit) int64 { return 15 } +// HasDiskTypeLabelKeyPrefix checks if the label key starts with the DiskTypeKeyPrefix. +func HasDiskTypeLabelKeyPrefix(labelKey string) bool { + return strings.HasPrefix(labelKey, constants.DiskTypeKeyPrefix) +} + func DiskTypeLabelKey(diskType string) string { return fmt.Sprintf("%s/%s", constants.DiskTypeKeyPrefix, diskType) } diff --git a/pkg/gce-pd-csi-driver/controller.go b/pkg/gce-pd-csi-driver/controller.go index 6c226d786..cdc722368 100644 --- a/pkg/gce-pd-csi-driver/controller.go +++ b/pkg/gce-pd-csi-driver/controller.go @@ -132,6 +132,7 @@ type GCEControllerServer struct { type GCEControllerServerArgs struct { EnableDiskTopology bool EnableDiskSizeValidation bool + EnableDynamicVolumes bool } type MultiZoneVolumeHandleConfig struct { diff --git a/pkg/gce-pd-csi-driver/gce-pd-driver.go b/pkg/gce-pd-csi-driver/gce-pd-driver.go index c1740df80..9e4ff02ed 100644 --- a/pkg/gce-pd-csi-driver/gce-pd-driver.go +++ b/pkg/gce-pd-csi-driver/gce-pd-driver.go @@ -160,6 +160,7 @@ func NewNodeServer(gceDriver *GCEDriver, mounter *mount.SafeFormatAndMount, devi SysfsPath: args.SysfsPath, metricsManager: args.MetricsManager, DeviceCache: args.DeviceCache, + EnableDynamicVolumes: args.EnableDynamicVolumes, } } diff --git a/pkg/gce-pd-csi-driver/node.go b/pkg/gce-pd-csi-driver/node.go index a0577400b..0cb3b576c 100644 --- a/pkg/gce-pd-csi-driver/node.go +++ b/pkg/gce-pd-csi-driver/node.go @@ -32,9 +32,9 @@ import ( csi "github.com/container-storage-interface/spec/lib/go/csi" + corev1 "k8s.io/api/core/v1" "k8s.io/klog/v2" "k8s.io/mount-utils" - "sigs.k8s.io/gcp-compute-persistent-disk-csi-driver/pkg/common" "sigs.k8s.io/gcp-compute-persistent-disk-csi-driver/pkg/constants" "sigs.k8s.io/gcp-compute-persistent-disk-csi-driver/pkg/deviceutils" @@ -83,6 +83,8 @@ type GCENodeServer struct { metricsManager *metrics.MetricsManager // A cache of the device paths for the volumes that are attached to the node. DeviceCache *linkcache.DeviceCache + + EnableDynamicVolumes bool } type NodeServerArgs struct { @@ -101,6 +103,8 @@ type NodeServerArgs struct { MetricsManager *metrics.MetricsManager DeviceCache *linkcache.DeviceCache + + EnableDynamicVolumes bool } var _ csi.NodeServer = &GCENodeServer{} @@ -717,9 +721,26 @@ func (ns *GCENodeServer) NodeGetInfo(ctx context.Context, req *csi.NodeGetInfoRe Segments: map[string]string{constants.TopologyKeyZone: ns.MetadataService.GetZone()}, } + node, err := k8sclient.GetNodeWithRetry(ctx, ns.MetadataService.GetName()) + if err != nil { + klog.Errorf("Failed to get node %s: %v. The error is ignored so that the driver can register", ns.MetadataService.GetName(), err.Error()) + err = nil + } + + if ns.EnableDynamicVolumes { + labels, err := ns.getDiskTypeLabels(node) + if err != nil { + return nil, fmt.Errorf("failed to fetch GKE topology labels: %v", err) + } + + for k, v := range labels { + top.Segments[k] = v + } + } + nodeID := common.CreateNodeID(ns.MetadataService.GetProject(), ns.MetadataService.GetZone(), ns.MetadataService.GetName()) - volumeLimits, err := ns.GetVolumeLimits(ctx) + volumeLimits, err := ns.getVolumeLimits(ctx, node) if err != nil { klog.Errorf("GetVolumeLimits failed: %v. The error is ignored so that the driver can register", err.Error()) // No error should be returned from NodeGetInfo, otherwise the driver will not register @@ -881,7 +902,7 @@ func (ns *GCENodeServer) NodeExpandVolume(ctx context.Context, req *csi.NodeExpa }, nil } -func (ns *GCENodeServer) GetVolumeLimits(ctx context.Context) (int64, error) { +func (ns *GCENodeServer) getVolumeLimits(ctx context.Context, node *corev1.Node) (int64, error) { // Machine-type format: n1-type-CPUS or custom-CPUS-RAM or f1/g1-type machineType := ns.MetadataService.GetMachineType() @@ -893,7 +914,7 @@ func (ns *GCENodeServer) GetVolumeLimits(ctx context.Context) (int64, error) { } // Get attach limit override from label - attachLimitOverride, err := GetAttachLimitsOverrideFromNodeLabel(ctx, ns.MetadataService.GetName()) + attachLimitOverride, err := getAttachLimitsOverrideFromNodeLabel(node) if err == nil && attachLimitOverride > 0 && attachLimitOverride < 128 { return attachLimitOverride, nil } else { @@ -955,11 +976,12 @@ func (ns *GCENodeServer) GetVolumeLimits(ctx context.Context) (int64, error) { return volumeLimitBig, nil } -func GetAttachLimitsOverrideFromNodeLabel(ctx context.Context, nodeName string) (int64, error) { - node, err := k8sclient.GetNodeWithRetry(ctx, nodeName) - if err != nil { - return 0, err +func getAttachLimitsOverrideFromNodeLabel(node *corev1.Node) (int64, error) { + // If then node is nil, return 0 which means there is no override + if node == nil { + return 0, fmt.Errorf("node is nil") } + if val, found := node.GetLabels()[fmt.Sprintf(constants.NodeRestrictionLabelPrefix, constants.AttachLimitOverrideLabel)]; found { attachLimitOverrideForNode, err := strconv.ParseInt(val, 10, 64) if err != nil { @@ -970,3 +992,17 @@ func GetAttachLimitsOverrideFromNodeLabel(ctx context.Context, nodeName string) } return 0, nil } + +func (ns *GCENodeServer) getDiskTypeLabels(node *corev1.Node) (map[string]string, error) { + if node == nil { + return nil, fmt.Errorf("node is nil") + } + lbls := make(map[string]string) + for k, v := range node.GetLabels() { + if common.HasDiskTypeLabelKeyPrefix(k) { + lbls[k] = v + } + } + + return lbls, nil +} diff --git a/pkg/gce-pd-csi-driver/node_test.go b/pkg/gce-pd-csi-driver/node_test.go index 23347f228..d7a97adc1 100644 --- a/pkg/gce-pd-csi-driver/node_test.go +++ b/pkg/gce-pd-csi-driver/node_test.go @@ -21,6 +21,7 @@ import ( "os" "path" "path/filepath" + "sort" "strings" "testing" "time" @@ -32,6 +33,9 @@ import ( "github.com/google/go-cmp/cmp" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "google.golang.org/protobuf/testing/protocmp" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/mount-utils" "sigs.k8s.io/gcp-compute-persistent-disk-csi-driver/pkg/constants" "sigs.k8s.io/gcp-compute-persistent-disk-csi-driver/pkg/deviceutils" @@ -237,11 +241,11 @@ func TestNodeGetVolumeStats(t *testing.T) { func TestNodeGetVolumeLimits(t *testing.T) { gceDriver := getTestGCEDriver(t) ns := gceDriver.ns - req := &csi.NodeGetInfoRequest{} testCases := []struct { name string machineType string + node *corev1.Node expVolumeLimit int64 expectError bool }{ @@ -431,6 +435,43 @@ func TestNodeGetVolumeLimits(t *testing.T) { name: "a4x-medgpu-nolssd", // does not exist, testing edge case machineType: "a4x-medgpu-nolssd", expVolumeLimit: volumeLimitBig, + expectError: true, + }, + { + name: "attach limit override", + machineType: "n1-standard-1", + node: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Labels: map[string]string{ + fmt.Sprintf(constants.NodeRestrictionLabelPrefix, constants.AttachLimitOverrideLabel): "63", + }, + }, + }, + expVolumeLimit: 63, + }, + { + name: "invalid attach limit override", + machineType: "n1-standard-1", + node: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Labels: map[string]string{ + fmt.Sprintf(constants.NodeRestrictionLabelPrefix, constants.AttachLimitOverrideLabel): "invalid", + }, + }, + }, + expVolumeLimit: volumeLimitBig, + }, + { + name: "attach limit override out of bounds", + machineType: "n1-standard-1", + node: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Labels: map[string]string{ + fmt.Sprintf(constants.NodeRestrictionLabelPrefix, constants.AttachLimitOverrideLabel): "9999", + }, + }, + }, + expVolumeLimit: volumeLimitBig, }, } @@ -438,18 +479,15 @@ func TestNodeGetVolumeLimits(t *testing.T) { t.Logf("Test case: %s", tc.name) metadataservice.SetMachineType(tc.machineType) - res, err := ns.NodeGetInfo(context.Background(), req) + volumeLimit, err := ns.getVolumeLimits(context.Background(), tc.node) if err != nil && !tc.expectError { t.Fatalf("Failed to get node info: %v", err) } - volumeLimit := res.GetMaxVolumesPerNode() if volumeLimit != tc.expVolumeLimit { t.Fatalf("Expected volume limit: %v, got %v, for machine-type: %v", tc.expVolumeLimit, volumeLimit, tc.machineType) } - - t.Logf("Get node info: %v", res) } } @@ -1702,3 +1740,119 @@ func TestBlockingFormatAndMount(t *testing.T) { gceDriver := getTestBlockingFormatAndMountGCEDriver(t, readyToExecute) runBlockingFormatAndMount(t, gceDriver, readyToExecute) } + +func TestGetDiskTypeLabels(t *testing.T) { + const ( + nodeName = "test-node" + diskA = constants.DiskTypeKeyPrefix + "/disk-a" + diskB = constants.DiskTypeKeyPrefix + "/disk-b" + ) + + testCases := []struct { + desc string + node *corev1.Node + want []string + wantError bool + }{ + { + desc: "no topology labels", + node: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: nodeName, + Labels: map[string]string{"foo": "bar"}, + }, + }, + want: nil, + }, + { + desc: "multiple topology labels", + node: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: nodeName, + Labels: map[string]string{ + diskA: "true", + diskB: "true", + }, + }, + }, + want: []string{diskA, diskB}, + }, + { + desc: "node not found", + node: nil, + wantError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + gceDriver := getTestGCEDriverWithCustomMounter(t, mountmanager.NewFakeSafeMounter(), &NodeServerArgs{}) + ns := gceDriver.ns + + lbls, err := ns.getDiskTypeLabels(tc.node) + if tc.wantError { + if err == nil { + t.Fatalf("expected error but got none") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var got []string + for key := range lbls { + got = append(got, key) + } + sort.Strings(got) + + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Fatalf("unexpected topology labels (-want +got):\n%s", diff) + } + }) + } +} + +func TestNodeGetInfo(t *testing.T) { + const ( + machineType = "n1-standard-4" + zone = "us-central1-b" + name = "test-node" + ) + tests := []struct { + desc string + want *csi.NodeGetInfoResponse + }{ + { + desc: "success", + want: &csi.NodeGetInfoResponse{ + NodeId: fmt.Sprintf("projects/test-project/zones/%s/instances/%s", zone, name), + MaxVolumesPerNode: volumeLimitBig, + AccessibleTopology: &csi.Topology{ + Segments: map[string]string{ + constants.TopologyKeyZone: zone, + }, + }, + }, + }, + } + for _, tc := range tests { + t.Run(tc.desc, func(t *testing.T) { + gceDriver := getTestGCEDriver(t) + ns := gceDriver.ns + req := &csi.NodeGetInfoRequest{} + metadataservice.SetMachineType(machineType) + metadataservice.SetZone(zone) + metadataservice.SetName(node) + + got, err := ns.NodeGetInfo(context.Background(), req) + if err != nil { + t.Fatalf("Failed to get node info: %v", err) + } + + if diff := cmp.Diff(tc.want, got, protocmp.Transform()); diff != "" { + t.Fatalf("NodeGetInfo() returned unexpected diff (-want +got):\n%s", diff) + } + }) + } +}