Edit on GitHub

pytact.graph_visualize

  1import inflection
  2import graphviz
  3from pytact.data_reader import ProofState, Node
  4
  5# Load the cap'n proto library, and the communication specification in 'graph_api.capnp'
  6import capnp
  7import pytact.graph_api_capnp as graph_api_capnp
  8
  9# TODO: Clean up and unify all the functions below
 10
 11arrow_heads = [ "dot", "inv", "odot", "invdot", "invodot" ]
 12edge_arrow_map = {}
 13for group in graph_api_capnp.groupedEdges:
 14    count = 0
 15    for sort in group.conflatable:
 16        edge_arrow_map[sort] = arrow_heads[count]
 17        count += 1
 18
 19edge_arrow_map2 = {e.raw : arrow for (e, arrow) in edge_arrow_map.items()}
 20
 21def visualize_proof_state(state: ProofState):
 22
 23    dot = graphviz.Digraph()
 24    dot.attr('graph', ordering="out")
 25
 26    seen = set()
 27    nodes_left = 100
 28
 29    def recurse(node: Node, depth):
 30        nonlocal seen
 31        nonlocal nodes_left
 32
 33        id = str(node)
 34        if id in seen:
 35            return id
 36        seen.add(id)
 37        nodes_left -= 1
 38        if nodes_left < 0:
 39            id = 'trunc' + str(nodes_left)
 40            dot.node(id, 'truncated')
 41            return id
 42
 43        if d := node.definition:
 44            label = d.name
 45        else:
 46            label = inflection.camelize(str(node.label.which.name.lower()))
 47        dot.node(id, label=label)
 48        if node.definition:
 49            depth -= 1
 50        if depth >= 0:
 51            for edge, child in node.children:
 52                cid = recurse(child, depth)
 53                dot.edge(id, cid, arrowtail=edge_arrow_map2[edge], dir="both")
 54        return id
 55    recurse(state.root, 0)
 56
 57    dot.render(filename='python_graph', view=False, cleanup=True)
 58
 59def visualize(graph, state, showLabel = False, graph1 = None,
 60              filename='python_graph', cleanup=True):
 61    nodes = graph.nodes
 62    root = state.root
 63    context = state.context
 64    assert all(n < len(nodes) for n in context)
 65
 66    dot = graphviz.Digraph()
 67    dot.attr('graph', ordering="out")
 68    for node, value in enumerate(nodes):
 69        label = value.label.which()
 70        if node in context: label = str(node) + ': ['+label+']'
 71        if node == root:
 72            label = "Root: "+label
 73        dot.node(str(node), label)
 74    for node, value in enumerate(nodes):
 75        for edge in list(graph.edges)[value.childrenIndex:value.childrenIndex+value.childrenCount]:
 76            if graph1 != None and edge.target.depIndex == 1:
 77                target = '1#'+str(edge.target.nodeIndex)
 78                dot.node(target, "Global:"
 79                         + str(graph1.nodes[edge.target.nodeIndex].label.definition.name))
 80            else:
 81                target = str(edge.target.nodeIndex)
 82            if showLabel:
 83                label = str(edge.label)
 84            else:
 85                label = ""
 86            dot.edge(str(node), target, label=label,
 87                     arrowtail=edge_arrow_map[edge.label], dir="both")
 88
 89    dot.render(filename=filename, view=False, cleanup=cleanup)
 90
 91def visualize_defs(graph, defs, showLabel=False, filename='python_grapn', cleanup=True):
 92    nodes = graph.nodes
 93    assert all(n < len(nodes) for n in defs)
 94    dot = graphviz.Digraph()
 95    dot.attr('graph', ordering="out")
 96    for node, value in enumerate(nodes):
 97        label = value.label.which()
 98        if node in defs: label = str(node) + ': ' + value.label.definition.name
 99        dot.node(str(node), label)
100    for node, value in enumerate(nodes):
101        for edge in list(graph.edges)[value.childrenIndex:value.childrenIndex+value.childrenCount]:
102            if showLabel:
103                label = str(edge.label)
104            else:
105                label = ""
106            dot.edge(str(node), str(edge.target.nodeIndex), label=label,
107                     arrowtail=edge_arrow_map[edge.label], dir="both")
108
109    dot.render(filename=filename, view=False, cleanup=True)
110
111def visualize_exception(reason, filename='visualize_graph.pdf', cleanup=None):
112    dot = graphviz.Digraph()
113    dot.node(str(reason), str(reason))
114    dot.render(filename, view=False, cleanup=cleanup)
115
116
117class Visualizer:
118    def __init__(self, filename, count, show_labels, cleanup):
119        self.filename = filename
120        self.cnt = 0
121        self.count = count
122        self.show_labels = show_labels
123        self.cleanup = cleanup
124    def _visualize(self, filename, result):
125        if result.which() == 'newState':
126            visualize(result.newState.graph, result.newState.state,
127                         filename=filename, showLabel=self.show_labels, cleanup=self.cleanup)
128        else:
129            visualize_exception(result, filename=self.filename, cleanup=self.cleanup)
130
131    def render(self, result):
132        filename = self.filename
133        if self.count:
134            filename += str(self.cnt)
135            self.cnt += 1
136
137        self._visualize(filename, result)
arrow_heads = ['dot', 'inv', 'odot', 'invdot', 'invodot']
edge_arrow_map = {<contextElem enum>: 'dot', <contextSubject enum>: 'inv', <contextDefType enum>: 'dot', <contextDefTerm enum>: 'inv', <constType enum>: 'dot', <constUndef enum>: 'inv', <constDef enum>: 'odot', <constOpaqueDef enum>: 'invdot', <constPrimitive enum>: 'invodot', <indType enum>: 'dot', <indConstruct enum>: 'inv', <indProjection enum>: 'odot', <projTerm enum>: 'dot', <constructTerm enum>: 'dot', <castType enum>: 'dot', <castTerm enum>: 'inv', <prodType enum>: 'dot', <prodTerm enum>: 'inv', <lambdaType enum>: 'dot', <lambdaTerm enum>: 'inv', <letInDef enum>: 'dot', <letInTerm enum>: 'inv', <letInType enum>: 'odot', <appFun enum>: 'dot', <appArg enum>: 'inv', <caseTerm enum>: 'dot', <caseReturn enum>: 'inv', <caseBranchPointer enum>: 'odot', <caseInd enum>: 'invdot', <cBConstruct enum>: 'dot', <cBTerm enum>: 'inv', <fixMutual enum>: 'dot', <fixReturn enum>: 'inv', <fixFunType enum>: 'dot', <fixFunTerm enum>: 'inv', <coFixMutual enum>: 'dot', <coFixReturn enum>: 'inv', <coFixFunType enum>: 'dot', <coFixFunTerm enum>: 'inv', <relPointer enum>: 'dot', <evarSubject enum>: 'dot', <evarSubstPointer enum>: 'inv', <evarSubstTerm enum>: 'dot', <evarSubstTarget enum>: 'inv'}
edge_arrow_map2 = {0: 'dot', 1: 'inv', 2: 'dot', 3: 'inv', 4: 'dot', 5: 'inv', 6: 'odot', 7: 'invdot', 8: 'invodot', 9: 'dot', 10: 'inv', 11: 'odot', 12: 'dot', 13: 'dot', 15: 'dot', 14: 'inv', 16: 'dot', 17: 'inv', 18: 'dot', 19: 'inv', 20: 'dot', 22: 'inv', 21: 'odot', 23: 'dot', 24: 'inv', 25: 'dot', 26: 'inv', 27: 'odot', 28: 'invdot', 29: 'dot', 30: 'inv', 31: 'dot', 32: 'inv', 33: 'dot', 34: 'inv', 35: 'dot', 36: 'inv', 37: 'dot', 38: 'inv', 39: 'dot', 43: 'dot', 40: 'inv', 41: 'dot', 42: 'inv'}
def visualize_proof_state(state: pytact.data_reader.ProofState):
22def visualize_proof_state(state: ProofState):
23
24    dot = graphviz.Digraph()
25    dot.attr('graph', ordering="out")
26
27    seen = set()
28    nodes_left = 100
29
30    def recurse(node: Node, depth):
31        nonlocal seen
32        nonlocal nodes_left
33
34        id = str(node)
35        if id in seen:
36            return id
37        seen.add(id)
38        nodes_left -= 1
39        if nodes_left < 0:
40            id = 'trunc' + str(nodes_left)
41            dot.node(id, 'truncated')
42            return id
43
44        if d := node.definition:
45            label = d.name
46        else:
47            label = inflection.camelize(str(node.label.which.name.lower()))
48        dot.node(id, label=label)
49        if node.definition:
50            depth -= 1
51        if depth >= 0:
52            for edge, child in node.children:
53                cid = recurse(child, depth)
54                dot.edge(id, cid, arrowtail=edge_arrow_map2[edge], dir="both")
55        return id
56    recurse(state.root, 0)
57
58    dot.render(filename='python_graph', view=False, cleanup=True)
def visualize( graph, state, showLabel=False, graph1=None, filename='python_graph', cleanup=True):
60def visualize(graph, state, showLabel = False, graph1 = None,
61              filename='python_graph', cleanup=True):
62    nodes = graph.nodes
63    root = state.root
64    context = state.context
65    assert all(n < len(nodes) for n in context)
66
67    dot = graphviz.Digraph()
68    dot.attr('graph', ordering="out")
69    for node, value in enumerate(nodes):
70        label = value.label.which()
71        if node in context: label = str(node) + ': ['+label+']'
72        if node == root:
73            label = "Root: "+label
74        dot.node(str(node), label)
75    for node, value in enumerate(nodes):
76        for edge in list(graph.edges)[value.childrenIndex:value.childrenIndex+value.childrenCount]:
77            if graph1 != None and edge.target.depIndex == 1:
78                target = '1#'+str(edge.target.nodeIndex)
79                dot.node(target, "Global:"
80                         + str(graph1.nodes[edge.target.nodeIndex].label.definition.name))
81            else:
82                target = str(edge.target.nodeIndex)
83            if showLabel:
84                label = str(edge.label)
85            else:
86                label = ""
87            dot.edge(str(node), target, label=label,
88                     arrowtail=edge_arrow_map[edge.label], dir="both")
89
90    dot.render(filename=filename, view=False, cleanup=cleanup)
def visualize_defs(graph, defs, showLabel=False, filename='python_grapn', cleanup=True):
 92def visualize_defs(graph, defs, showLabel=False, filename='python_grapn', cleanup=True):
 93    nodes = graph.nodes
 94    assert all(n < len(nodes) for n in defs)
 95    dot = graphviz.Digraph()
 96    dot.attr('graph', ordering="out")
 97    for node, value in enumerate(nodes):
 98        label = value.label.which()
 99        if node in defs: label = str(node) + ': ' + value.label.definition.name
100        dot.node(str(node), label)
101    for node, value in enumerate(nodes):
102        for edge in list(graph.edges)[value.childrenIndex:value.childrenIndex+value.childrenCount]:
103            if showLabel:
104                label = str(edge.label)
105            else:
106                label = ""
107            dot.edge(str(node), str(edge.target.nodeIndex), label=label,
108                     arrowtail=edge_arrow_map[edge.label], dir="both")
109
110    dot.render(filename=filename, view=False, cleanup=True)
def visualize_exception(reason, filename='visualize_graph.pdf', cleanup=None):
112def visualize_exception(reason, filename='visualize_graph.pdf', cleanup=None):
113    dot = graphviz.Digraph()
114    dot.node(str(reason), str(reason))
115    dot.render(filename, view=False, cleanup=cleanup)
class Visualizer:
118class Visualizer:
119    def __init__(self, filename, count, show_labels, cleanup):
120        self.filename = filename
121        self.cnt = 0
122        self.count = count
123        self.show_labels = show_labels
124        self.cleanup = cleanup
125    def _visualize(self, filename, result):
126        if result.which() == 'newState':
127            visualize(result.newState.graph, result.newState.state,
128                         filename=filename, showLabel=self.show_labels, cleanup=self.cleanup)
129        else:
130            visualize_exception(result, filename=self.filename, cleanup=self.cleanup)
131
132    def render(self, result):
133        filename = self.filename
134        if self.count:
135            filename += str(self.cnt)
136            self.cnt += 1
137
138        self._visualize(filename, result)
Visualizer(filename, count, show_labels, cleanup)
119    def __init__(self, filename, count, show_labels, cleanup):
120        self.filename = filename
121        self.cnt = 0
122        self.count = count
123        self.show_labels = show_labels
124        self.cleanup = cleanup
filename
cnt
count
show_labels
cleanup
def render(self, result):
132    def render(self, result):
133        filename = self.filename
134        if self.count:
135            filename += str(self.cnt)
136            self.cnt += 1
137
138        self._visualize(filename, result)