Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions tests/e2e/heterogeneous_clusters_oauth_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from time import sleep
import time
from codeflare_sdk import (
Cluster,
ClusterConfiguration,
Expand Down Expand Up @@ -55,22 +53,24 @@ def run_heterogeneous_clusters(
namespace=self.namespace,
name=cluster_name,
num_workers=1,
head_cpu_requests=1,
head_cpu_limits=1,
worker_cpu_requests=1,
head_cpu_requests="500m",
head_cpu_limits="500m",
head_memory_requests=2,
head_memory_limits=4,
worker_cpu_requests="500m",
worker_cpu_limits=1,
worker_memory_requests=1,
worker_memory_requests=2,
worker_memory_limits=4,
image=ray_image,
verify_tls=False,
local_queue=queue_name,
)
)
cluster.apply()
sleep(5)
# Wait for the cluster to be scheduled and ready, we don't need the dashboard for this check
cluster.wait_ready(dashboard_check=False)
node_name = get_pod_node(self, self.namespace, cluster_name)
print(f"Cluster {cluster_name}-{flavor} is running on node: {node_name}")
sleep(5)
assert (
node_name in expected_nodes
), f"Node {node_name} is not in the expected nodes for flavor {flavor}."
Expand Down
1 change: 1 addition & 0 deletions tests/e2e/local_interactive_sdk_oauth_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from support import *


@pytest.mark.skip(reason="Remote ray.init() is temporarily unsupported")
@pytest.mark.openshift
class TestRayLocalInteractiveOauth:
def setup_method(self):
Expand Down
117 changes: 111 additions & 6 deletions tests/e2e/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from time import sleep
from codeflare_sdk import get_cluster
from kubernetes import client, config
from kubernetes.client import V1Toleration
from codeflare_sdk.common.kubernetes_cluster.kube_api_helpers import (
_kube_api_error_handling,
)
Expand Down Expand Up @@ -146,6 +147,92 @@ def random_choice():
return "".join(random.choices(alphabet, k=5))


def _parse_label_env(env_var, default):
"""Parse label from environment variable (format: 'key=value')."""
label_str = os.getenv(env_var, default)
return label_str.split("=")


def get_master_taint_key(self):
"""
Detect the actual master/control-plane taint key from nodes.
Returns the taint key if found, or defaults to control-plane.
"""
# Check env var first (most efficient)
if os.getenv("TOLERATION_KEY"):
return os.getenv("TOLERATION_KEY")

# Try to detect from cluster nodes
try:
nodes = self.api_instance.list_node()
taint_key = next(
(
taint.key
for node in nodes.items
if node.spec.taints
for taint in node.spec.taints
if taint.key
in [
"node-role.kubernetes.io/master",
"node-role.kubernetes.io/control-plane",
]
),
None,
)
if taint_key:
return taint_key
except Exception as e:
print(f"Warning: Could not detect master taint key: {e}")

# Default fallback
return "node-role.kubernetes.io/control-plane"


def ensure_nodes_labeled_for_flavors(self, num_flavors, with_labels):
"""
Check if required node labels exist for ResourceFlavor targeting.
This handles both default (worker-1=true) and non-default (ingress-ready=true) flavors.

NOTE: This function does NOT modify cluster nodes. It only checks if required labels exist.
If labels don't exist, the test will use whatever labels are available on the cluster.
For shared clusters, set WORKER_LABEL and CONTROL_LABEL env vars to match existing labels.
"""
if not with_labels:
return

worker_label, worker_value = _parse_label_env("WORKER_LABEL", "worker-1=true")
control_label, control_value = _parse_label_env(
"CONTROL_LABEL", "ingress-ready=true"
)

try:
worker_nodes = self.api_instance.list_node(
label_selector="node-role.kubernetes.io/worker"
)

if not worker_nodes.items:
print("Warning: No worker nodes found")
return

# Check labels based on num_flavors
labels_to_check = [("WORKER_LABEL", worker_label, worker_value)]
if num_flavors > 1:
labels_to_check.append(("CONTROL_LABEL", control_label, control_value))

for env_var, label, value in labels_to_check:
has_label = any(
node.metadata.labels and node.metadata.labels.get(label) == value
for node in worker_nodes.items
)
if not has_label:
print(
f"Warning: Label {label}={value} not found (set {env_var} env var to match existing labels)"
)

except Exception as e:
print(f"Warning: Could not check existing labels: {e}")


def create_namespace(self):
try:
self.namespace = f"test-ns-{random_choice()}"
Expand Down Expand Up @@ -280,14 +367,13 @@ def create_cluster_queue(self, cluster_queue, flavor):
def create_resource_flavor(
self, flavor, default=True, with_labels=False, with_tolerations=False
):
worker_label, worker_value = os.getenv("WORKER_LABEL", "worker-1=true").split("=")
control_label, control_value = os.getenv(
worker_label, worker_value = _parse_label_env("WORKER_LABEL", "worker-1=true")
control_label, control_value = _parse_label_env(
"CONTROL_LABEL", "ingress-ready=true"
).split("=")
toleration_key = os.getenv(
"TOLERATION_KEY", "node-role.kubernetes.io/control-plane"
)

toleration_key = os.getenv("TOLERATION_KEY") or get_master_taint_key(self)

node_labels = {}
if with_labels:
node_labels = (
Expand Down Expand Up @@ -451,6 +537,25 @@ def get_nodes_by_label(self, node_labels):
return [node.metadata.name for node in nodes.items]


def get_tolerations_from_flavor(self, flavor_name):
"""
Extract tolerations from a ResourceFlavor and convert them to V1Toleration objects.
Returns a list of V1Toleration objects, or empty list if no tolerations found.
"""
flavor_spec = get_flavor_spec(self, flavor_name)
tolerations_spec = flavor_spec.get("spec", {}).get("tolerations", [])

return [
V1Toleration(
key=tol_spec.get("key"),
operator=tol_spec.get("operator", "Equal"),
value=tol_spec.get("value"),
effect=tol_spec.get("effect"),
)
for tol_spec in tolerations_spec
]


def assert_get_cluster_and_jobsubmit(
self, cluster_name, accelerator=None, number_of_gpus=None
):
Expand Down Expand Up @@ -514,7 +619,7 @@ def wait_for_kueue_admission(self, job_api, job_name, namespace, timeout=120):
workload = get_kueue_workload_for_job(self, job_name, namespace)
if workload:
conditions = workload.get("status", {}).get("conditions", [])
print(f" DEBUG: Workload conditions for '{job_name}':")
print(f"DEBUG: Workload conditions for '{job_name}':")
for condition in conditions:
print(
f" - {condition.get('type')}: {condition.get('status')} - {condition.get('reason', '')} - {condition.get('message', '')}"
Expand Down
Loading