Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 38 additions & 13 deletions src/codeflare_sdk/common/utils/generate_cert.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
import os
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.backends import default_backend
from cryptography import x509
from cryptography.x509.oid import NameOID
import ipaddress
import datetime
from ..kubernetes_cluster.auth import (
config_check,
Expand Down Expand Up @@ -151,7 +153,7 @@ def generate_tls_cert(cluster_name, namespace, days=30):
os.makedirs(tls_dir)

# Similar to:
# oc get secret ca-secret-<cluster-name> -o template='{{index .data "ca.key"}}'
# oc get secret ca-secret-<cluster-name> -o template='{{index .data "tls.key"}}'
# oc get secret ca-secret-<cluster-name> -o template='{{index .data "ca.crt"}}'|base64 -d > ${TLSDIR}/ca.crt
config_check()
v1 = client.CoreV1Api(get_api_client())
Expand All @@ -161,10 +163,29 @@ def generate_tls_cert(cluster_name, namespace, days=30):
secret = v1.read_namespaced_secret(secret_name, namespace).data

ca_cert = secret.get("ca.crt")
ca_key = secret.get("ca.key")
ca_key = secret.get("tls.key")

if not ca_cert:
raise ValueError(
f"CA certificate (ca.crt or tls.crt) not found in secret {secret_name}. "
f"Available keys: {list(secret.keys())}"
)
if not ca_key:
raise ValueError(
f"CA private key (tls.key) not found in secret {secret_name}. "
f"Available keys: {list(secret.keys())}"
)

# Decode and write CA certificate
ca_cert_pem = base64.b64decode(ca_cert).decode("utf-8")
with open(os.path.join(tls_dir, "ca.crt"), "w") as f:
f.write(base64.b64decode(ca_cert).decode("utf-8"))
f.write(ca_cert_pem)

# Extract CA subject to use as issuer for client certificate
ca_cert_obj = x509.load_pem_x509_certificate(
ca_cert_pem.encode("utf-8"), default_backend()
)
ca_subject = ca_cert_obj.subject

# Generate tls.key and signed tls.cert locally for ray client
# Similar to running these commands:
Expand All @@ -191,16 +212,22 @@ def generate_tls_cert(cluster_name, namespace, days=30):
with open(os.path.join(tls_dir, "tls.key"), "w") as f:
f.write(tls_key.decode("utf-8"))

head_svc_name = f"{cluster_name}-head-svc"
service_dns = f"{head_svc_name}.{namespace}.svc"
service_dns_cluster_local = f"{head_svc_name}.{namespace}.svc.cluster.local"

san_list = [
x509.DNSName("localhost"),
x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")),
x509.DNSName(head_svc_name),
x509.DNSName(service_dns),
x509.DNSName(service_dns_cluster_local),
]

one_day = datetime.timedelta(1, 0, 0)
tls_cert = (
x509.CertificateBuilder()
.issuer_name(
x509.Name(
[
x509.NameAttribute(NameOID.COMMON_NAME, "root-ca"),
]
)
)
.issuer_name(ca_subject)
.subject_name(
x509.Name(
[
Expand All @@ -213,9 +240,7 @@ def generate_tls_cert(cluster_name, namespace, days=30):
.not_valid_after(datetime.datetime.today() + (one_day * days))
.serial_number(x509.random_serial_number())
.add_extension(
x509.SubjectAlternativeName(
[x509.DNSName("localhost"), x509.DNSName("127.0.0.1")]
),
x509.SubjectAlternativeName(san_list),
False,
)
.sign(
Expand Down
2 changes: 1 addition & 1 deletion src/codeflare_sdk/common/utils/test_generate_cert.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_generate_ca_cert():

def secret_ca_retreival(secret_name, namespace):
ca_private_key_bytes, ca_cert = generate_ca_cert()
data = {"ca.crt": ca_cert, "ca.key": ca_private_key_bytes}
data = {"ca.crt": ca_cert, "tls.key": ca_private_key_bytes}
assert secret_name == "ca-secret-cluster"
assert namespace == "namespace"
return client.models.V1Secret(data=data)
Expand Down
Loading