Skip to content

Commit 27de322

Browse files
committed
added support for registering custom tasks
Signed-off-by: hughneale <mail@hughneale.com>
1 parent 592f31d commit 27de322

File tree

2 files changed

+635
-22
lines changed

2 files changed

+635
-22
lines changed

model/task.go

Lines changed: 122 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,92 @@ import (
1818
"encoding/json"
1919
"errors"
2020
"fmt"
21+
"sync"
2122
)
2223

24+
// TaskConstructor is a function that creates a new instance of a task type
25+
type TaskConstructor func() Task
26+
27+
// TaskRegistry manages the registration of task types for JSON unmarshaling
28+
type TaskRegistry struct {
29+
mu sync.RWMutex
30+
constructors map[string]TaskConstructor
31+
}
32+
33+
// NewTaskRegistry creates a new task registry
34+
func NewTaskRegistry() *TaskRegistry {
35+
return &TaskRegistry{
36+
constructors: make(map[string]TaskConstructor),
37+
}
38+
}
39+
40+
// RegisterTask registers a custom task type with a constructor function
41+
func (r *TaskRegistry) RegisterTask(taskType string, constructor TaskConstructor) error {
42+
43+
if len(taskType) == 0 {
44+
return fmt.Errorf("task type cannot be empty")
45+
}
46+
47+
if constructor == nil {
48+
return fmt.Errorf("constructor function cannot be nil")
49+
}
50+
51+
r.mu.Lock()
52+
defer r.mu.Unlock()
53+
54+
if _, exists := r.constructors[taskType]; exists {
55+
return fmt.Errorf("task type '%s' is already registered", taskType)
56+
}
57+
58+
r.constructors[taskType] = constructor
59+
return nil
60+
}
61+
62+
// GetConstructor returns the constructor function for a given task type
63+
func (r *TaskRegistry) GetConstructor(taskType string) (TaskConstructor, bool) {
64+
r.mu.RLock()
65+
defer r.mu.RUnlock()
66+
constructor, exists := r.constructors[taskType]
67+
return constructor, exists
68+
}
69+
70+
// UnregisterTask removes a task type from the registry (mainly for testing)
71+
func (r *TaskRegistry) UnregisterTask(taskType string) {
72+
r.mu.Lock()
73+
defer r.mu.Unlock()
74+
delete(r.constructors, taskType)
75+
}
76+
77+
// ListRegisteredTypes returns all registered task types
78+
func (r *TaskRegistry) ListRegisteredTypes() []string {
79+
r.mu.RLock()
80+
defer r.mu.RUnlock()
81+
82+
types := make([]string, 0, len(r.constructors))
83+
for taskType := range r.constructors {
84+
types = append(types, taskType)
85+
}
86+
return types
87+
}
88+
89+
// Global task registry instance
90+
var defaultRegistry = NewTaskRegistry()
91+
92+
// RegisterTask registers a custom task type with the global registry
93+
func RegisterTask(taskType string, constructor TaskConstructor) error {
94+
return defaultRegistry.RegisterTask(taskType, constructor)
95+
}
96+
97+
// GetTaskConstructor returns the constructor function for a given task type from the global registry
98+
func GetTaskConstructor(taskType string) (TaskConstructor, bool) {
99+
return defaultRegistry.GetConstructor(taskType)
100+
}
101+
102+
// ListRegisteredTaskTypes returns all registered task types from the global registry
103+
func ListRegisteredTaskTypes() []string {
104+
return defaultRegistry.ListRegisteredTypes()
105+
}
106+
23107
type TaskBase struct {
24108
// A runtime expression, if any, used to determine whether or not the task should be run.
25109
If *RuntimeExpression `json:"if,omitempty" validate:"omitempty"`
@@ -118,23 +202,34 @@ func (tl *TaskList) UnmarshalJSON(data []byte) error {
118202
return nil
119203
}
120204

121-
var taskTypeRegistry = map[string]func() Task{
122-
"call_http": func() Task { return &CallHTTP{} },
123-
"call_openapi": func() Task { return &CallOpenAPI{} },
124-
"call_grpc": func() Task { return &CallGRPC{} },
125-
"call_asyncapi": func() Task { return &CallAsyncAPI{} },
126-
"call": func() Task { return &CallFunction{} },
127-
"do": func() Task { return &DoTask{} },
128-
"fork": func() Task { return &ForkTask{} },
129-
"emit": func() Task { return &EmitTask{} },
130-
"for": func() Task { return &ForTask{} },
131-
"listen": func() Task { return &ListenTask{} },
132-
"raise": func() Task { return &RaiseTask{} },
133-
"run": func() Task { return &RunTask{} },
134-
"set": func() Task { return &SetTask{} },
135-
"switch": func() Task { return &SwitchTask{} },
136-
"try": func() Task { return &TryTask{} },
137-
"wait": func() Task { return &WaitTask{} },
205+
// Initialize built-in task types with the registry
206+
func init() {
207+
208+
// Register all built-in task types
209+
builtInTasks := map[string]TaskConstructor{
210+
"call_http": func() Task { return &CallHTTP{} },
211+
"call_openapi": func() Task { return &CallOpenAPI{} },
212+
"call_grpc": func() Task { return &CallGRPC{} },
213+
"call_asyncapi": func() Task { return &CallAsyncAPI{} },
214+
"call": func() Task { return &CallFunction{} },
215+
"do": func() Task { return &DoTask{} },
216+
"fork": func() Task { return &ForkTask{} },
217+
"emit": func() Task { return &EmitTask{} },
218+
"for": func() Task { return &ForTask{} },
219+
"listen": func() Task { return &ListenTask{} },
220+
"raise": func() Task { return &RaiseTask{} },
221+
"run": func() Task { return &RunTask{} },
222+
"set": func() Task { return &SetTask{} },
223+
"switch": func() Task { return &SwitchTask{} },
224+
"try": func() Task { return &TryTask{} },
225+
"wait": func() Task { return &WaitTask{} },
226+
}
227+
228+
for taskType, constructor := range builtInTasks {
229+
if err := defaultRegistry.RegisterTask(taskType, constructor); err != nil {
230+
panic(fmt.Sprintf("failed to register built-in task type '%s': %v", taskType, err))
231+
}
232+
}
138233
}
139234

140235
func unmarshalTask(key string, taskRaw json.RawMessage) (Task, error) {
@@ -150,27 +245,32 @@ func unmarshalTask(key string, taskRaw json.RawMessage) (Task, error) {
150245
if callValue, hasCall := taskType["call"].(string); hasCall {
151246
// Form composite key and check if it's in the registry
152247
registryKey := fmt.Sprintf("call_%s", callValue)
153-
if constructor, exists := taskTypeRegistry[registryKey]; exists {
248+
if constructor, exists := defaultRegistry.GetConstructor(registryKey); exists {
154249
task = constructor()
155250
} else {
156251
// Default to CallFunction for unrecognized call values
157-
task = &CallFunction{}
252+
if constructor, exists := defaultRegistry.GetConstructor("call"); exists {
253+
task = constructor()
254+
}
158255
}
159256
} else if _, hasFor := taskType["for"]; hasFor {
160257
// Handle special case "for" that also has "do"
161-
task = taskTypeRegistry["for"]()
258+
if constructor, exists := defaultRegistry.GetConstructor("for"); exists {
259+
task = constructor()
260+
}
162261
} else {
163262
// Handle everything else (e.g., "do", "fork")
164263
for typeKey := range taskType {
165-
if constructor, exists := taskTypeRegistry[typeKey]; exists {
264+
if constructor, exists := defaultRegistry.GetConstructor(typeKey); exists {
166265
task = constructor()
167266
break
168267
}
169268
}
170269
}
171270

172271
if task == nil {
173-
return nil, fmt.Errorf("unknown task type for key '%s'", key)
272+
return nil, fmt.Errorf("unknown task type for key '%s'. Available types: %v",
273+
key, defaultRegistry.ListRegisteredTypes())
174274
}
175275

176276
// Populate the task with raw data

0 commit comments

Comments
 (0)