pytact.graph_visualize
1import inflection 2import graphviz 3from pytact.data_reader import ProofState, Node 4 5# Load the cap'n proto library, and the communication specification in 'graph_api.capnp' 6import capnp 7import pytact.graph_api_capnp as graph_api_capnp 8 9# TODO: Clean up and unify all the functions below 10 11arrow_heads = [ "dot", "inv", "odot", "invdot", "invodot" ] 12edge_arrow_map = {} 13for group in graph_api_capnp.groupedEdges: 14 count = 0 15 for sort in group.conflatable: 16 edge_arrow_map[sort] = arrow_heads[count] 17 count += 1 18 19edge_arrow_map2 = {e.raw : arrow for (e, arrow) in edge_arrow_map.items()} 20 21def visualize_proof_state(state: ProofState): 22 23 dot = graphviz.Digraph() 24 dot.attr('graph', ordering="out") 25 26 seen = set() 27 nodes_left = 100 28 29 def recurse(node: Node, depth): 30 nonlocal seen 31 nonlocal nodes_left 32 33 id = str(node) 34 if id in seen: 35 return id 36 seen.add(id) 37 nodes_left -= 1 38 if nodes_left < 0: 39 id = 'trunc' + str(nodes_left) 40 dot.node(id, 'truncated') 41 return id 42 43 if d := node.definition: 44 label = d.name 45 else: 46 label = inflection.camelize(str(node.label.which.name.lower())) 47 dot.node(id, label=label) 48 if node.definition: 49 depth -= 1 50 if depth >= 0: 51 for edge, child in node.children: 52 cid = recurse(child, depth) 53 dot.edge(id, cid, arrowtail=edge_arrow_map2[edge], dir="both") 54 return id 55 recurse(state.root, 0) 56 57 dot.render(filename='python_graph', view=False, cleanup=True) 58 59def visualize(graph, state, showLabel = False, graph1 = None, 60 filename='python_graph', cleanup=True): 61 nodes = graph.nodes 62 root = state.root 63 context = state.context 64 assert all(n < len(nodes) for n in context) 65 66 dot = graphviz.Digraph() 67 dot.attr('graph', ordering="out") 68 for node, value in enumerate(nodes): 69 label = value.label.which() 70 if node in context: label = str(node) + ': ['+label+']' 71 if node == root: 72 label = "Root: "+label 73 dot.node(str(node), label) 74 for node, value in enumerate(nodes): 75 for edge in list(graph.edges)[value.childrenIndex:value.childrenIndex+value.childrenCount]: 76 if graph1 != None and edge.target.depIndex == 1: 77 target = '1#'+str(edge.target.nodeIndex) 78 dot.node(target, "Global:" 79 + str(graph1.nodes[edge.target.nodeIndex].label.definition.name)) 80 else: 81 target = str(edge.target.nodeIndex) 82 if showLabel: 83 label = str(edge.label) 84 else: 85 label = "" 86 dot.edge(str(node), target, label=label, 87 arrowtail=edge_arrow_map[edge.label], dir="both") 88 89 dot.render(filename=filename, view=False, cleanup=cleanup) 90 91def visualize_defs(graph, defs, showLabel=False, filename='python_grapn', cleanup=True): 92 nodes = graph.nodes 93 assert all(n < len(nodes) for n in defs) 94 dot = graphviz.Digraph() 95 dot.attr('graph', ordering="out") 96 for node, value in enumerate(nodes): 97 label = value.label.which() 98 if node in defs: label = str(node) + ': ' + value.label.definition.name 99 dot.node(str(node), label) 100 for node, value in enumerate(nodes): 101 for edge in list(graph.edges)[value.childrenIndex:value.childrenIndex+value.childrenCount]: 102 if showLabel: 103 label = str(edge.label) 104 else: 105 label = "" 106 dot.edge(str(node), str(edge.target.nodeIndex), label=label, 107 arrowtail=edge_arrow_map[edge.label], dir="both") 108 109 dot.render(filename=filename, view=False, cleanup=True) 110 111def visualize_exception(reason, filename='visualize_graph.pdf', cleanup=None): 112 dot = graphviz.Digraph() 113 dot.node(str(reason), str(reason)) 114 dot.render(filename, view=False, cleanup=cleanup) 115 116 117class Visualizer: 118 def __init__(self, filename, count, show_labels, cleanup): 119 self.filename = filename 120 self.cnt = 0 121 self.count = count 122 self.show_labels = show_labels 123 self.cleanup = cleanup 124 def _visualize(self, filename, result): 125 if result.which() == 'newState': 126 visualize(result.newState.graph, result.newState.state, 127 filename=filename, showLabel=self.show_labels, cleanup=self.cleanup) 128 else: 129 visualize_exception(result, filename=self.filename, cleanup=self.cleanup) 130 131 def render(self, result): 132 filename = self.filename 133 if self.count: 134 filename += str(self.cnt) 135 self.cnt += 1 136 137 self._visualize(filename, result)
arrow_heads =
['dot', 'inv', 'odot', 'invdot', 'invodot']
edge_arrow_map =
{<contextElem enum>: 'dot', <contextSubject enum>: 'inv', <contextDefType enum>: 'dot', <contextDefTerm enum>: 'inv', <constType enum>: 'dot', <constUndef enum>: 'inv', <constDef enum>: 'odot', <constOpaqueDef enum>: 'invdot', <constPrimitive enum>: 'invodot', <indType enum>: 'dot', <indConstruct enum>: 'inv', <indProjection enum>: 'odot', <projTerm enum>: 'dot', <constructTerm enum>: 'dot', <castType enum>: 'dot', <castTerm enum>: 'inv', <prodType enum>: 'dot', <prodTerm enum>: 'inv', <lambdaType enum>: 'dot', <lambdaTerm enum>: 'inv', <letInDef enum>: 'dot', <letInTerm enum>: 'inv', <letInType enum>: 'odot', <appFun enum>: 'dot', <appArg enum>: 'inv', <caseTerm enum>: 'dot', <caseReturn enum>: 'inv', <caseBranchPointer enum>: 'odot', <caseInd enum>: 'invdot', <cBConstruct enum>: 'dot', <cBTerm enum>: 'inv', <fixMutual enum>: 'dot', <fixReturn enum>: 'inv', <fixFunType enum>: 'dot', <fixFunTerm enum>: 'inv', <coFixMutual enum>: 'dot', <coFixReturn enum>: 'inv', <coFixFunType enum>: 'dot', <coFixFunTerm enum>: 'inv', <relPointer enum>: 'dot', <evarSubject enum>: 'dot', <evarSubstPointer enum>: 'inv', <evarSubstTerm enum>: 'dot', <evarSubstTarget enum>: 'inv'}
edge_arrow_map2 =
{0: 'dot', 1: 'inv', 2: 'dot', 3: 'inv', 4: 'dot', 5: 'inv', 6: 'odot', 7: 'invdot', 8: 'invodot', 9: 'dot', 10: 'inv', 11: 'odot', 12: 'dot', 13: 'dot', 15: 'dot', 14: 'inv', 16: 'dot', 17: 'inv', 18: 'dot', 19: 'inv', 20: 'dot', 22: 'inv', 21: 'odot', 23: 'dot', 24: 'inv', 25: 'dot', 26: 'inv', 27: 'odot', 28: 'invdot', 29: 'dot', 30: 'inv', 31: 'dot', 32: 'inv', 33: 'dot', 34: 'inv', 35: 'dot', 36: 'inv', 37: 'dot', 38: 'inv', 39: 'dot', 43: 'dot', 40: 'inv', 41: 'dot', 42: 'inv'}
22def visualize_proof_state(state: ProofState): 23 24 dot = graphviz.Digraph() 25 dot.attr('graph', ordering="out") 26 27 seen = set() 28 nodes_left = 100 29 30 def recurse(node: Node, depth): 31 nonlocal seen 32 nonlocal nodes_left 33 34 id = str(node) 35 if id in seen: 36 return id 37 seen.add(id) 38 nodes_left -= 1 39 if nodes_left < 0: 40 id = 'trunc' + str(nodes_left) 41 dot.node(id, 'truncated') 42 return id 43 44 if d := node.definition: 45 label = d.name 46 else: 47 label = inflection.camelize(str(node.label.which.name.lower())) 48 dot.node(id, label=label) 49 if node.definition: 50 depth -= 1 51 if depth >= 0: 52 for edge, child in node.children: 53 cid = recurse(child, depth) 54 dot.edge(id, cid, arrowtail=edge_arrow_map2[edge], dir="both") 55 return id 56 recurse(state.root, 0) 57 58 dot.render(filename='python_graph', view=False, cleanup=True)
def
visualize( graph, state, showLabel=False, graph1=None, filename='python_graph', cleanup=True):
60def visualize(graph, state, showLabel = False, graph1 = None, 61 filename='python_graph', cleanup=True): 62 nodes = graph.nodes 63 root = state.root 64 context = state.context 65 assert all(n < len(nodes) for n in context) 66 67 dot = graphviz.Digraph() 68 dot.attr('graph', ordering="out") 69 for node, value in enumerate(nodes): 70 label = value.label.which() 71 if node in context: label = str(node) + ': ['+label+']' 72 if node == root: 73 label = "Root: "+label 74 dot.node(str(node), label) 75 for node, value in enumerate(nodes): 76 for edge in list(graph.edges)[value.childrenIndex:value.childrenIndex+value.childrenCount]: 77 if graph1 != None and edge.target.depIndex == 1: 78 target = '1#'+str(edge.target.nodeIndex) 79 dot.node(target, "Global:" 80 + str(graph1.nodes[edge.target.nodeIndex].label.definition.name)) 81 else: 82 target = str(edge.target.nodeIndex) 83 if showLabel: 84 label = str(edge.label) 85 else: 86 label = "" 87 dot.edge(str(node), target, label=label, 88 arrowtail=edge_arrow_map[edge.label], dir="both") 89 90 dot.render(filename=filename, view=False, cleanup=cleanup)
def
visualize_defs(graph, defs, showLabel=False, filename='python_grapn', cleanup=True):
92def visualize_defs(graph, defs, showLabel=False, filename='python_grapn', cleanup=True): 93 nodes = graph.nodes 94 assert all(n < len(nodes) for n in defs) 95 dot = graphviz.Digraph() 96 dot.attr('graph', ordering="out") 97 for node, value in enumerate(nodes): 98 label = value.label.which() 99 if node in defs: label = str(node) + ': ' + value.label.definition.name 100 dot.node(str(node), label) 101 for node, value in enumerate(nodes): 102 for edge in list(graph.edges)[value.childrenIndex:value.childrenIndex+value.childrenCount]: 103 if showLabel: 104 label = str(edge.label) 105 else: 106 label = "" 107 dot.edge(str(node), str(edge.target.nodeIndex), label=label, 108 arrowtail=edge_arrow_map[edge.label], dir="both") 109 110 dot.render(filename=filename, view=False, cleanup=True)
def
visualize_exception(reason, filename='visualize_graph.pdf', cleanup=None):
class
Visualizer:
118class Visualizer: 119 def __init__(self, filename, count, show_labels, cleanup): 120 self.filename = filename 121 self.cnt = 0 122 self.count = count 123 self.show_labels = show_labels 124 self.cleanup = cleanup 125 def _visualize(self, filename, result): 126 if result.which() == 'newState': 127 visualize(result.newState.graph, result.newState.state, 128 filename=filename, showLabel=self.show_labels, cleanup=self.cleanup) 129 else: 130 visualize_exception(result, filename=self.filename, cleanup=self.cleanup) 131 132 def render(self, result): 133 filename = self.filename 134 if self.count: 135 filename += str(self.cnt) 136 self.cnt += 1 137 138 self._visualize(filename, result)