pytact.scripts.lemma_distance
For each lemma in some test packages, calculate the number of definitions it depends on that are not in the training set
1"""For each lemma in some test packages, calculate the number of definitions it 2depends on that are not in the training set""" 3 4from pathlib import Path 5import sys 6from pytact.data_reader import data_reader, Original, definition_dependencies 7 8test_packages = set([ 9 "coq-bits.1.1.0", 10 "coq-qcert.2.2.0", 11 "coq-ceres.0.4.0", 12 "coq-corn.8.16.0", 13 "coq-bytestring.0.9.0", 14 "coq-hammer.1.3.2+8.11", 15 "coq-gaia-stern.1.15", 16 "coq-mathcomp-apery.1.0.1", 17 "coq-tlc.20200328", 18 "coq-iris-heap-lang.3.4.0", 19 "coq-printf.2.0.0", 20 "coq-smtcoq.2.0+8.11", 21 "coq-topology.10.0.1" 22 "coq-haskell.1.0.0", 23 "coq-bbv.1.3", 24 "coq-poltac.0.8.11", 25 "coq-mathcomp-odd-order.1.14.0", 26 "coq-hott.8.11" 27]) 28 29def main(): 30 dataset_path = Path(sys.argv[1]).resolve() 31 with data_reader(dataset_path) as data: 32 33 graphid_in_test = [d.filename.parts[0] in test_packages for d in sorted(data.values(), key=lambda d: d.graph)] 34 trans_deps = dict() 35 36 def calc_trans_deps(d): 37 if dist := trans_deps.get(d, None): 38 return dist 39 if not graphid_in_test[d.node.graph]: 40 trans_deps[d] = set() 41 return set() 42 dist = set() 43 dist.update(d.cluster) 44 direct_cluster_deps = { dep for c in d.cluster for dep in definition_dependencies(c) } - set(d.cluster) 45 for dep in direct_cluster_deps: 46 dist.update(calc_trans_deps(dep)) 47 trans_deps[d] = dist 48 return dist 49 50 for f in data.values(): 51 if f.filename.parts[0] not in test_packages: 52 continue 53 for d in f.definitions(): 54 if not isinstance(d.status, Original): 55 continue 56 if proof := d.proof: 57 deps = calc_trans_deps(d) - set(d.cluster) 58 print(f"{f.filename.parts[0]}\t{d.name}\t{len(deps)}") 59 60if __name__ == "__main__": 61 exit(main())
test_packages =
{'coq-hott.8.11', 'coq-poltac.0.8.11', 'coq-smtcoq.2.0+8.11', 'coq-mathcomp-apery.1.0.1', 'coq-mathcomp-odd-order.1.14.0', 'coq-hammer.1.3.2+8.11', 'coq-printf.2.0.0', 'coq-qcert.2.2.0', 'coq-tlc.20200328', 'coq-bits.1.1.0', 'coq-corn.8.16.0', 'coq-iris-heap-lang.3.4.0', 'coq-ceres.0.4.0', 'coq-topology.10.0.1coq-haskell.1.0.0', 'coq-gaia-stern.1.15', 'coq-bbv.1.3', 'coq-bytestring.0.9.0'}
def
main():
30def main(): 31 dataset_path = Path(sys.argv[1]).resolve() 32 with data_reader(dataset_path) as data: 33 34 graphid_in_test = [d.filename.parts[0] in test_packages for d in sorted(data.values(), key=lambda d: d.graph)] 35 trans_deps = dict() 36 37 def calc_trans_deps(d): 38 if dist := trans_deps.get(d, None): 39 return dist 40 if not graphid_in_test[d.node.graph]: 41 trans_deps[d] = set() 42 return set() 43 dist = set() 44 dist.update(d.cluster) 45 direct_cluster_deps = { dep for c in d.cluster for dep in definition_dependencies(c) } - set(d.cluster) 46 for dep in direct_cluster_deps: 47 dist.update(calc_trans_deps(dep)) 48 trans_deps[d] = dist 49 return dist 50 51 for f in data.values(): 52 if f.filename.parts[0] not in test_packages: 53 continue 54 for d in f.definitions(): 55 if not isinstance(d.status, Original): 56 continue 57 if proof := d.proof: 58 deps = calc_trans_deps(d) - set(d.cluster) 59 print(f"{f.filename.parts[0]}\t{d.name}\t{len(deps)}")