pytact.oracle_server
1from collections import defaultdict 2from dataclasses import dataclass 3from pathlib import Path 4import sys 5import socket 6import socketserver 7import argparse 8import contextlib 9from typing import Union, Tuple 10from pytact.data_reader import (data_reader, Original, capnp_message_generator, ProofState, 11 TacticPredictionGraph, TacticPredictionsGraph, 12 TacticPredictionText, TacticPredictionsText, 13 GlobalContextMessage, CheckAlignmentMessage, CheckAlignmentResponse) 14 15@dataclass(eq=True, frozen=True) 16class GlobalArgument: 17 identity: int 18@dataclass(eq=True, frozen=True) 19class LocalArgument: 20 context_index: int 21@dataclass(eq=True, frozen=True) 22class OracleTactic: 23 tactic_id: int 24 arguments: Tuple[Union[GlobalArgument, LocalArgument], ...] 25 clean: bool 26 27def text_prediction_loop(text_oracle_data, context: GlobalContextMessage): 28 prediction_requests = context.prediction_requests 29 for msg in prediction_requests: 30 if isinstance(msg, ProofState): 31 proof_state = msg 32 if proof_state.text in text_oracle_data: 33 preds = [TacticPredictionText(t, 1) for t in text_oracle_data[proof_state.text]] 34 else: 35 preds = [] 36 prediction_requests.send(TacticPredictionsText(preds)) 37 elif isinstance(msg, CheckAlignmentMessage): 38 alignment = CheckAlignmentResponse([], []) 39 prediction_requests.send(alignment) 40 elif isinstance(msg, GlobalContextMessage): 41 text_prediction_loop(text_oracle_data, msg) 42 else: 43 raise Exception("Capnp protocol error") 44 45def graph_prediction_loop(context: GlobalContextMessage, oracle_data, known_definitions, known_tactics): 46 available_tacticids = set([ t.ident for t in context.tactics ]) 47 available_definitions = { d.node.identity : d.node for d in context.definitions.definitions() } 48 prediction_requests = context.prediction_requests 49 for msg in prediction_requests: 50 if isinstance(msg, ProofState): 51 proof_state = msg 52 def resolve_arg(arg): 53 if isinstance(arg, LocalArgument): 54 return proof_state.context[arg.context_index] 55 elif isinstance(arg, GlobalArgument) and arg.identity in available_definitions: 56 return available_definitions[arg.identity] 57 else: 58 return None 59 possible_tactics = [ 60 TacticPredictionGraph(t.tactic_id, 61 [resolve_arg(arg) for arg in t.arguments], 62 1 if t.clean else 0.95) 63 for t in sorted(oracle_data[proof_state.root.identity], key = lambda t: not t.clean) 64 if t.tactic_id in available_tacticids and 65 all([resolve_arg(arg) is not None for arg in t.arguments])] 66 prediction_requests.send(TacticPredictionsGraph(possible_tactics)) 67 elif isinstance(msg, CheckAlignmentMessage): 68 unknown_definitions = [ d for d in context.definitions.definitions() 69 if d.node.identity not in known_definitions ] 70 unknown_tactics = [ t.ident for t in context.tactics 71 if t.ident not in known_tactics ] 72 alignment = CheckAlignmentResponse(unknown_definitions, unknown_tactics) 73 prediction_requests.send(alignment) 74 elif isinstance(msg, GlobalContextMessage): 75 graph_prediction_loop(msg, oracle_data, known_definitions, known_tactics) 76 else: 77 raise Exception("Capnp protocol error") 78 79def run_session(oracle_data, text_oracle_data, known_definitions, known_tactics, args, capnp_socket, record_file): 80 messages_generator = capnp_message_generator(capnp_socket, record_file) 81 if args.mode == 'text': 82 print('Python server running in text mode') 83 text_prediction_loop(text_oracle_data, messages_generator) 84 elif args.mode == 'graph': 85 print('Python server running in graph mode') 86 graph_prediction_loop(messages_generator, oracle_data, known_definitions, known_tactics) 87 else: 88 raise Exception("The 'mode' argument needs to be either 'text' or 'graph'") 89 90def main(): 91 sys.setrecursionlimit(10000) 92 parser = argparse.ArgumentParser( 93 description = 'A tactic prediction server acting as an oracle, retrieving it\'s information from a dataset', 94 formatter_class=argparse.ArgumentDefaultsHelpFormatter) 95 96 parser.add_argument('mode', 97 type=str, 98 choices=['graph', 'text'], 99 help='"graph" to communicate in graph-mode, "text" to communicate in text-mode') 100 parser.add_argument('dataset', 101 type=str, 102 help=('The location of the dataset from which to extract the oracle information. ' + 103 'Either a dataset directory, or a SquashFS image, ' + 104 'which will be automatically mounted.')) 105 parser.add_argument('--tcp', 106 dest='port', 107 type = int, 108 default = None, 109 help='Run in tcp mode instead of stdin mode on the specified port.') 110 parser.add_argument('--record', 111 dest="record_file", 112 type = str, 113 default = None, 114 help='Record all exchanged messages to the specified file, so that they can later be ' + 115 'replayed through "pytact-fake-coq"') 116 cmd_args = parser.parse_args() 117 118 print("Building oracle data...") 119 dataset_path = Path(cmd_args.dataset).resolve() 120 oracle_data = defaultdict(set) 121 text_oracle_data = defaultdict(set) 122 known_definitions = set() 123 known_tactics = set() 124 with data_reader(dataset_path) as data: 125 for datafile in data.values(): 126 for d in datafile.definitions(): 127 known_definitions.add(d.node.identity) 128 if proof := d.proof: 129 for step in proof: 130 for outcome in step.outcomes: 131 if outcome.tactic is None: 132 continue # If the tactic is unknown we are screwed 133 known_tactics.add(outcome.tactic.ident) 134 if not isinstance(d.status, Original): 135 continue # For an oracle, we are not interested in non-original proofs 136 if len(outcome.after) == 1: 137 if outcome.before.id == outcome.after[0].id: 138 continue # This tactic didn't do anything, we can ignore it 139 if outcome.before.root.identity == outcome.after[0].root.identity: 140 # This tactic did something, but very minimally, usually just an identity cast 141 continue 142 text_oracle_data[outcome.before.text].add(outcome.tactic.text_non_anonymous) 143 tactic_args = outcome.tactic_arguments 144 if any(arg is None for arg in tactic_args): 145 continue # If an argument is unknown we are screwed 146 args = [] 147 for arg in tactic_args: 148 if arg_def := arg.definition: 149 args.append(GlobalArgument(arg_def.node.identity)) 150 else: 151 args.append(LocalArgument(list(outcome.before.context).index(arg))) 152 oracle_tactic = OracleTactic(outcome.tactic.ident, tuple(args), 153 outcome.tactic.text == outcome.tactic.interm_text) 154 oracle_data[outcome.before.root.identity].add(oracle_tactic) 155 print("Oracle data built, ready for incoming connections") 156 157 if cmd_args.record_file is not None: 158 record_context = open(cmd_args.record_file, 'wb') 159 else: 160 record_context = contextlib.nullcontext() 161 with record_context as record_file: 162 if cmd_args.port is not None: 163 class Handler(socketserver.BaseRequestHandler): 164 def handle(self): 165 run_session(oracle_data, text_oracle_data, known_definitions, known_tactics, 166 cmd_args, self.request, record_file) 167 class Server(socketserver.ForkingTCPServer): 168 def __init__(self, *kwargs): 169 self.allow_reuse_address = True 170 self.daemon_threads = True 171 super().__init__(*kwargs) 172 addr = ('localhost', cmd_args.port) 173 with Server(addr, Handler) as server: 174 server.serve_forever() 175 else: 176 capnp_socket = socket.socket(fileno=sys.stdin.fileno()) 177 run_session(oracle_data, text_oracle_data, known_definitions, known_tactics, 178 cmd_args, capnp_socket, record_file) 179 180if __name__ == '__main__': 181 main()
@dataclass(eq=True, frozen=True)
class
GlobalArgument:
@dataclass(eq=True, frozen=True)
class
LocalArgument:
@dataclass(eq=True, frozen=True)
class
OracleTactic:
22@dataclass(eq=True, frozen=True) 23class OracleTactic: 24 tactic_id: int 25 arguments: Tuple[Union[GlobalArgument, LocalArgument], ...] 26 clean: bool
OracleTactic( tactic_id: int, arguments: Tuple[Union[GlobalArgument, LocalArgument], ...], clean: bool)
arguments: Tuple[Union[GlobalArgument, LocalArgument], ...]
28def text_prediction_loop(text_oracle_data, context: GlobalContextMessage): 29 prediction_requests = context.prediction_requests 30 for msg in prediction_requests: 31 if isinstance(msg, ProofState): 32 proof_state = msg 33 if proof_state.text in text_oracle_data: 34 preds = [TacticPredictionText(t, 1) for t in text_oracle_data[proof_state.text]] 35 else: 36 preds = [] 37 prediction_requests.send(TacticPredictionsText(preds)) 38 elif isinstance(msg, CheckAlignmentMessage): 39 alignment = CheckAlignmentResponse([], []) 40 prediction_requests.send(alignment) 41 elif isinstance(msg, GlobalContextMessage): 42 text_prediction_loop(text_oracle_data, msg) 43 else: 44 raise Exception("Capnp protocol error")
def
graph_prediction_loop( context: pytact.data_reader.GlobalContextMessage, oracle_data, known_definitions, known_tactics):
46def graph_prediction_loop(context: GlobalContextMessage, oracle_data, known_definitions, known_tactics): 47 available_tacticids = set([ t.ident for t in context.tactics ]) 48 available_definitions = { d.node.identity : d.node for d in context.definitions.definitions() } 49 prediction_requests = context.prediction_requests 50 for msg in prediction_requests: 51 if isinstance(msg, ProofState): 52 proof_state = msg 53 def resolve_arg(arg): 54 if isinstance(arg, LocalArgument): 55 return proof_state.context[arg.context_index] 56 elif isinstance(arg, GlobalArgument) and arg.identity in available_definitions: 57 return available_definitions[arg.identity] 58 else: 59 return None 60 possible_tactics = [ 61 TacticPredictionGraph(t.tactic_id, 62 [resolve_arg(arg) for arg in t.arguments], 63 1 if t.clean else 0.95) 64 for t in sorted(oracle_data[proof_state.root.identity], key = lambda t: not t.clean) 65 if t.tactic_id in available_tacticids and 66 all([resolve_arg(arg) is not None for arg in t.arguments])] 67 prediction_requests.send(TacticPredictionsGraph(possible_tactics)) 68 elif isinstance(msg, CheckAlignmentMessage): 69 unknown_definitions = [ d for d in context.definitions.definitions() 70 if d.node.identity not in known_definitions ] 71 unknown_tactics = [ t.ident for t in context.tactics 72 if t.ident not in known_tactics ] 73 alignment = CheckAlignmentResponse(unknown_definitions, unknown_tactics) 74 prediction_requests.send(alignment) 75 elif isinstance(msg, GlobalContextMessage): 76 graph_prediction_loop(msg, oracle_data, known_definitions, known_tactics) 77 else: 78 raise Exception("Capnp protocol error")
def
run_session( oracle_data, text_oracle_data, known_definitions, known_tactics, args, capnp_socket, record_file):
80def run_session(oracle_data, text_oracle_data, known_definitions, known_tactics, args, capnp_socket, record_file): 81 messages_generator = capnp_message_generator(capnp_socket, record_file) 82 if args.mode == 'text': 83 print('Python server running in text mode') 84 text_prediction_loop(text_oracle_data, messages_generator) 85 elif args.mode == 'graph': 86 print('Python server running in graph mode') 87 graph_prediction_loop(messages_generator, oracle_data, known_definitions, known_tactics) 88 else: 89 raise Exception("The 'mode' argument needs to be either 'text' or 'graph'")
def
main():
91def main(): 92 sys.setrecursionlimit(10000) 93 parser = argparse.ArgumentParser( 94 description = 'A tactic prediction server acting as an oracle, retrieving it\'s information from a dataset', 95 formatter_class=argparse.ArgumentDefaultsHelpFormatter) 96 97 parser.add_argument('mode', 98 type=str, 99 choices=['graph', 'text'], 100 help='"graph" to communicate in graph-mode, "text" to communicate in text-mode') 101 parser.add_argument('dataset', 102 type=str, 103 help=('The location of the dataset from which to extract the oracle information. ' + 104 'Either a dataset directory, or a SquashFS image, ' + 105 'which will be automatically mounted.')) 106 parser.add_argument('--tcp', 107 dest='port', 108 type = int, 109 default = None, 110 help='Run in tcp mode instead of stdin mode on the specified port.') 111 parser.add_argument('--record', 112 dest="record_file", 113 type = str, 114 default = None, 115 help='Record all exchanged messages to the specified file, so that they can later be ' + 116 'replayed through "pytact-fake-coq"') 117 cmd_args = parser.parse_args() 118 119 print("Building oracle data...") 120 dataset_path = Path(cmd_args.dataset).resolve() 121 oracle_data = defaultdict(set) 122 text_oracle_data = defaultdict(set) 123 known_definitions = set() 124 known_tactics = set() 125 with data_reader(dataset_path) as data: 126 for datafile in data.values(): 127 for d in datafile.definitions(): 128 known_definitions.add(d.node.identity) 129 if proof := d.proof: 130 for step in proof: 131 for outcome in step.outcomes: 132 if outcome.tactic is None: 133 continue # If the tactic is unknown we are screwed 134 known_tactics.add(outcome.tactic.ident) 135 if not isinstance(d.status, Original): 136 continue # For an oracle, we are not interested in non-original proofs 137 if len(outcome.after) == 1: 138 if outcome.before.id == outcome.after[0].id: 139 continue # This tactic didn't do anything, we can ignore it 140 if outcome.before.root.identity == outcome.after[0].root.identity: 141 # This tactic did something, but very minimally, usually just an identity cast 142 continue 143 text_oracle_data[outcome.before.text].add(outcome.tactic.text_non_anonymous) 144 tactic_args = outcome.tactic_arguments 145 if any(arg is None for arg in tactic_args): 146 continue # If an argument is unknown we are screwed 147 args = [] 148 for arg in tactic_args: 149 if arg_def := arg.definition: 150 args.append(GlobalArgument(arg_def.node.identity)) 151 else: 152 args.append(LocalArgument(list(outcome.before.context).index(arg))) 153 oracle_tactic = OracleTactic(outcome.tactic.ident, tuple(args), 154 outcome.tactic.text == outcome.tactic.interm_text) 155 oracle_data[outcome.before.root.identity].add(oracle_tactic) 156 print("Oracle data built, ready for incoming connections") 157 158 if cmd_args.record_file is not None: 159 record_context = open(cmd_args.record_file, 'wb') 160 else: 161 record_context = contextlib.nullcontext() 162 with record_context as record_file: 163 if cmd_args.port is not None: 164 class Handler(socketserver.BaseRequestHandler): 165 def handle(self): 166 run_session(oracle_data, text_oracle_data, known_definitions, known_tactics, 167 cmd_args, self.request, record_file) 168 class Server(socketserver.ForkingTCPServer): 169 def __init__(self, *kwargs): 170 self.allow_reuse_address = True 171 self.daemon_threads = True 172 super().__init__(*kwargs) 173 addr = ('localhost', cmd_args.port) 174 with Server(addr, Handler) as server: 175 server.serve_forever() 176 else: 177 capnp_socket = socket.socket(fileno=sys.stdin.fileno()) 178 run_session(oracle_data, text_oracle_data, known_definitions, known_tactics, 179 cmd_args, capnp_socket, record_file)