@@ -45,20 +45,52 @@ def train(
4545 self ,
4646 G : Graph ,
4747 model_name : str ,
48- scoring_function ,
49- num_epochs ,
50- embedding_dimension ,
51- epochs_per_checkpoint ,
48+ * ,
49+ num_epochs : int ,
50+ embedding_dimension : int ,
51+ epochs_per_checkpoint : Optional [int ] = None ,
52+ load_from_checkpoint : Optional [tuple [str , int ]] = None ,
53+ split_ratios = None ,
54+ scoring_function : str = "transe" ,
55+ p_norm : float = 1.0 ,
56+ batch_size : int = 512 ,
57+ test_batch_size : int = 512 ,
58+ optimizer : str = "adam" ,
59+ optimizer_kwargs = None ,
60+ lr_scheduler : str = "ConstantLR" ,
61+ lr_scheduler_kwargs = None ,
62+ loss_function : str = "MarginRanking" ,
63+ loss_function_kwargs = None ,
64+ negative_sampling_size : int = 1 ,
65+ use_node_type_aware_sampler : bool = False ,
66+ k_value : int = 10 ,
67+ do_validation : bool = True ,
68+ do_test : bool = True ,
69+ filtered_metrics : bool = False ,
70+ epochs_per_val : int = 50 ,
71+ inner_norm : bool = True ,
72+ init_bound : Optional [float ] = None ,
5273 mlflow_experiment_name : Optional [str ] = None ,
5374 ) -> Series :
54- graph_config = {"name" : G .name ()}
75+ if epochs_per_checkpoint is None :
76+ epochs_per_checkpoint = max (num_epochs / 10 , 1 )
77+ if loss_function_kwargs is None :
78+ loss_function_kwargs = dict (margin = 1.0 , adversarial_temperature = 1.0 , gamma = 20.0 )
79+ if lr_scheduler_kwargs is None :
80+ lr_scheduler_kwargs = dict (factor = 1 , total_iters = 1000 )
81+ if optimizer_kwargs is None :
82+ optimizer_kwargs = {"lr" : 0.01 , "weight_decay" : 0.0005 }
83+ if split_ratios is None :
84+ split_ratios = {"TRAIN" : 0.8 , "TEST" : 0.2 }
5585
5686 algo_config = {
57- "scoring_function" : scoring_function ,
58- "num_epochs" : num_epochs ,
59- "embedding_dimension" : embedding_dimension ,
60- "epochs_per_checkpoint" : epochs_per_checkpoint ,
87+ key : value
88+ for key , value in locals ().items ()
89+ if (key not in ["self" , "G" , "mlflow_experiment_name" , "model_name" ]) and (value is not None )
6190 }
91+ print (algo_config )
92+
93+ graph_config = {"name" : G .name ()}
6294
6395 config = {
6496 "user_name" : "DUMMY_USER" ,
0 commit comments