22import glob
33import inspect
44import json
5- import logging
65import os
76import random
87import sys
8+ import re
99from typing import Dict , List , Any , Callable , Tuple
1010
1111import black
1212
13- from utils import import_custom_nodes , add_comfyui_directory_to_sys_path , get_value_at_index
1413
15- sys . path . append ( '../' )
14+ from utils import import_custom_nodes , find_path , add_comfyui_directory_to_sys_path , add_extra_model_paths , get_value_at_index
1615
16+ sys .path .append ('../' )
1717from nodes import NODE_CLASS_MAPPINGS
1818
1919
20- logging .basicConfig (level = logging .INFO )
21-
22-
2320class FileHandler :
2421 """Handles reading and writing files.
2522
@@ -217,7 +214,7 @@ def generate_workflow(self, load_order: List, filename: str = 'generated_code_wo
217214 continue
218215
219216 class_type , import_statement , class_code = self .get_class_info (class_type )
220- initialized_objects [class_type ] = class_type . lower (). strip ( )
217+ initialized_objects [class_type ] = self . clean_variable_name ( class_type )
221218 if class_type in self .base_node_class_mappings .keys ():
222219 import_statements .add (import_statement )
223220 if class_type not in self .base_node_class_mappings .keys ():
@@ -234,9 +231,9 @@ def generate_workflow(self, load_order: List, filename: str = 'generated_code_wo
234231 inputs ['unique_id' ] = random .randint (1 , 2 ** 64 )
235232
236233 # Create executed variable and generate code
237- executed_variables [idx ] = f'{ class_type . lower (). strip ( )} _{ idx } '
234+ executed_variables [idx ] = f'{ self . clean_variable_name ( class_type )} _{ idx } '
238235 inputs = self .update_inputs (inputs , executed_variables )
239-
236+
240237 if is_special_function :
241238 special_functions_code .append (self .create_function_call_code (initialized_objects [class_type ], class_def .FUNCTION , executed_variables [idx ], is_special_function , ** inputs ))
242239 else :
@@ -306,11 +303,11 @@ def assemble_python_code(self, import_statements: set, speical_functions_code: L
306303 """
307304 # Get the source code of the utils functions as a string
308305 func_strings = []
309- for func in [add_comfyui_directory_to_sys_path , get_value_at_index ]:
306+ for func in [get_value_at_index , find_path , add_comfyui_directory_to_sys_path , add_extra_model_paths ]:
310307 func_strings .append (f'\n { inspect .getsource (func )} ' )
311308 # Define static import statements required for the script
312309 static_imports = ['import os' , 'import random' , 'import sys' , 'from typing import Sequence, Mapping, Any, Union' ,
313- 'import torch' ] + func_strings + ['\n \n add_comfyui_directory_to_sys_path()' ]
310+ 'import torch' ] + func_strings + ['\n \n add_comfyui_directory_to_sys_path()\n add_extra_model_paths() \n ' ]
314311 # Check if custom nodes should be included
315312 if custom_nodes :
316313 static_imports .append (f'\n { inspect .getsource (import_custom_nodes )} \n ' )
@@ -328,7 +325,7 @@ def assemble_python_code(self, import_statements: set, speical_functions_code: L
328325 final_code = black .format_str (final_code , mode = black .Mode ())
329326
330327 return final_code
331-
328+
332329 def get_class_info (self , class_type : str ) -> Tuple [str , str , str ]:
333330 """Generates and returns necessary information about class type.
334331
@@ -339,12 +336,36 @@ def get_class_info(self, class_type: str) -> Tuple[str, str, str]:
339336 Tuple[str, str, str]: Updated class type, import statement string, class initialization code.
340337 """
341338 import_statement = class_type
339+ variable_name = self .clean_variable_name (class_type )
342340 if class_type in self .base_node_class_mappings .keys ():
343- class_code = f'{ class_type . lower (). strip () } = { class_type .strip ()} ()'
341+ class_code = f'{ variable_name } = { class_type .strip ()} ()'
344342 else :
345- class_code = f'{ class_type . lower (). strip () } = NODE_CLASS_MAPPINGS["{ class_type } "]()'
343+ class_code = f'{ variable_name } = NODE_CLASS_MAPPINGS["{ class_type } "]()'
346344
347345 return class_type , import_statement , class_code
346+
347+ @staticmethod
348+ def clean_variable_name (class_type : str ) -> str :
349+ """
350+ Remove any characters from variable name that could cause errors running the Python script.
351+
352+ Args:
353+ class_type (str): Class type.
354+
355+ Returns:
356+ str: Cleaned variable name with no special characters or spaces
357+ """
358+ # Convert to lowercase and replace spaces with underscores
359+ clean_name = class_type .lower ().strip ().replace ("-" , "_" ).replace (" " , "_" )
360+
361+ # Remove characters that are not letters, numbers, or underscores
362+ clean_name = re .sub (r'[^a-z0-9_]' , '' , clean_name )
363+
364+ # Ensure that it doesn't start with a number
365+ if clean_name [0 ].isdigit ():
366+ clean_name = "_" + clean_name
367+
368+ return clean_name
348369
349370 def get_function_parameters (self , func : Callable ) -> List :
350371 """Get the names of a function's parameters.
0 commit comments