Skip to content

Commit f0da027

Browse files
committed
use structural comparison in topological order for rebuilt computation graph equality tests
1 parent 7ed87b2 commit f0da027

File tree

3 files changed

+106
-17
lines changed

3 files changed

+106
-17
lines changed

graph_builder.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
"""
2-
Generated by ChatGPT
3-
"""
41
from typing import Dict, List, Tuple
52

63
import torch
@@ -147,9 +144,9 @@ def forward(
147144
self,
148145
loss: torch.Tensor,
149146
prev_loss: torch.Tensor,
150-
named_params: List[Tuple[str, torch.Tensor]],
147+
named_parameters: List[Tuple[str, torch.Tensor]],
151148
) -> Dict[str, torch.Tensor]:
152-
params = [p for _, p in named_params]
149+
params = [p for _, p in named_parameters]
153150
all_inputs = [loss, prev_loss] + params
154151
features = torch.stack(all_inputs, 0)
155152

@@ -198,8 +195,6 @@ def rebuild_and_script(graph_dict, config, key) -> DynamicOptimizerModule:
198195

199196
# --- build a Python module and script it ---
200197
if genome.connections:
201-
module = DynamicOptimizerModule(
202-
genome, config.input_keys, config.output_keys, graph_dict
203-
)
198+
module = DynamicOptimizerModule(genome, config.input_keys, config.output_keys, graph_dict)
204199
return torch.jit.script(module)
205200
return None

population.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,8 @@ def evaluate_optimizer(self, optimizer, model, task, steps=10):
292292
task: The task on which to evaluate the optimizer.
293293
steps: Number of update iterations.
294294
"""
295-
# TODO: find way to correct for time improvements that are solely due to RAM cache tiers
295+
# TODO: clear all levels of RAM caches in between every run to create fair starting point
296+
# for comparison
296297
tracemalloc.start()
297298
start = time.perf_counter()
298299
prev_metrics_values = torch.tensor([0.0] * len(task.metrics))

tests/test_graph_builder.py

Lines changed: 101 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import neat
77
import pytest
88
import torch
9+
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
910

1011
# allow imports from repo root
1112
sys.path.insert(0, str(pathlib.Path(__file__).resolve().parents[1]))
@@ -27,6 +28,104 @@ def make_config():
2728
)
2829

2930

31+
def get_node_signature(node):
32+
# simple signature includes kind (operator name), types of inputs, and output type
33+
# TODO: for robust comparison, also need to compare attributes and potentially canonicalize constant values
34+
input_kinds = [inp.node().kind() for inp in node.inputs()]
35+
36+
attributes = {}
37+
if node.kind() == "prim::Constant":
38+
if node.hasAttribute("value"):
39+
attributes["value"] = node.t("value")
40+
elif node.hasAttribute("i"):
41+
attributes["value"] = node.i("i")
42+
elif node.hasAttribute("f"):
43+
attributes["value"] = node.f("f")
44+
# TODO: finish
45+
46+
return (node.kind(), tuple(input_kinds), node.output().type(), tuple(sorted(attributes.items())))
47+
48+
49+
def compare_jit_graphs_structural(original: torch.jit.ScriptModule, rebuilt: torch.jit.ScriptModule) -> bool:
50+
original_inputs = list(original.graph.inputs())
51+
rebuilt_inputs = list(rebuilt.graph.inputs())
52+
original_outputs = list(original.graph.outputs())
53+
rebuilt_outputs = list(rebuilt.graph.outputs())
54+
if len(original_inputs) != len(rebuilt_inputs) or len(original_outputs) != len(rebuilt_outputs):
55+
print(
56+
f"Input/output counts differ: original.graph inputs={len(original_inputs)}, outputs={len(original_outputs)} vs rebuilt inputs={len(rebuilt_inputs)}, outputs={len(rebuilt_outputs)}",
57+
file=sys.stderr,
58+
)
59+
return False
60+
61+
# default iterator for graph.nodes() is typically a topological sort
62+
original_nodes = list(original.graph.nodes())
63+
rebuilt_nodes = list(rebuilt.graph.nodes())
64+
65+
if len(original_nodes) != len(rebuilt_nodes):
66+
print(
67+
f"Number of nodes differ: original.graph has {len(original_nodes)} nodes, rebuilt has {len(rebuilt_nodes)} nodes",
68+
file=sys.stderr,
69+
)
70+
return False
71+
72+
# create mapping from nodes to canonical representation based on signature + inputs
73+
original_node_map = {}
74+
rebuilt_node_map = {}
75+
for i, (original_node, rebuilt_node) in enumerate(zip(original_nodes, rebuilt_nodes)):
76+
signature1 = get_node_signature(original_node)
77+
signature2 = get_node_signature(rebuilt_node)
78+
79+
if signature1 != signature2:
80+
print(f"Signatures differ at node {i}:", file=sys.stderr)
81+
print(f" original.graph Node Kind: {original_node.kind()}", file=sys.stderr)
82+
print(f" rebuilt Node Kind: {rebuilt_node.kind()}", file=sys.stderr)
83+
# TODO: add more detailed diffing here
84+
return False
85+
86+
# assumes a consistent order of inputs and that corresponding inputs have corresponding nodes
87+
for input_idx, (original_input_val, rebeuilt_input_val) in enumerate(
88+
zip(original_node.inputs(), rebuilt_node.inputs())
89+
):
90+
if original_input_val.node().kind() != rebeuilt_input_val.node().kind():
91+
print(f"Input kind differs for node {i}, input {input_idx}", file=sys.stderr)
92+
return False
93+
# TODO: need to further compare value properties if they are constants or recursively
94+
# check if the input nodes themselves are structurally equivalent up to that point
95+
96+
original_params = dict(original.named_parameters())
97+
rebuilt_params = dict(rebuilt.named_parameters())
98+
if len(original_params) != len(rebuilt_params):
99+
print("Parameter counts differ", file=sys.stderr)
100+
return False
101+
for name, original_param in original_params.items():
102+
if name not in rebuilt_params:
103+
print(f"Parameter '{name}' missing in rebuilt graph", file=sys.stderr)
104+
return False
105+
rebuilt_param = rebuilt_params[name]
106+
if not torch.equal(original_param, rebuilt_param):
107+
print(f"Parameter '{name}' values differ", file=sys.stderr)
108+
return False
109+
110+
if not compare_custom_data(original, rebuilt):
111+
print("Custom data attributes differ", file=sys.stderr)
112+
return False
113+
114+
return True
115+
116+
117+
def compare_custom_data(original: torch.jit.ScriptModule, rebuilt: torch.jit.ScriptModule) -> bool:
118+
if hasattr(original, "node_types") and hasattr(rebuilt, "node_types"):
119+
if original.node_types != rebuilt.node_types:
120+
print("node_types differ", file=sys.stderr)
121+
return False
122+
if hasattr(original, "edge_index") and hasattr(rebuilt, "edge_index"):
123+
if not torch.equal(original.edge_index, rebuilt.edge_index):
124+
print("edge_index differ", file=sys.stderr)
125+
return False
126+
return True
127+
128+
30129
@pytest.mark.parametrize("pt_path", glob.glob(os.path.join("computation_graphs", "optimizers", "*.pt")))
31130
def test_graph_builder_rebuilds_pt(pt_path):
32131
original = torch.jit.load(pt_path)
@@ -51,11 +150,5 @@ def test_graph_builder_rebuilds_pt(pt_path):
51150
assert len(list(rebuilt.parameters())) == len(expected_edges)
52151
assert len(rebuilt.node_types) == len(data.node_types)
53152

54-
# Verify that the rebuilt computation graph is identical to the original
55-
if str(rebuilt.graph) != str(original.graph):
56-
print("Original graph:\n", original.graph)
57-
print("Rebuilt graph:\n", rebuilt.graph)
58-
assert str(rebuilt.graph) == str(original.graph), (
59-
"\nOriginal graph:\n" + str(original.graph) +
60-
"\nRebuilt graph:\n" + str(rebuilt.graph)
61-
)
153+
# Verify that the rebuilt computation graph is structurally identical to the original
154+
assert compare_jit_graphs_structural(rebuilt, original)

0 commit comments

Comments
 (0)