  1from collections import defaultdict
  2from dataclasses import dataclass
  3from pathlib import Path
  4import sys
  5import socket
  6import socketserver
  7import argparse
  8import contextlib
  9from typing import Union, Tuple
 10from pytact.data_reader import (data_reader, Original, capnp_message_generator, ProofState,
 11                                TacticPredictionGraph, TacticPredictionsGraph,
 12                                TacticPredictionText, TacticPredictionsText,
 13                                GlobalContextMessage, CheckAlignmentMessage, CheckAlignmentResponse)
 15@dataclass(eq=True, frozen=True)
 16class GlobalArgument:
 17    identity: int
 18@dataclass(eq=True, frozen=True)
 19class LocalArgument:
 20    context_index: int
 21@dataclass(eq=True, frozen=True)
 22class OracleTactic:
 23    tactic_id: int
 24    arguments: Tuple[Union[GlobalArgument, LocalArgument], ...]
 25    clean: bool
 27def text_prediction_loop(text_oracle_data, context: GlobalContextMessage):
 28    prediction_requests = context.prediction_requests
 29    for msg in prediction_requests:
 30        if isinstance(msg, ProofState):
 31            proof_state = msg
 32            if proof_state.text in text_oracle_data:
 33                preds = [TacticPredictionText(t, 1) for t in text_oracle_data[proof_state.text]]
 34            else:
 35                preds = []
 36            prediction_requests.send(TacticPredictionsText(preds))
 37        elif isinstance(msg, CheckAlignmentMessage):
 38            alignment = CheckAlignmentResponse([], [])
 39            prediction_requests.send(alignment)
 40        elif isinstance(msg, GlobalContextMessage):
 41            text_prediction_loop(text_oracle_data, msg)
 42        else:
 43            raise Exception("Capnp protocol error")
 45def graph_prediction_loop(context: GlobalContextMessage, oracle_data, known_definitions, known_tactics):
 46    available_tacticids = set([ t.ident for t in context.tactics ])
 47    available_definitions = { d.node.identity : d.node for d in context.definitions.definitions() }
 48    prediction_requests = context.prediction_requests
 49    for msg in prediction_requests:
 50        if isinstance(msg, ProofState):
 51            proof_state = msg
 52            def resolve_arg(arg):
 53                if isinstance(arg, LocalArgument):
 54                    return proof_state.context[arg.context_index]
 55                elif isinstance(arg, GlobalArgument) and arg.identity in available_definitions:
 56                    return available_definitions[arg.identity]
 57                else:
 58                    return None
 59            possible_tactics = [
 60                TacticPredictionGraph(t.tactic_id,
 61                                    [resolve_arg(arg) for arg in t.arguments],
 62                                    1 if t.clean else 0.95)
 63                for t in sorted(oracle_data[proof_state.root.identity], key = lambda t: not t.clean)
 64                if t.tactic_id in available_tacticids and
 65                all([resolve_arg(arg) is not None for arg in t.arguments])]
 66            prediction_requests.send(TacticPredictionsGraph(possible_tactics))
 67        elif isinstance(msg, CheckAlignmentMessage):
 68            unknown_definitions = [ d for d in context.definitions.definitions()
 69                                    if d.node.identity not in known_definitions ]
 70            unknown_tactics = [ t.ident for t in context.tactics
 71                                if t.ident not in known_tactics ]
 72            alignment = CheckAlignmentResponse(unknown_definitions, unknown_tactics)
 73            prediction_requests.send(alignment)
 74        elif isinstance(msg, GlobalContextMessage):
 75            graph_prediction_loop(msg, oracle_data, known_definitions, known_tactics)
 76        else:
 77            raise Exception("Capnp protocol error")
 79def run_session(oracle_data, text_oracle_data, known_definitions, known_tactics, args, capnp_socket, record_file):
 80    messages_generator = capnp_message_generator(capnp_socket, record_file)
 81    if args.mode == 'text':
 82        print('Python server running in text mode')
 83        text_prediction_loop(text_oracle_data, messages_generator)
 84    elif args.mode == 'graph':
 85        print('Python server running in graph mode')
 86        graph_prediction_loop(messages_generator, oracle_data, known_definitions, known_tactics)
 87    else:
 88        raise Exception("The 'mode' argument needs to be either 'text' or 'graph'")
 90def main():
 91    sys.setrecursionlimit(10000)
 92    parser = argparse.ArgumentParser(
 93        description = 'A tactic prediction server acting as an oracle, retrieving it\'s information from a dataset',
 94        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
 96    parser.add_argument('mode',
 97                        type=str,
 98                        choices=['graph', 'text'],
 99                        help='"graph" to communicate in graph-mode, "text" to communicate in text-mode')
100    parser.add_argument('dataset',
101                        type=str,
102                        help=('The location of the dataset from which to extract the oracle information. ' +
103                              'Either a dataset directory, or a SquashFS image, ' +
104                              'which will be automatically mounted.'))
105    parser.add_argument('--tcp',
106                        dest='port',
107                        type = int,
108                        default = None,
109                        help='Run in tcp mode instead of stdin mode on the specified port.')
110    parser.add_argument('--record',
111                        dest="record_file",
112                        type = str,
113                        default = None,
114                        help='Record all exchanged messages to the specified file, so that they can later be ' +
115                        'replayed through "pytact-fake-coq"')
116    cmd_args = parser.parse_args()
118    print("Building oracle data...")
119    dataset_path = Path(cmd_args.dataset).resolve()
120    oracle_data = defaultdict(set)
121    text_oracle_data = defaultdict(set)
122    known_definitions = set()
123    known_tactics = set()
124    with data_reader(dataset_path) as data:
125        for datafile in data.values():
126            for d in datafile.definitions():
127                known_definitions.add(d.node.identity)
128                if proof := d.proof:
129                    for step in proof:
130                        for outcome in step.outcomes:
131                            if outcome.tactic is None:
132                                continue # If the tactic is unknown we are screwed
133                            known_tactics.add(outcome.tactic.ident)
134                            if not isinstance(d.status, Original):
135                                continue # For an oracle, we are not interested in non-original proofs
136                            if len(outcome.after) == 1:
137                                if outcome.before.id == outcome.after[0].id:
138                                    continue # This tactic didn't do anything, we can ignore it
139                                if outcome.before.root.identity == outcome.after[0].root.identity:
140                                    # This tactic did something, but very minimally, usually just an identity cast
141                                    continue
142                            text_oracle_data[outcome.before.text].add(outcome.tactic.text_non_anonymous)
143                            tactic_args = outcome.tactic_arguments
144                            if any(arg is None for arg in tactic_args):
145                                continue # If an argument is unknown we are screwed
146                            args = []
147                            for arg in tactic_args:
148                                if arg_def := arg.definition:
149                                    args.append(GlobalArgument(arg_def.node.identity))
150                                else:
151                                    args.append(LocalArgument(list(outcome.before.context).index(arg)))
152                            oracle_tactic = OracleTactic(outcome.tactic.ident, tuple(args),
153                                                         outcome.tactic.text == outcome.tactic.interm_text)
154                            oracle_data[outcome.before.root.identity].add(oracle_tactic)
155    print("Oracle data built, ready for incoming connections")
157    if cmd_args.record_file is not None:
158        record_context = open(cmd_args.record_file, 'wb')
159    else:
160        record_context = contextlib.nullcontext()
161    with record_context as record_file:
162        if cmd_args.port is not None:
163            class Handler(socketserver.BaseRequestHandler):
164                def handle(self):
165                    run_session(oracle_data, text_oracle_data, known_definitions, known_tactics,
166                                cmd_args, self.request, record_file)
167            class Server(socketserver.ForkingTCPServer):
168                def __init__(self, *kwargs):
169                    self.allow_reuse_address = True
170                    self.daemon_threads = True
171                    super().__init__(*kwargs)
172            addr = ('localhost', cmd_args.port)
173            with Server(addr, Handler) as server:
174                server.serve_forever()
175        else:
176            capnp_socket = socket.socket(fileno=sys.stdin.fileno())
177            run_session(oracle_data, text_oracle_data, known_definitions, known_tactics,
178                        cmd_args, capnp_socket, record_file)
180if __name__ == '__main__':
181    main()
