Skip to content

Commit 54a60c7

Browse files
committed
Fix GraphViz visualization to handle FIRRTL arrays and vectors
Problem: visualize() crashed with "syntax error near ']'" when rendering circuits with FIRRTL array types like `reg zs : UInt<8>[3]`. Root cause: Array element accesses (zs[0], zs[1], products[2]) were being converted to DOT identifiers (zs_0, zs_1, products_2) but never declared as nodes. GraphViz requires all referenced nodes to be declared before use. Fix: - Collect all referenced nodes from connections (including array elements) - Categorize nodes by type (input/output/reg/wire) matching base names - Declare all referenced nodes in DOT output - Add wire nodes for internal signals - Skip self-loops and verify both endpoints exist before creating edges - Add debug output to show generated DOT on error This allows FirFilter and other array-heavy circuits to visualize correctly.
1 parent 2c266d8 commit 54a60c7

File tree

1 file changed

+92
-22
lines changed

1 file changed

+92
-22
lines changed

source/load-ivy.sc

Lines changed: 92 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,53 @@ def visualize(gen: () => chisel3.RawModule): Unit = {
137137
// Parse FIRRTL to extract structure
138138
val lines = firrtlString.split("\n")
139139
val moduleName = lines.find(_.trim.startsWith("module ")).map(_.trim.split(" ")(1).replace(":", "")).getOrElse("Module")
140-
val inputs = lines.filter(_.trim.startsWith("input ")).map(l => sanitizeName(l.trim.split(" ")(1).split(":")(0)))
141-
val outputs = lines.filter(_.trim.startsWith("output ")).map(l => sanitizeName(l.trim.split(" ")(1).split(":")(0)))
142-
val regs = lines.filter(_.trim.startsWith("reg ")).map(l => sanitizeName(l.trim.split(" ")(1).split(":")(0)))
143-
val wires = lines.filter(_.trim.startsWith("wire ")).map(l => sanitizeName(l.trim.split(" ")(1).split(":")(0)))
140+
141+
// Collect all referenced nodes from connections (this includes array elements)
142+
val referencedNodes = scala.collection.mutable.Set[String]()
143+
lines.filter(_.contains("<=")).foreach { line =>
144+
val parts = line.trim.split("<=").map(_.trim)
145+
if (parts.length == 2) {
146+
val targetRaw = parts(0).split("\\.")(0).split("\\(")(0).trim
147+
val sourceRaw = parts(1).split("\\.")(0).split("\\(")(0).split(" ")(0).trim
148+
referencedNodes += sanitizeName(targetRaw)
149+
referencedNodes += sanitizeName(sourceRaw)
150+
}
151+
}
152+
153+
// Extract declared nodes (base names without array indices)
154+
val inputs = lines.filter(_.trim.startsWith("input ")).map { l =>
155+
val name = l.trim.split(" ")(1).split(":")(0).trim
156+
sanitizeName(name)
157+
}.toSet
158+
159+
val outputs = lines.filter(_.trim.startsWith("output ")).map { l =>
160+
val name = l.trim.split(" ")(1).split(":")(0).trim
161+
sanitizeName(name)
162+
}.toSet
163+
164+
val regs = lines.filter(_.trim.startsWith("reg ")).map { l =>
165+
val name = l.trim.split(" ")(1).split(":")(0).trim
166+
sanitizeName(name)
167+
}.toSet
168+
169+
val wires = lines.filter(_.trim.startsWith("wire ")).map { l =>
170+
val name = l.trim.split(" ")(1).split(":")(0).trim
171+
sanitizeName(name)
172+
}.toSet
173+
174+
// Categorize all referenced nodes
175+
val inputNodes = referencedNodes.filter { node =>
176+
inputs.contains(node) || inputs.exists(i => node.startsWith(i + "_"))
177+
}
178+
val outputNodes = referencedNodes.filter { node =>
179+
outputs.contains(node) || outputs.exists(o => node.startsWith(o + "_"))
180+
}
181+
val regNodes = referencedNodes.filter { node =>
182+
regs.contains(node) || regs.exists(r => node.startsWith(r + "_"))
183+
}
184+
val wireNodes = referencedNodes.filter { node =>
185+
wires.contains(node) || wires.exists(w => node.startsWith(w + "_"))
186+
} -- inputNodes -- outputNodes -- regNodes
144187

145188
// Generate GraphViz DOT
146189
val dot = new StringBuilder
@@ -149,39 +192,58 @@ def visualize(gen: () => chisel3.RawModule): Unit = {
149192
dot ++= " node [shape=box, style=rounded];\n\n"
150193

151194
// Input nodes
152-
dot ++= " subgraph cluster_inputs {\n"
153-
dot ++= " label=\"Inputs\";\n"
154-
dot ++= " style=filled; color=lightblue;\n"
155-
inputs.foreach(i => dot ++= s" $i [shape=circle, fillcolor=lightgreen, style=filled];\n")
156-
dot ++= " }\n\n"
195+
if (inputNodes.nonEmpty) {
196+
dot ++= " subgraph cluster_inputs {\n"
197+
dot ++= " label=\"Inputs\";\n"
198+
dot ++= " style=filled; color=lightblue;\n"
199+
inputNodes.foreach(i => dot ++= s" $i [shape=circle, fillcolor=lightgreen, style=filled];\n")
200+
dot ++= " }\n\n"
201+
}
157202

158203
// Output nodes
159-
dot ++= " subgraph cluster_outputs {\n"
160-
dot ++= " label=\"Outputs\";\n"
161-
dot ++= " style=filled; color=lightblue;\n"
162-
outputs.foreach(o => dot ++= s" $o [shape=doublecircle, fillcolor=lightcoral, style=filled];\n")
163-
dot ++= " }\n\n"
204+
if (outputNodes.nonEmpty) {
205+
dot ++= " subgraph cluster_outputs {\n"
206+
dot ++= " label=\"Outputs\";\n"
207+
dot ++= " style=filled; color=lightblue;\n"
208+
outputNodes.foreach(o => dot ++= s" $o [shape=doublecircle, fillcolor=lightcoral, style=filled];\n")
209+
dot ++= " }\n\n"
210+
}
164211

165212
// Register nodes
166-
if (regs.nonEmpty) {
213+
if (regNodes.nonEmpty) {
167214
dot ++= " subgraph cluster_regs {\n"
168215
dot ++= " label=\"Registers\";\n"
169216
dot ++= " style=filled; color=lightyellow;\n"
170-
regs.foreach { r =>
217+
regNodes.foreach { r =>
171218
dot ++= s" $r [shape=box, fillcolor=yellow, style=filled];\n"
172219
}
173220
dot ++= " }\n\n"
174221
}
175222

223+
// Wire nodes (internal signals)
224+
if (wireNodes.nonEmpty) {
225+
wireNodes.foreach { w =>
226+
dot ++= s" $w [shape=box, style=filled, fillcolor=lightgray];\n"
227+
}
228+
dot ++= "\n"
229+
}
230+
176231
// Parse connections from FIRRTL
177232
lines.filter(l => l.contains("<=") && !l.trim.startsWith("reset")).foreach { line =>
178233
val parts = line.trim.split("<=").map(_.trim)
179234
if (parts.length == 2) {
180-
val targetRaw = parts(0).split("\\.")(0).split("\\(")(0)
181-
val sourceRaw = parts(1).split("\\.")(0).split("\\(")(0).split(" ")(0)
182-
if (!sourceRaw.startsWith("UInt") && !sourceRaw.contains("\"") && !sourceRaw.startsWith("_")) {
183-
val target = sanitizeName(targetRaw)
184-
val source = sanitizeName(sourceRaw)
235+
// Extract base name, handling array indices and field accesses
236+
val targetRaw = parts(0).split("\\.")(0).split("\\(")(0).trim
237+
val sourceRaw = parts(1).split("\\.")(0).split("\\(")(0).split(" ")(0).trim
238+
239+
// Sanitize array indices in both target and source
240+
val target = sanitizeName(targetRaw)
241+
val source = sanitizeName(sourceRaw)
242+
243+
// Skip constants, temporaries, self-loops, and invalid nodes
244+
if (!sourceRaw.startsWith("UInt") && !sourceRaw.contains("\"") &&
245+
!sourceRaw.startsWith("_") && source != target &&
246+
referencedNodes.contains(source) && referencedNodes.contains(target)) {
185247
dot ++= s" $source -> $target;\n"
186248
}
187249
}
@@ -193,12 +255,20 @@ def visualize(gen: () => chisel3.RawModule): Unit = {
193255
val dotFile = File.createTempFile("circuit", ".dot")
194256
val svgFile = File.createTempFile("circuit", ".svg")
195257
val pw = new PrintWriter(dotFile)
196-
pw.write(dot.toString)
258+
val dotContent = dot.toString
259+
pw.write(dotContent)
197260
pw.close()
198261

199262
// Generate SVG using graphviz
200263
val result = s"dot -Tsvg ${dotFile.getAbsolutePath} -o ${svgFile.getAbsolutePath}".!
201264

265+
if (result != 0) {
266+
// Print DOT content for debugging
267+
println("=== Generated DOT (debug) ===")
268+
println(dotContent)
269+
println("=== End DOT ===")
270+
}
271+
202272
if (result == 0 && svgFile.exists()) {
203273
val svgContent = scala.io.Source.fromFile(svgFile).mkString
204274
Display.html(svgContent)

0 commit comments

Comments
 (0)