Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cmd/gce-pd-csi-driver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ var (

diskTopology = flag.Bool("disk-topology", false, "If set to true, the driver will add a disk-type.gke.io/[disk-type] topology label when the StorageClass has the use-allowed-disk-topology parameter set to true. That topology label is included in the Topologies returned in CreateVolumeResponse. This flag is disabled by default.")

dynamicVolumes = flag.Bool("dynamic-volumes", false, "If set to true, the CSI driver will automatically select a compatible disk type based on the presence of the dynamic-volume parameter and disk types defined in the StorageClass. Disabled by default.")

diskCacheSyncPeriod = flag.Duration("disk-cache-sync-period", 10*time.Minute, "Period for the disk cache to check the /dev/disk/by-id/ directory and evaluate the symlinks")

enableDiskSizeValidation = flag.Bool("enable-disk-size-validation", false, "If set to true, the driver will validate that the requested disk size is matches the physical disk size. This flag is disabled by default.")
Expand Down Expand Up @@ -257,6 +259,7 @@ func handle() {
args := &driver.GCEControllerServerArgs{
EnableDiskTopology: *diskTopology,
EnableDiskSizeValidation: *enableDiskSizeValidation,
EnableDynamicVolumes: *dynamicVolumes,
}

controllerServer = driver.NewControllerServer(gceDriver, cloudProvider, initialBackoffDuration, maxBackoffDuration, fallbackRequisiteZones, *enableStoragePoolsFlag, *enableDataCacheFlag, multiZoneVolumeHandleConfig, listVolumesConfig, provisionableDisksConfig, *enableHdHAFlag, args)
Expand Down Expand Up @@ -299,6 +302,7 @@ func handle() {
SysfsPath: "/sys",
MetricsManager: metricsManager,
DeviceCache: deviceCache,
EnableDynamicVolumes: *dynamicVolumes,
}
nodeServer = driver.NewNodeServer(gceDriver, mounter, deviceUtils, meta, statter, nsArgs)

Expand Down
5 changes: 5 additions & 0 deletions pkg/common/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,11 @@ func MapNumber(vCPUs int64, limitMap []constants.MachineHyperdiskLimit) int64 {
return 15
}

// HasDiskTypeLabelKeyPrefix checks if the label key starts with the DiskTypeKeyPrefix.
func HasDiskTypeLabelKeyPrefix(labelKey string) bool {
return strings.HasPrefix(labelKey, constants.DiskTypeKeyPrefix)
}

func DiskTypeLabelKey(diskType string) string {
return fmt.Sprintf("%s/%s", constants.DiskTypeKeyPrefix, diskType)
}
Expand Down
1 change: 1 addition & 0 deletions pkg/gce-pd-csi-driver/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ type GCEControllerServer struct {
type GCEControllerServerArgs struct {
EnableDiskTopology bool
EnableDiskSizeValidation bool
EnableDynamicVolumes bool
}

type MultiZoneVolumeHandleConfig struct {
Expand Down
1 change: 1 addition & 0 deletions pkg/gce-pd-csi-driver/gce-pd-driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ func NewNodeServer(gceDriver *GCEDriver, mounter *mount.SafeFormatAndMount, devi
SysfsPath: args.SysfsPath,
metricsManager: args.MetricsManager,
DeviceCache: args.DeviceCache,
EnableDynamicVolumes: args.EnableDynamicVolumes,
}
}

Expand Down
52 changes: 44 additions & 8 deletions pkg/gce-pd-csi-driver/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ import (

csi "github.com/container-storage-interface/spec/lib/go/csi"

corev1 "k8s.io/api/core/v1"
"k8s.io/klog/v2"
"k8s.io/mount-utils"

"sigs.k8s.io/gcp-compute-persistent-disk-csi-driver/pkg/common"
"sigs.k8s.io/gcp-compute-persistent-disk-csi-driver/pkg/constants"
"sigs.k8s.io/gcp-compute-persistent-disk-csi-driver/pkg/deviceutils"
Expand Down Expand Up @@ -83,6 +83,8 @@ type GCENodeServer struct {
metricsManager *metrics.MetricsManager
// A cache of the device paths for the volumes that are attached to the node.
DeviceCache *linkcache.DeviceCache

EnableDynamicVolumes bool
}

type NodeServerArgs struct {
Expand All @@ -101,6 +103,8 @@ type NodeServerArgs struct {

MetricsManager *metrics.MetricsManager
DeviceCache *linkcache.DeviceCache

EnableDynamicVolumes bool
}

var _ csi.NodeServer = &GCENodeServer{}
Expand Down Expand Up @@ -717,9 +721,26 @@ func (ns *GCENodeServer) NodeGetInfo(ctx context.Context, req *csi.NodeGetInfoRe
Segments: map[string]string{constants.TopologyKeyZone: ns.MetadataService.GetZone()},
}

node, err := k8sclient.GetNodeWithRetry(ctx, ns.MetadataService.GetName())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: You can rename err to nodeErr and avoid having to set it to nil in the if statement.

if err != nil {
klog.Errorf("Failed to get node %s: %v. The error is ignored so that the driver can register", ns.MetadataService.GetName(), err.Error())
err = nil
}

if ns.EnableDynamicVolumes {
labels, err := ns.getDiskTypeLabels(node)
if err != nil {
return nil, fmt.Errorf("failed to fetch GKE topology labels: %v", err)
}

for k, v := range labels {
top.Segments[k] = v
}
}

nodeID := common.CreateNodeID(ns.MetadataService.GetProject(), ns.MetadataService.GetZone(), ns.MetadataService.GetName())

volumeLimits, err := ns.GetVolumeLimits(ctx)
volumeLimits, err := ns.getVolumeLimits(ctx, node)
if err != nil {
klog.Errorf("GetVolumeLimits failed: %v. The error is ignored so that the driver can register", err.Error())
// No error should be returned from NodeGetInfo, otherwise the driver will not register
Expand Down Expand Up @@ -881,7 +902,7 @@ func (ns *GCENodeServer) NodeExpandVolume(ctx context.Context, req *csi.NodeExpa
}, nil
}

func (ns *GCENodeServer) GetVolumeLimits(ctx context.Context) (int64, error) {
func (ns *GCENodeServer) getVolumeLimits(ctx context.Context, node *corev1.Node) (int64, error) {
// Machine-type format: n1-type-CPUS or custom-CPUS-RAM or f1/g1-type
machineType := ns.MetadataService.GetMachineType()

Expand All @@ -893,7 +914,7 @@ func (ns *GCENodeServer) GetVolumeLimits(ctx context.Context) (int64, error) {
}

// Get attach limit override from label
attachLimitOverride, err := GetAttachLimitsOverrideFromNodeLabel(ctx, ns.MetadataService.GetName())
attachLimitOverride, err := getAttachLimitsOverrideFromNodeLabel(node)
if err == nil && attachLimitOverride > 0 && attachLimitOverride < 128 {
return attachLimitOverride, nil
} else {
Expand Down Expand Up @@ -955,11 +976,12 @@ func (ns *GCENodeServer) GetVolumeLimits(ctx context.Context) (int64, error) {
return volumeLimitBig, nil
}

func GetAttachLimitsOverrideFromNodeLabel(ctx context.Context, nodeName string) (int64, error) {
node, err := k8sclient.GetNodeWithRetry(ctx, nodeName)
if err != nil {
return 0, err
func getAttachLimitsOverrideFromNodeLabel(node *corev1.Node) (int64, error) {
// If then node is nil, return 0 which means there is no override
if node == nil {
return 0, fmt.Errorf("node is nil")
}

if val, found := node.GetLabels()[fmt.Sprintf(constants.NodeRestrictionLabelPrefix, constants.AttachLimitOverrideLabel)]; found {
attachLimitOverrideForNode, err := strconv.ParseInt(val, 10, 64)
if err != nil {
Expand All @@ -970,3 +992,17 @@ func GetAttachLimitsOverrideFromNodeLabel(ctx context.Context, nodeName string)
}
return 0, nil
}

func (ns *GCENodeServer) getDiskTypeLabels(node *corev1.Node) (map[string]string, error) {
if node == nil {
return nil, fmt.Errorf("node is nil")
}
lbls := make(map[string]string)
for k, v := range node.GetLabels() {
if common.HasDiskTypeLabelKeyPrefix(k) {
lbls[k] = v
}
}

return lbls, nil
}
164 changes: 159 additions & 5 deletions pkg/gce-pd-csi-driver/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"os"
"path"
"path/filepath"
"sort"
"strings"
"testing"
"time"
Expand All @@ -32,6 +33,9 @@ import (
"github.com/google/go-cmp/cmp"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/testing/protocmp"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/mount-utils"
"sigs.k8s.io/gcp-compute-persistent-disk-csi-driver/pkg/constants"
"sigs.k8s.io/gcp-compute-persistent-disk-csi-driver/pkg/deviceutils"
Expand Down Expand Up @@ -237,11 +241,11 @@ func TestNodeGetVolumeStats(t *testing.T) {
func TestNodeGetVolumeLimits(t *testing.T) {
gceDriver := getTestGCEDriver(t)
ns := gceDriver.ns
req := &csi.NodeGetInfoRequest{}

testCases := []struct {
name string
machineType string
node *corev1.Node
expVolumeLimit int64
expectError bool
}{
Expand Down Expand Up @@ -431,25 +435,59 @@ func TestNodeGetVolumeLimits(t *testing.T) {
name: "a4x-medgpu-nolssd", // does not exist, testing edge case
machineType: "a4x-medgpu-nolssd",
expVolumeLimit: volumeLimitBig,
expectError: true,
},
{
name: "attach limit override",
machineType: "n1-standard-1",
node: &corev1.Node{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
fmt.Sprintf(constants.NodeRestrictionLabelPrefix, constants.AttachLimitOverrideLabel): "63",
},
},
},
expVolumeLimit: 63,
},
{
name: "invalid attach limit override",
machineType: "n1-standard-1",
node: &corev1.Node{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
fmt.Sprintf(constants.NodeRestrictionLabelPrefix, constants.AttachLimitOverrideLabel): "invalid",
},
},
},
expVolumeLimit: volumeLimitBig,
},
{
name: "attach limit override out of bounds",
machineType: "n1-standard-1",
node: &corev1.Node{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
fmt.Sprintf(constants.NodeRestrictionLabelPrefix, constants.AttachLimitOverrideLabel): "9999",
},
},
},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a test case where node is nil as well, just to add coverage for the node == nil case

expVolumeLimit: volumeLimitBig,
},
}

for _, tc := range testCases {
t.Logf("Test case: %s", tc.name)

metadataservice.SetMachineType(tc.machineType)
res, err := ns.NodeGetInfo(context.Background(), req)
volumeLimit, err := ns.getVolumeLimits(context.Background(), tc.node)
if err != nil && !tc.expectError {
t.Fatalf("Failed to get node info: %v", err)
}

volumeLimit := res.GetMaxVolumesPerNode()
if volumeLimit != tc.expVolumeLimit {
t.Fatalf("Expected volume limit: %v, got %v, for machine-type: %v",
tc.expVolumeLimit, volumeLimit, tc.machineType)
}

t.Logf("Get node info: %v", res)
}
}

Expand Down Expand Up @@ -1702,3 +1740,119 @@ func TestBlockingFormatAndMount(t *testing.T) {
gceDriver := getTestBlockingFormatAndMountGCEDriver(t, readyToExecute)
runBlockingFormatAndMount(t, gceDriver, readyToExecute)
}

func TestGetDiskTypeLabels(t *testing.T) {
const (
nodeName = "test-node"
diskA = constants.DiskTypeKeyPrefix + "/disk-a"
diskB = constants.DiskTypeKeyPrefix + "/disk-b"
)

testCases := []struct {
desc string
node *corev1.Node
want []string
wantError bool
}{
{
desc: "no topology labels",
node: &corev1.Node{
ObjectMeta: metav1.ObjectMeta{
Name: nodeName,
Labels: map[string]string{"foo": "bar"},
},
},
want: nil,
},
{
desc: "multiple topology labels",
node: &corev1.Node{
ObjectMeta: metav1.ObjectMeta{
Name: nodeName,
Labels: map[string]string{
diskA: "true",
diskB: "true",
},
},
},
want: []string{diskA, diskB},
},
{
desc: "node not found",
node: nil,
wantError: true,
},
}

for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
gceDriver := getTestGCEDriverWithCustomMounter(t, mountmanager.NewFakeSafeMounter(), &NodeServerArgs{})
ns := gceDriver.ns

lbls, err := ns.getDiskTypeLabels(tc.node)
if tc.wantError {
if err == nil {
t.Fatalf("expected error but got none")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

var got []string
for key := range lbls {
got = append(got, key)
}
sort.Strings(got)

if diff := cmp.Diff(tc.want, got); diff != "" {
t.Fatalf("unexpected topology labels (-want +got):\n%s", diff)
}
})
}
}

func TestNodeGetInfo(t *testing.T) {
const (
machineType = "n1-standard-4"
zone = "us-central1-b"
name = "test-node"
)
tests := []struct {
desc string
want *csi.NodeGetInfoResponse
}{
{
desc: "success",
want: &csi.NodeGetInfoResponse{
NodeId: fmt.Sprintf("projects/test-project/zones/%s/instances/%s", zone, name),
MaxVolumesPerNode: volumeLimitBig,
AccessibleTopology: &csi.Topology{
Segments: map[string]string{
constants.TopologyKeyZone: zone,
},
},
},
},
}
for _, tc := range tests {
t.Run(tc.desc, func(t *testing.T) {
gceDriver := getTestGCEDriver(t)
ns := gceDriver.ns
req := &csi.NodeGetInfoRequest{}
metadataservice.SetMachineType(machineType)
metadataservice.SetZone(zone)
metadataservice.SetName(node)

got, err := ns.NodeGetInfo(context.Background(), req)
if err != nil {
t.Fatalf("Failed to get node info: %v", err)
}

if diff := cmp.Diff(tc.want, got, protocmp.Transform()); diff != "" {
t.Fatalf("NodeGetInfo() returned unexpected diff (-want +got):\n%s", diff)
}
})
}
}