@@ -18,8 +18,92 @@ import (
1818 "encoding/json"
1919 "errors"
2020 "fmt"
21+ "sync"
2122)
2223
24+ // TaskConstructor is a function that creates a new instance of a task type
25+ type TaskConstructor func () Task
26+
27+ // TaskRegistry manages the registration of task types for JSON unmarshaling
28+ type TaskRegistry struct {
29+ mu sync.RWMutex
30+ constructors map [string ]TaskConstructor
31+ }
32+
33+ // NewTaskRegistry creates a new task registry
34+ func NewTaskRegistry () * TaskRegistry {
35+ return & TaskRegistry {
36+ constructors : make (map [string ]TaskConstructor ),
37+ }
38+ }
39+
40+ // RegisterTask registers a custom task type with a constructor function
41+ func (r * TaskRegistry ) RegisterTask (taskType string , constructor TaskConstructor ) error {
42+
43+ if len (taskType ) == 0 {
44+ return fmt .Errorf ("task type cannot be empty" )
45+ }
46+
47+ if constructor == nil {
48+ return fmt .Errorf ("constructor function cannot be nil" )
49+ }
50+
51+ r .mu .Lock ()
52+ defer r .mu .Unlock ()
53+
54+ if _ , exists := r .constructors [taskType ]; exists {
55+ return fmt .Errorf ("task type '%s' is already registered" , taskType )
56+ }
57+
58+ r .constructors [taskType ] = constructor
59+ return nil
60+ }
61+
62+ // GetConstructor returns the constructor function for a given task type
63+ func (r * TaskRegistry ) GetConstructor (taskType string ) (TaskConstructor , bool ) {
64+ r .mu .RLock ()
65+ defer r .mu .RUnlock ()
66+ constructor , exists := r .constructors [taskType ]
67+ return constructor , exists
68+ }
69+
70+ // UnregisterTask removes a task type from the registry (mainly for testing)
71+ func (r * TaskRegistry ) UnregisterTask (taskType string ) {
72+ r .mu .Lock ()
73+ defer r .mu .Unlock ()
74+ delete (r .constructors , taskType )
75+ }
76+
77+ // ListRegisteredTypes returns all registered task types
78+ func (r * TaskRegistry ) ListRegisteredTypes () []string {
79+ r .mu .RLock ()
80+ defer r .mu .RUnlock ()
81+
82+ types := make ([]string , 0 , len (r .constructors ))
83+ for taskType := range r .constructors {
84+ types = append (types , taskType )
85+ }
86+ return types
87+ }
88+
89+ // Global task registry instance
90+ var defaultRegistry = NewTaskRegistry ()
91+
92+ // RegisterTask registers a custom task type with the global registry
93+ func RegisterTask (taskType string , constructor TaskConstructor ) error {
94+ return defaultRegistry .RegisterTask (taskType , constructor )
95+ }
96+
97+ // GetTaskConstructor returns the constructor function for a given task type from the global registry
98+ func GetTaskConstructor (taskType string ) (TaskConstructor , bool ) {
99+ return defaultRegistry .GetConstructor (taskType )
100+ }
101+
102+ // ListRegisteredTaskTypes returns all registered task types from the global registry
103+ func ListRegisteredTaskTypes () []string {
104+ return defaultRegistry .ListRegisteredTypes ()
105+ }
106+
23107type TaskBase struct {
24108 // A runtime expression, if any, used to determine whether or not the task should be run.
25109 If * RuntimeExpression `json:"if,omitempty" validate:"omitempty"`
@@ -118,23 +202,34 @@ func (tl *TaskList) UnmarshalJSON(data []byte) error {
118202 return nil
119203}
120204
121- var taskTypeRegistry = map [string ]func () Task {
122- "call_http" : func () Task { return & CallHTTP {} },
123- "call_openapi" : func () Task { return & CallOpenAPI {} },
124- "call_grpc" : func () Task { return & CallGRPC {} },
125- "call_asyncapi" : func () Task { return & CallAsyncAPI {} },
126- "call" : func () Task { return & CallFunction {} },
127- "do" : func () Task { return & DoTask {} },
128- "fork" : func () Task { return & ForkTask {} },
129- "emit" : func () Task { return & EmitTask {} },
130- "for" : func () Task { return & ForTask {} },
131- "listen" : func () Task { return & ListenTask {} },
132- "raise" : func () Task { return & RaiseTask {} },
133- "run" : func () Task { return & RunTask {} },
134- "set" : func () Task { return & SetTask {} },
135- "switch" : func () Task { return & SwitchTask {} },
136- "try" : func () Task { return & TryTask {} },
137- "wait" : func () Task { return & WaitTask {} },
205+ // Initialize built-in task types with the registry
206+ func init () {
207+
208+ // Register all built-in task types
209+ builtInTasks := map [string ]TaskConstructor {
210+ "call_http" : func () Task { return & CallHTTP {} },
211+ "call_openapi" : func () Task { return & CallOpenAPI {} },
212+ "call_grpc" : func () Task { return & CallGRPC {} },
213+ "call_asyncapi" : func () Task { return & CallAsyncAPI {} },
214+ "call" : func () Task { return & CallFunction {} },
215+ "do" : func () Task { return & DoTask {} },
216+ "fork" : func () Task { return & ForkTask {} },
217+ "emit" : func () Task { return & EmitTask {} },
218+ "for" : func () Task { return & ForTask {} },
219+ "listen" : func () Task { return & ListenTask {} },
220+ "raise" : func () Task { return & RaiseTask {} },
221+ "run" : func () Task { return & RunTask {} },
222+ "set" : func () Task { return & SetTask {} },
223+ "switch" : func () Task { return & SwitchTask {} },
224+ "try" : func () Task { return & TryTask {} },
225+ "wait" : func () Task { return & WaitTask {} },
226+ }
227+
228+ for taskType , constructor := range builtInTasks {
229+ if err := defaultRegistry .RegisterTask (taskType , constructor ); err != nil {
230+ panic (fmt .Sprintf ("failed to register built-in task type '%s': %v" , taskType , err ))
231+ }
232+ }
138233}
139234
140235func unmarshalTask (key string , taskRaw json.RawMessage ) (Task , error ) {
@@ -150,27 +245,32 @@ func unmarshalTask(key string, taskRaw json.RawMessage) (Task, error) {
150245 if callValue , hasCall := taskType ["call" ].(string ); hasCall {
151246 // Form composite key and check if it's in the registry
152247 registryKey := fmt .Sprintf ("call_%s" , callValue )
153- if constructor , exists := taskTypeRegistry [ registryKey ] ; exists {
248+ if constructor , exists := defaultRegistry . GetConstructor ( registryKey ) ; exists {
154249 task = constructor ()
155250 } else {
156251 // Default to CallFunction for unrecognized call values
157- task = & CallFunction {}
252+ if constructor , exists := defaultRegistry .GetConstructor ("call" ); exists {
253+ task = constructor ()
254+ }
158255 }
159256 } else if _ , hasFor := taskType ["for" ]; hasFor {
160257 // Handle special case "for" that also has "do"
161- task = taskTypeRegistry ["for" ]()
258+ if constructor , exists := defaultRegistry .GetConstructor ("for" ); exists {
259+ task = constructor ()
260+ }
162261 } else {
163262 // Handle everything else (e.g., "do", "fork")
164263 for typeKey := range taskType {
165- if constructor , exists := taskTypeRegistry [ typeKey ] ; exists {
264+ if constructor , exists := defaultRegistry . GetConstructor ( typeKey ) ; exists {
166265 task = constructor ()
167266 break
168267 }
169268 }
170269 }
171270
172271 if task == nil {
173- return nil , fmt .Errorf ("unknown task type for key '%s'" , key )
272+ return nil , fmt .Errorf ("unknown task type for key '%s'. Available types: %v" ,
273+ key , defaultRegistry .ListRegisteredTypes ())
174274 }
175275
176276 // Populate the task with raw data
0 commit comments