Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 122 additions & 22 deletions model/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,92 @@ import (
"encoding/json"
"errors"
"fmt"
"sync"
)

// TaskConstructor is a function that creates a new instance of a task type
type TaskConstructor func() Task

// TaskRegistry manages the registration of task types for JSON unmarshaling
type TaskRegistry struct {
mu sync.RWMutex
constructors map[string]TaskConstructor
}

// NewTaskRegistry creates a new task registry
func NewTaskRegistry() *TaskRegistry {
return &TaskRegistry{
constructors: make(map[string]TaskConstructor),
}
}

// RegisterTask registers a custom task type with a constructor function
func (r *TaskRegistry) RegisterTask(taskType string, constructor TaskConstructor) error {

if len(taskType) == 0 {
return fmt.Errorf("task type cannot be empty")
}

if constructor == nil {
return fmt.Errorf("constructor function cannot be nil")
}

r.mu.Lock()
defer r.mu.Unlock()

if _, exists := r.constructors[taskType]; exists {
return fmt.Errorf("task type '%s' is already registered", taskType)
}

r.constructors[taskType] = constructor
return nil
}

// GetConstructor returns the constructor function for a given task type
func (r *TaskRegistry) GetConstructor(taskType string) (TaskConstructor, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
constructor, exists := r.constructors[taskType]
return constructor, exists
}

// UnregisterTask removes a task type from the registry (mainly for testing)
func (r *TaskRegistry) UnregisterTask(taskType string) {
r.mu.Lock()
defer r.mu.Unlock()
delete(r.constructors, taskType)
}

// ListRegisteredTypes returns all registered task types
func (r *TaskRegistry) ListRegisteredTypes() []string {
r.mu.RLock()
defer r.mu.RUnlock()

types := make([]string, 0, len(r.constructors))
for taskType := range r.constructors {
types = append(types, taskType)
}
return types
}

// Global task registry instance
var defaultRegistry = NewTaskRegistry()

// RegisterTask registers a custom task type with the global registry
func RegisterTask(taskType string, constructor TaskConstructor) error {
return defaultRegistry.RegisterTask(taskType, constructor)
}

// GetTaskConstructor returns the constructor function for a given task type from the global registry
func GetTaskConstructor(taskType string) (TaskConstructor, bool) {
return defaultRegistry.GetConstructor(taskType)
}

// ListRegisteredTaskTypes returns all registered task types from the global registry
func ListRegisteredTaskTypes() []string {
return defaultRegistry.ListRegisteredTypes()
}

type TaskBase struct {
// A runtime expression, if any, used to determine whether or not the task should be run.
If *RuntimeExpression `json:"if,omitempty" validate:"omitempty"`
Expand Down Expand Up @@ -118,23 +202,34 @@ func (tl *TaskList) UnmarshalJSON(data []byte) error {
return nil
}

var taskTypeRegistry = map[string]func() Task{
"call_http": func() Task { return &CallHTTP{} },
"call_openapi": func() Task { return &CallOpenAPI{} },
"call_grpc": func() Task { return &CallGRPC{} },
"call_asyncapi": func() Task { return &CallAsyncAPI{} },
"call": func() Task { return &CallFunction{} },
"do": func() Task { return &DoTask{} },
"fork": func() Task { return &ForkTask{} },
"emit": func() Task { return &EmitTask{} },
"for": func() Task { return &ForTask{} },
"listen": func() Task { return &ListenTask{} },
"raise": func() Task { return &RaiseTask{} },
"run": func() Task { return &RunTask{} },
"set": func() Task { return &SetTask{} },
"switch": func() Task { return &SwitchTask{} },
"try": func() Task { return &TryTask{} },
"wait": func() Task { return &WaitTask{} },
// Initialize built-in task types with the registry
func init() {

// Register all built-in task types
builtInTasks := map[string]TaskConstructor{
"call_http": func() Task { return &CallHTTP{} },
"call_openapi": func() Task { return &CallOpenAPI{} },
"call_grpc": func() Task { return &CallGRPC{} },
"call_asyncapi": func() Task { return &CallAsyncAPI{} },
"call": func() Task { return &CallFunction{} },
"do": func() Task { return &DoTask{} },
"fork": func() Task { return &ForkTask{} },
"emit": func() Task { return &EmitTask{} },
"for": func() Task { return &ForTask{} },
"listen": func() Task { return &ListenTask{} },
"raise": func() Task { return &RaiseTask{} },
"run": func() Task { return &RunTask{} },
"set": func() Task { return &SetTask{} },
"switch": func() Task { return &SwitchTask{} },
"try": func() Task { return &TryTask{} },
"wait": func() Task { return &WaitTask{} },
}

for taskType, constructor := range builtInTasks {
if err := defaultRegistry.RegisterTask(taskType, constructor); err != nil {
panic(fmt.Sprintf("failed to register built-in task type '%s': %v", taskType, err))
}
}
}

func unmarshalTask(key string, taskRaw json.RawMessage) (Task, error) {
Expand All @@ -150,27 +245,32 @@ func unmarshalTask(key string, taskRaw json.RawMessage) (Task, error) {
if callValue, hasCall := taskType["call"].(string); hasCall {
// Form composite key and check if it's in the registry
registryKey := fmt.Sprintf("call_%s", callValue)
if constructor, exists := taskTypeRegistry[registryKey]; exists {
if constructor, exists := defaultRegistry.GetConstructor(registryKey); exists {
task = constructor()
} else {
// Default to CallFunction for unrecognized call values
task = &CallFunction{}
if constructor, exists := defaultRegistry.GetConstructor("call"); exists {
task = constructor()
}
}
} else if _, hasFor := taskType["for"]; hasFor {
// Handle special case "for" that also has "do"
task = taskTypeRegistry["for"]()
if constructor, exists := defaultRegistry.GetConstructor("for"); exists {
task = constructor()
}
} else {
// Handle everything else (e.g., "do", "fork")
for typeKey := range taskType {
if constructor, exists := taskTypeRegistry[typeKey]; exists {
if constructor, exists := defaultRegistry.GetConstructor(typeKey); exists {
task = constructor()
break
}
}
}

if task == nil {
return nil, fmt.Errorf("unknown task type for key '%s'", key)
return nil, fmt.Errorf("unknown task type for key '%s'. Available types: %v",
key, defaultRegistry.ListRegisteredTypes())
}

// Populate the task with raw data
Expand Down
Loading
Loading