Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 220 additions & 0 deletions cadence/signal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
"""
Signal definition and registration for Cadence workflows.

This module provides functionality to define and register signal handlers
for workflows, similar to ActivityDefinition but for signals.
"""

import inspect
from dataclasses import dataclass
from functools import update_wrapper
from inspect import Parameter, signature
from typing import (
Callable,
Generic,
ParamSpec,
Type,
TypeVar,
TypedDict,
Unpack,
overload,
get_type_hints,
Any,
)

P = ParamSpec("P")
T = TypeVar("T")


@dataclass(frozen=True)
class SignalParameter:
"""Parameter metadata for a signal handler."""

name: str
type_hint: Type | None
has_default: bool
default_value: Any


class SignalDefinitionOptions(TypedDict, total=False):
"""Options for defining a signal."""

name: str


class SignalDefinition(Generic[P, T]):
"""
Definition of a signal handler with metadata.

Similar to ActivityDefinition but for signal handlers.
Provides type safety and metadata for signal handlers.
"""

def __init__(
self,
wrapped: Callable[P, T],
name: str,
params: list[SignalParameter],
is_async: bool,
):
self._wrapped = wrapped
self._name = name
self._params = params
self._is_async = is_async
update_wrapper(self, wrapped)

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
"""Call the wrapped signal handler function."""
return self._wrapped(*args, **kwargs)

@property
def name(self) -> str:
"""Get the signal name."""
return self._name

@property
def params(self) -> list[SignalParameter]:
"""Get the signal parameters."""
return self._params

@property
def is_async(self) -> bool:
"""Check if the signal handler is async."""
return self._is_async

@property
def wrapped(self) -> Callable[P, T]:
"""Get the wrapped signal handler function."""
return self._wrapped

@staticmethod
def wrap(
fn: Callable[P, T], opts: SignalDefinitionOptions
) -> "SignalDefinition[P, T]":
"""
Wrap a function as a SignalDefinition.

Args:
fn: The signal handler function to wrap
opts: Options for the signal definition

Returns:
A SignalDefinition instance

Raises:
ValueError: If name is not provided in options or return type is not None
"""
name = opts.get("name") or fn.__qualname__
is_async = inspect.iscoroutinefunction(fn)
params = _get_signal_signature(fn)
_validate_signal_return_type(fn)

return SignalDefinition(fn, name, params, is_async)


SignalDecorator = Callable[[Callable[P, T]], SignalDefinition[P, T]]


@overload
def defn(fn: Callable[P, T]) -> SignalDefinition[P, T]: ...


@overload
def defn(**kwargs: Unpack[SignalDefinitionOptions]) -> SignalDecorator: ...


def defn(
fn: Callable[P, T] | None = None, **kwargs: Unpack[SignalDefinitionOptions]
) -> SignalDecorator | SignalDefinition[P, T]:
"""
Decorator to define a signal handler.

Can be used with or without parentheses:
@signal.defn(name="approval")
async def handle_approval(self, approved: bool):
...

@signal.defn(name="approval")
def handle_approval(self, approved: bool):
...

Args:
fn: The signal handler function to decorate
**kwargs: Options for the signal definition (name is required)

Returns:
The decorated function as a SignalDefinition instance

Raises:
ValueError: If name is not provided
"""
options = SignalDefinitionOptions(**kwargs)

def decorator(inner_fn: Callable[P, T]) -> SignalDefinition[P, T]:
return SignalDefinition.wrap(inner_fn, options)

if fn is not None:
return decorator(fn)

return decorator


def _validate_signal_return_type(fn: Callable) -> None:
"""
Validate that signal handler returns None.

Args:
fn: The signal handler function

Raises:
ValueError: If return type is not None
"""
try:
hints = get_type_hints(fn)
ret_type = hints.get("return", inspect.Signature.empty)

if ret_type is not None and ret_type is not inspect.Signature.empty:
raise ValueError(
f"Signal handler '{fn.__qualname__}' must return None "
f"(signals cannot return values), got {ret_type}"
)
except NameError:
pass


def _get_signal_signature(fn: Callable[P, T]) -> list[SignalParameter]:
"""
Extract parameter information from a signal handler function.

Args:
fn: The signal handler function

Returns:
List of SignalParameter objects

Raises:
ValueError: If parameters are not positional
"""
sig = signature(fn)
args = sig.parameters
hints = get_type_hints(fn)
params = []

for name, param in args.items():
# Filter out the self parameter for instance methods
if param.name == "self":
continue

has_default = param.default != Parameter.empty
default = param.default if has_default else None

if param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD):
type_hint = hints.get(name, None)
params.append(SignalParameter(name, type_hint, has_default, default))
else:
raise ValueError(
f"Signal handler '{fn.__qualname__}' parameter '{name}' must be positional, "
f"got {param.kind.name}"
)

return params
66 changes: 64 additions & 2 deletions cadence/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from cadence.api.v1.history_pb2 import HistoryEvent
from cadence.data_converter import DataConverter
from cadence.signal import SignalDefinition, SignalDefinitionOptions

ResultType = TypeVar("ResultType")

Expand Down Expand Up @@ -60,10 +61,22 @@ class WorkflowDefinition:
Provides type safety and metadata for workflow classes.
"""

def __init__(self, cls: Type, name: str, run_method_name: str):
def __init__(
self,
cls: Type,
name: str,
run_method_name: str,
signals: dict[str, SignalDefinition[..., Any]],
):
self._cls = cls
self._name = name
self._run_method_name = run_method_name
self._signals = signals

@property
def signals(self) -> dict[str, SignalDefinition[..., Any]]:
"""Get the signal definitions."""
return self._signals

@property
def name(self) -> str:
Expand Down Expand Up @@ -99,6 +112,11 @@ def wrap(cls: Type, opts: WorkflowDefinitionOptions) -> "WorkflowDefinition":
name = opts["name"]

# Validate that the class has exactly one run method and find it
# Also validate that class does not have multiple signal methods with the same name
signals: dict[str, SignalDefinition[..., Any]] = {}
signal_names: dict[
str, str
] = {} # Map signal name to method name for duplicate detection
run_method_name = None
for attr_name in dir(cls):
if attr_name.startswith("_"):
Expand All @@ -116,10 +134,24 @@ def wrap(cls: Type, opts: WorkflowDefinitionOptions) -> "WorkflowDefinition":
)
run_method_name = attr_name

if hasattr(attr, "_workflow_signal"):
signal_name = getattr(attr, "_workflow_signal")
if signal_name in signal_names:
raise ValueError(
f"Multiple @workflow.signal methods found in class {cls.__name__} "
f"with signal name '{signal_name}': '{attr_name}' and '{signal_names[signal_name]}'"
)
# Create SignalDefinition from the decorated method
signal_def = SignalDefinition.wrap(
attr, SignalDefinitionOptions(name=signal_name)
)
signals[signal_name] = signal_def
signal_names[signal_name] = attr_name

if run_method_name is None:
raise ValueError(f"No @workflow.run method found in class {cls.__name__}")

return WorkflowDefinition(cls, name, run_method_name)
return WorkflowDefinition(cls, name, run_method_name, signals)


def run(func: Optional[T] = None) -> Union[T, Callable[[T], T]]:
Expand Down Expand Up @@ -163,6 +195,36 @@ def decorator(f: T) -> T:
return decorator(func)


def signal(name: str | None = None) -> Callable[[T], T]:
"""
Decorator to mark a method as a workflow signal handler.

Example:
@workflow.signal(name="approval_channel")
async def approve(self, approved: bool):
self.approved = approved

Args:
name: The name of the signal

Returns:
The decorated method with workflow signal metadata

Raises:
ValueError: If name is not provided

"""
if name is None:
raise ValueError("name is required")

def decorator(f: T) -> T:
f._workflow_signal = name # type: ignore
return f

# Only allow @workflow.signal(name), require name to be explicitly provided
return decorator


@dataclass
class WorkflowInfo:
workflow_type: str
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

import pytest
from unittest.mock import Mock, AsyncMock, patch
from unittest.mock import Mock, patch
from cadence.api.v1.service_worker_pb2 import PollForDecisionTaskResponse
from cadence.api.v1.common_pb2 import Payload, WorkflowExecution, WorkflowType
from cadence.api.v1.history_pb2 import (
Expand Down Expand Up @@ -250,7 +250,7 @@ async def test_extract_workflow_input_deserialization_error(
"""Test workflow input extraction with deserialization error."""

# Mock data converter to raise an exception
mock_client.data_converter.from_data = AsyncMock(
mock_client.data_converter.from_data = Mock(
side_effect=Exception("Deserialization error")
)

Expand Down
Loading