Edit on GitHub

pytact.oracle_server

  1from collections import defaultdict
  2from dataclasses import dataclass
  3from pathlib import Path
  4import sys
  5import socket
  6import socketserver
  7import argparse
  8import contextlib
  9from typing import Union, Tuple
 10from pytact.data_reader import (data_reader, Original, capnp_message_generator, ProofState,
 11                                TacticPredictionGraph, TacticPredictionsGraph,
 12                                TacticPredictionText, TacticPredictionsText,
 13                                GlobalContextMessage, CheckAlignmentMessage, CheckAlignmentResponse)
 14
 15@dataclass(eq=True, frozen=True)
 16class GlobalArgument:
 17    identity: int
 18@dataclass(eq=True, frozen=True)
 19class LocalArgument:
 20    context_index: int
 21@dataclass(eq=True, frozen=True)
 22class OracleTactic:
 23    tactic_id: int
 24    arguments: Tuple[Union[GlobalArgument, LocalArgument], ...]
 25    clean: bool
 26
 27def text_prediction_loop(text_oracle_data, context: GlobalContextMessage):
 28    prediction_requests = context.prediction_requests
 29    for msg in prediction_requests:
 30        if isinstance(msg, ProofState):
 31            proof_state = msg
 32            if proof_state.text in text_oracle_data:
 33                preds = [TacticPredictionText(t, 1) for t in text_oracle_data[proof_state.text]]
 34            else:
 35                preds = []
 36            prediction_requests.send(TacticPredictionsText(preds))
 37        elif isinstance(msg, CheckAlignmentMessage):
 38            alignment = CheckAlignmentResponse([], [])
 39            prediction_requests.send(alignment)
 40        elif isinstance(msg, GlobalContextMessage):
 41            text_prediction_loop(text_oracle_data, msg)
 42        else:
 43            raise Exception("Capnp protocol error")
 44
 45def graph_prediction_loop(context: GlobalContextMessage, oracle_data, known_definitions, known_tactics):
 46    available_tacticids = set([ t.ident for t in context.tactics ])
 47    available_definitions = { d.node.identity : d.node for d in context.definitions.definitions() }
 48    prediction_requests = context.prediction_requests
 49    for msg in prediction_requests:
 50        if isinstance(msg, ProofState):
 51            proof_state = msg
 52            def resolve_arg(arg):
 53                if isinstance(arg, LocalArgument):
 54                    return proof_state.context[arg.context_index]
 55                elif isinstance(arg, GlobalArgument) and arg.identity in available_definitions:
 56                    return available_definitions[arg.identity]
 57                else:
 58                    return None
 59            possible_tactics = [
 60                TacticPredictionGraph(t.tactic_id,
 61                                    [resolve_arg(arg) for arg in t.arguments],
 62                                    1 if t.clean else 0.95)
 63                for t in sorted(oracle_data[proof_state.root.identity], key = lambda t: not t.clean)
 64                if t.tactic_id in available_tacticids and
 65                all([resolve_arg(arg) is not None for arg in t.arguments])]
 66            prediction_requests.send(TacticPredictionsGraph(possible_tactics))
 67        elif isinstance(msg, CheckAlignmentMessage):
 68            unknown_definitions = [ d for d in context.definitions.definitions()
 69                                    if d.node.identity not in known_definitions ]
 70            unknown_tactics = [ t.ident for t in context.tactics
 71                                if t.ident not in known_tactics ]
 72            alignment = CheckAlignmentResponse(unknown_definitions, unknown_tactics)
 73            prediction_requests.send(alignment)
 74        elif isinstance(msg, GlobalContextMessage):
 75            graph_prediction_loop(msg, oracle_data, known_definitions, known_tactics)
 76        else:
 77            raise Exception("Capnp protocol error")
 78
 79def run_session(oracle_data, text_oracle_data, known_definitions, known_tactics, args, capnp_socket, record_file):
 80    messages_generator = capnp_message_generator(capnp_socket, record_file)
 81    if args.mode == 'text':
 82        print('Python server running in text mode')
 83        text_prediction_loop(text_oracle_data, messages_generator)
 84    elif args.mode == 'graph':
 85        print('Python server running in graph mode')
 86        graph_prediction_loop(messages_generator, oracle_data, known_definitions, known_tactics)
 87    else:
 88        raise Exception("The 'mode' argument needs to be either 'text' or 'graph'")
 89
 90def main():
 91    sys.setrecursionlimit(10000)
 92    parser = argparse.ArgumentParser(
 93        description = 'A tactic prediction server acting as an oracle, retrieving it\'s information from a dataset',
 94        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
 95
 96    parser.add_argument('mode',
 97                        type=str,
 98                        choices=['graph', 'text'],
 99                        help='"graph" to communicate in graph-mode, "text" to communicate in text-mode')
100    parser.add_argument('dataset',
101                        type=str,
102                        help=('The location of the dataset from which to extract the oracle information. ' +
103                              'Either a dataset directory, or a SquashFS image, ' +
104                              'which will be automatically mounted.'))
105    parser.add_argument('--tcp',
106                        dest='port',
107                        type = int,
108                        default = None,
109                        help='Run in tcp mode instead of stdin mode on the specified port.')
110    parser.add_argument('--record',
111                        dest="record_file",
112                        type = str,
113                        default = None,
114                        help='Record all exchanged messages to the specified file, so that they can later be ' +
115                        'replayed through "pytact-fake-coq"')
116    cmd_args = parser.parse_args()
117
118    print("Building oracle data...")
119    dataset_path = Path(cmd_args.dataset).resolve()
120    oracle_data = defaultdict(set)
121    text_oracle_data = defaultdict(set)
122    known_definitions = set()
123    known_tactics = set()
124    with data_reader(dataset_path) as data:
125        for datafile in data.values():
126            for d in datafile.definitions():
127                known_definitions.add(d.node.identity)
128                if proof := d.proof:
129                    for step in proof:
130                        for outcome in step.outcomes:
131                            if outcome.tactic is None:
132                                continue # If the tactic is unknown we are screwed
133                            known_tactics.add(outcome.tactic.ident)
134                            if not isinstance(d.status, Original):
135                                continue # For an oracle, we are not interested in non-original proofs
136                            if len(outcome.after) == 1:
137                                if outcome.before.id == outcome.after[0].id:
138                                    continue # This tactic didn't do anything, we can ignore it
139                                if outcome.before.root.identity == outcome.after[0].root.identity:
140                                    # This tactic did something, but very minimally, usually just an identity cast
141                                    continue
142                            text_oracle_data[outcome.before.text].add(outcome.tactic.text_non_anonymous)
143                            tactic_args = outcome.tactic_arguments
144                            if any(arg is None for arg in tactic_args):
145                                continue # If an argument is unknown we are screwed
146                            args = []
147                            for arg in tactic_args:
148                                if arg_def := arg.definition:
149                                    args.append(GlobalArgument(arg_def.node.identity))
150                                else:
151                                    args.append(LocalArgument(list(outcome.before.context).index(arg)))
152                            oracle_tactic = OracleTactic(outcome.tactic.ident, tuple(args),
153                                                         outcome.tactic.text == outcome.tactic.interm_text)
154                            oracle_data[outcome.before.root.identity].add(oracle_tactic)
155    print("Oracle data built, ready for incoming connections")
156
157    if cmd_args.record_file is not None:
158        record_context = open(cmd_args.record_file, 'wb')
159    else:
160        record_context = contextlib.nullcontext()
161    with record_context as record_file:
162        if cmd_args.port is not None:
163            class Handler(socketserver.BaseRequestHandler):
164                def handle(self):
165                    run_session(oracle_data, text_oracle_data, known_definitions, known_tactics,
166                                cmd_args, self.request, record_file)
167            class Server(socketserver.ForkingTCPServer):
168                def __init__(self, *kwargs):
169                    self.allow_reuse_address = True
170                    self.daemon_threads = True
171                    super().__init__(*kwargs)
172            addr = ('localhost', cmd_args.port)
173            with Server(addr, Handler) as server:
174                server.serve_forever()
175        else:
176            capnp_socket = socket.socket(fileno=sys.stdin.fileno())
177            run_session(oracle_data, text_oracle_data, known_definitions, known_tactics,
178                        cmd_args, capnp_socket, record_file)
179
180if __name__ == '__main__':
181    main()
@dataclass(eq=True, frozen=True)
class GlobalArgument:
16@dataclass(eq=True, frozen=True)
17class GlobalArgument:
18    identity: int
GlobalArgument(identity: int)
identity: int
@dataclass(eq=True, frozen=True)
class LocalArgument:
19@dataclass(eq=True, frozen=True)
20class LocalArgument:
21    context_index: int
LocalArgument(context_index: int)
context_index: int
@dataclass(eq=True, frozen=True)
class OracleTactic:
22@dataclass(eq=True, frozen=True)
23class OracleTactic:
24    tactic_id: int
25    arguments: Tuple[Union[GlobalArgument, LocalArgument], ...]
26    clean: bool
OracleTactic( tactic_id: int, arguments: Tuple[Union[GlobalArgument, LocalArgument], ...], clean: bool)
tactic_id: int
arguments: Tuple[Union[GlobalArgument, LocalArgument], ...]
clean: bool
def text_prediction_loop(text_oracle_data, context: pytact.data_reader.GlobalContextMessage):
28def text_prediction_loop(text_oracle_data, context: GlobalContextMessage):
29    prediction_requests = context.prediction_requests
30    for msg in prediction_requests:
31        if isinstance(msg, ProofState):
32            proof_state = msg
33            if proof_state.text in text_oracle_data:
34                preds = [TacticPredictionText(t, 1) for t in text_oracle_data[proof_state.text]]
35            else:
36                preds = []
37            prediction_requests.send(TacticPredictionsText(preds))
38        elif isinstance(msg, CheckAlignmentMessage):
39            alignment = CheckAlignmentResponse([], [])
40            prediction_requests.send(alignment)
41        elif isinstance(msg, GlobalContextMessage):
42            text_prediction_loop(text_oracle_data, msg)
43        else:
44            raise Exception("Capnp protocol error")
def graph_prediction_loop( context: pytact.data_reader.GlobalContextMessage, oracle_data, known_definitions, known_tactics):
46def graph_prediction_loop(context: GlobalContextMessage, oracle_data, known_definitions, known_tactics):
47    available_tacticids = set([ t.ident for t in context.tactics ])
48    available_definitions = { d.node.identity : d.node for d in context.definitions.definitions() }
49    prediction_requests = context.prediction_requests
50    for msg in prediction_requests:
51        if isinstance(msg, ProofState):
52            proof_state = msg
53            def resolve_arg(arg):
54                if isinstance(arg, LocalArgument):
55                    return proof_state.context[arg.context_index]
56                elif isinstance(arg, GlobalArgument) and arg.identity in available_definitions:
57                    return available_definitions[arg.identity]
58                else:
59                    return None
60            possible_tactics = [
61                TacticPredictionGraph(t.tactic_id,
62                                    [resolve_arg(arg) for arg in t.arguments],
63                                    1 if t.clean else 0.95)
64                for t in sorted(oracle_data[proof_state.root.identity], key = lambda t: not t.clean)
65                if t.tactic_id in available_tacticids and
66                all([resolve_arg(arg) is not None for arg in t.arguments])]
67            prediction_requests.send(TacticPredictionsGraph(possible_tactics))
68        elif isinstance(msg, CheckAlignmentMessage):
69            unknown_definitions = [ d for d in context.definitions.definitions()
70                                    if d.node.identity not in known_definitions ]
71            unknown_tactics = [ t.ident for t in context.tactics
72                                if t.ident not in known_tactics ]
73            alignment = CheckAlignmentResponse(unknown_definitions, unknown_tactics)
74            prediction_requests.send(alignment)
75        elif isinstance(msg, GlobalContextMessage):
76            graph_prediction_loop(msg, oracle_data, known_definitions, known_tactics)
77        else:
78            raise Exception("Capnp protocol error")
def run_session( oracle_data, text_oracle_data, known_definitions, known_tactics, args, capnp_socket, record_file):
80def run_session(oracle_data, text_oracle_data, known_definitions, known_tactics, args, capnp_socket, record_file):
81    messages_generator = capnp_message_generator(capnp_socket, record_file)
82    if args.mode == 'text':
83        print('Python server running in text mode')
84        text_prediction_loop(text_oracle_data, messages_generator)
85    elif args.mode == 'graph':
86        print('Python server running in graph mode')
87        graph_prediction_loop(messages_generator, oracle_data, known_definitions, known_tactics)
88    else:
89        raise Exception("The 'mode' argument needs to be either 'text' or 'graph'")
def main():
 91def main():
 92    sys.setrecursionlimit(10000)
 93    parser = argparse.ArgumentParser(
 94        description = 'A tactic prediction server acting as an oracle, retrieving it\'s information from a dataset',
 95        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
 96
 97    parser.add_argument('mode',
 98                        type=str,
 99                        choices=['graph', 'text'],
100                        help='"graph" to communicate in graph-mode, "text" to communicate in text-mode')
101    parser.add_argument('dataset',
102                        type=str,
103                        help=('The location of the dataset from which to extract the oracle information. ' +
104                              'Either a dataset directory, or a SquashFS image, ' +
105                              'which will be automatically mounted.'))
106    parser.add_argument('--tcp',
107                        dest='port',
108                        type = int,
109                        default = None,
110                        help='Run in tcp mode instead of stdin mode on the specified port.')
111    parser.add_argument('--record',
112                        dest="record_file",
113                        type = str,
114                        default = None,
115                        help='Record all exchanged messages to the specified file, so that they can later be ' +
116                        'replayed through "pytact-fake-coq"')
117    cmd_args = parser.parse_args()
118
119    print("Building oracle data...")
120    dataset_path = Path(cmd_args.dataset).resolve()
121    oracle_data = defaultdict(set)
122    text_oracle_data = defaultdict(set)
123    known_definitions = set()
124    known_tactics = set()
125    with data_reader(dataset_path) as data:
126        for datafile in data.values():
127            for d in datafile.definitions():
128                known_definitions.add(d.node.identity)
129                if proof := d.proof:
130                    for step in proof:
131                        for outcome in step.outcomes:
132                            if outcome.tactic is None:
133                                continue # If the tactic is unknown we are screwed
134                            known_tactics.add(outcome.tactic.ident)
135                            if not isinstance(d.status, Original):
136                                continue # For an oracle, we are not interested in non-original proofs
137                            if len(outcome.after) == 1:
138                                if outcome.before.id == outcome.after[0].id:
139                                    continue # This tactic didn't do anything, we can ignore it
140                                if outcome.before.root.identity == outcome.after[0].root.identity:
141                                    # This tactic did something, but very minimally, usually just an identity cast
142                                    continue
143                            text_oracle_data[outcome.before.text].add(outcome.tactic.text_non_anonymous)
144                            tactic_args = outcome.tactic_arguments
145                            if any(arg is None for arg in tactic_args):
146                                continue # If an argument is unknown we are screwed
147                            args = []
148                            for arg in tactic_args:
149                                if arg_def := arg.definition:
150                                    args.append(GlobalArgument(arg_def.node.identity))
151                                else:
152                                    args.append(LocalArgument(list(outcome.before.context).index(arg)))
153                            oracle_tactic = OracleTactic(outcome.tactic.ident, tuple(args),
154                                                         outcome.tactic.text == outcome.tactic.interm_text)
155                            oracle_data[outcome.before.root.identity].add(oracle_tactic)
156    print("Oracle data built, ready for incoming connections")
157
158    if cmd_args.record_file is not None:
159        record_context = open(cmd_args.record_file, 'wb')
160    else:
161        record_context = contextlib.nullcontext()
162    with record_context as record_file:
163        if cmd_args.port is not None:
164            class Handler(socketserver.BaseRequestHandler):
165                def handle(self):
166                    run_session(oracle_data, text_oracle_data, known_definitions, known_tactics,
167                                cmd_args, self.request, record_file)
168            class Server(socketserver.ForkingTCPServer):
169                def __init__(self, *kwargs):
170                    self.allow_reuse_address = True
171                    self.daemon_threads = True
172                    super().__init__(*kwargs)
173            addr = ('localhost', cmd_args.port)
174            with Server(addr, Handler) as server:
175                server.serve_forever()
176        else:
177            capnp_socket = socket.socket(fileno=sys.stdin.fileno())
178            run_session(oracle_data, text_oracle_data, known_definitions, known_tactics,
179                        cmd_args, capnp_socket, record_file)