66import neat
77import pytest
88import torch
9+ from torch .fx .passes .utils .matcher_utils import SubgraphMatcher
910
1011# allow imports from repo root
1112sys .path .insert (0 , str (pathlib .Path (__file__ ).resolve ().parents [1 ]))
@@ -27,6 +28,104 @@ def make_config():
2728 )
2829
2930
31+ def get_node_signature (node ):
32+ # simple signature includes kind (operator name), types of inputs, and output type
33+ # TODO: for robust comparison, also need to compare attributes and potentially canonicalize constant values
34+ input_kinds = [inp .node ().kind () for inp in node .inputs ()]
35+
36+ attributes = {}
37+ if node .kind () == "prim::Constant" :
38+ if node .hasAttribute ("value" ):
39+ attributes ["value" ] = node .t ("value" )
40+ elif node .hasAttribute ("i" ):
41+ attributes ["value" ] = node .i ("i" )
42+ elif node .hasAttribute ("f" ):
43+ attributes ["value" ] = node .f ("f" )
44+ # TODO: finish
45+
46+ return (node .kind (), tuple (input_kinds ), node .output ().type (), tuple (sorted (attributes .items ())))
47+
48+
49+ def compare_jit_graphs_structural (original : torch .jit .ScriptModule , rebuilt : torch .jit .ScriptModule ) -> bool :
50+ original_inputs = list (original .graph .inputs ())
51+ rebuilt_inputs = list (rebuilt .graph .inputs ())
52+ original_outputs = list (original .graph .outputs ())
53+ rebuilt_outputs = list (rebuilt .graph .outputs ())
54+ if len (original_inputs ) != len (rebuilt_inputs ) or len (original_outputs ) != len (rebuilt_outputs ):
55+ print (
56+ f"Input/output counts differ: original.graph inputs={ len (original_inputs )} , outputs={ len (original_outputs )} vs rebuilt inputs={ len (rebuilt_inputs )} , outputs={ len (rebuilt_outputs )} " ,
57+ file = sys .stderr ,
58+ )
59+ return False
60+
61+ # default iterator for graph.nodes() is typically a topological sort
62+ original_nodes = list (original .graph .nodes ())
63+ rebuilt_nodes = list (rebuilt .graph .nodes ())
64+
65+ if len (original_nodes ) != len (rebuilt_nodes ):
66+ print (
67+ f"Number of nodes differ: original.graph has { len (original_nodes )} nodes, rebuilt has { len (rebuilt_nodes )} nodes" ,
68+ file = sys .stderr ,
69+ )
70+ return False
71+
72+ # create mapping from nodes to canonical representation based on signature + inputs
73+ original_node_map = {}
74+ rebuilt_node_map = {}
75+ for i , (original_node , rebuilt_node ) in enumerate (zip (original_nodes , rebuilt_nodes )):
76+ signature1 = get_node_signature (original_node )
77+ signature2 = get_node_signature (rebuilt_node )
78+
79+ if signature1 != signature2 :
80+ print (f"Signatures differ at node { i } :" , file = sys .stderr )
81+ print (f" original.graph Node Kind: { original_node .kind ()} " , file = sys .stderr )
82+ print (f" rebuilt Node Kind: { rebuilt_node .kind ()} " , file = sys .stderr )
83+ # TODO: add more detailed diffing here
84+ return False
85+
86+ # assumes a consistent order of inputs and that corresponding inputs have corresponding nodes
87+ for input_idx , (original_input_val , rebeuilt_input_val ) in enumerate (
88+ zip (original_node .inputs (), rebuilt_node .inputs ())
89+ ):
90+ if original_input_val .node ().kind () != rebeuilt_input_val .node ().kind ():
91+ print (f"Input kind differs for node { i } , input { input_idx } " , file = sys .stderr )
92+ return False
93+ # TODO: need to further compare value properties if they are constants or recursively
94+ # check if the input nodes themselves are structurally equivalent up to that point
95+
96+ original_params = dict (original .named_parameters ())
97+ rebuilt_params = dict (rebuilt .named_parameters ())
98+ if len (original_params ) != len (rebuilt_params ):
99+ print ("Parameter counts differ" , file = sys .stderr )
100+ return False
101+ for name , original_param in original_params .items ():
102+ if name not in rebuilt_params :
103+ print (f"Parameter '{ name } ' missing in rebuilt graph" , file = sys .stderr )
104+ return False
105+ rebuilt_param = rebuilt_params [name ]
106+ if not torch .equal (original_param , rebuilt_param ):
107+ print (f"Parameter '{ name } ' values differ" , file = sys .stderr )
108+ return False
109+
110+ if not compare_custom_data (original , rebuilt ):
111+ print ("Custom data attributes differ" , file = sys .stderr )
112+ return False
113+
114+ return True
115+
116+
117+ def compare_custom_data (original : torch .jit .ScriptModule , rebuilt : torch .jit .ScriptModule ) -> bool :
118+ if hasattr (original , "node_types" ) and hasattr (rebuilt , "node_types" ):
119+ if original .node_types != rebuilt .node_types :
120+ print ("node_types differ" , file = sys .stderr )
121+ return False
122+ if hasattr (original , "edge_index" ) and hasattr (rebuilt , "edge_index" ):
123+ if not torch .equal (original .edge_index , rebuilt .edge_index ):
124+ print ("edge_index differ" , file = sys .stderr )
125+ return False
126+ return True
127+
128+
30129@pytest .mark .parametrize ("pt_path" , glob .glob (os .path .join ("computation_graphs" , "optimizers" , "*.pt" )))
31130def test_graph_builder_rebuilds_pt (pt_path ):
32131 original = torch .jit .load (pt_path )
@@ -51,11 +150,5 @@ def test_graph_builder_rebuilds_pt(pt_path):
51150 assert len (list (rebuilt .parameters ())) == len (expected_edges )
52151 assert len (rebuilt .node_types ) == len (data .node_types )
53152
54- # Verify that the rebuilt computation graph is identical to the original
55- if str (rebuilt .graph ) != str (original .graph ):
56- print ("Original graph:\n " , original .graph )
57- print ("Rebuilt graph:\n " , rebuilt .graph )
58- assert str (rebuilt .graph ) == str (original .graph ), (
59- "\n Original graph:\n " + str (original .graph ) +
60- "\n Rebuilt graph:\n " + str (rebuilt .graph )
61- )
153+ # Verify that the rebuilt computation graph is structurally identical to the original
154+ assert compare_jit_graphs_structural (rebuilt , original )
0 commit comments