Edit on GitHub

pytact.scripts.lemma_distance

For each lemma in some test packages, calculate the number of definitions it depends on that are not in the training set

 1"""For each lemma in some test packages, calculate the number of definitions it
 2depends on that are not in the training set"""
 3
 4from pathlib import Path
 5import sys
 6from pytact.data_reader import data_reader, Original, definition_dependencies
 7
 8test_packages = set([
 9    "coq-bits.1.1.0",
10    "coq-qcert.2.2.0",
11    "coq-ceres.0.4.0",
12    "coq-corn.8.16.0",
13    "coq-bytestring.0.9.0",
14    "coq-hammer.1.3.2+8.11",
15    "coq-gaia-stern.1.15",
16    "coq-mathcomp-apery.1.0.1",
17    "coq-tlc.20200328",
18    "coq-iris-heap-lang.3.4.0",
19    "coq-printf.2.0.0",
20    "coq-smtcoq.2.0+8.11",
21    "coq-topology.10.0.1"
22    "coq-haskell.1.0.0",
23    "coq-bbv.1.3",
24    "coq-poltac.0.8.11",
25    "coq-mathcomp-odd-order.1.14.0",
26    "coq-hott.8.11"
27])
28
29def main():
30    dataset_path = Path(sys.argv[1]).resolve()
31    with data_reader(dataset_path) as data:
32
33        graphid_in_test = [d.filename.parts[0] in test_packages for d in sorted(data.values(), key=lambda d: d.graph)]
34        trans_deps = dict()
35
36        def calc_trans_deps(d):
37            if dist := trans_deps.get(d, None):
38                return dist
39            if not graphid_in_test[d.node.graph]:
40                trans_deps[d] = set()
41                return set()
42            dist = set()
43            dist.update(d.cluster)
44            direct_cluster_deps = { dep for c in d.cluster for dep in definition_dependencies(c) } - set(d.cluster)
45            for dep in direct_cluster_deps:
46                dist.update(calc_trans_deps(dep))
47            trans_deps[d] = dist
48            return dist
49
50        for f in data.values():
51            if f.filename.parts[0] not in test_packages:
52                continue
53            for d in f.definitions():
54                if not isinstance(d.status, Original):
55                    continue
56                if proof := d.proof:
57                    deps = calc_trans_deps(d) - set(d.cluster)
58                    print(f"{f.filename.parts[0]}\t{d.name}\t{len(deps)}")
59
60if __name__ == "__main__":
61    exit(main())
test_packages = {'coq-hott.8.11', 'coq-poltac.0.8.11', 'coq-smtcoq.2.0+8.11', 'coq-mathcomp-apery.1.0.1', 'coq-mathcomp-odd-order.1.14.0', 'coq-hammer.1.3.2+8.11', 'coq-printf.2.0.0', 'coq-qcert.2.2.0', 'coq-tlc.20200328', 'coq-bits.1.1.0', 'coq-corn.8.16.0', 'coq-iris-heap-lang.3.4.0', 'coq-ceres.0.4.0', 'coq-topology.10.0.1coq-haskell.1.0.0', 'coq-gaia-stern.1.15', 'coq-bbv.1.3', 'coq-bytestring.0.9.0'}
def main():
30def main():
31    dataset_path = Path(sys.argv[1]).resolve()
32    with data_reader(dataset_path) as data:
33
34        graphid_in_test = [d.filename.parts[0] in test_packages for d in sorted(data.values(), key=lambda d: d.graph)]
35        trans_deps = dict()
36
37        def calc_trans_deps(d):
38            if dist := trans_deps.get(d, None):
39                return dist
40            if not graphid_in_test[d.node.graph]:
41                trans_deps[d] = set()
42                return set()
43            dist = set()
44            dist.update(d.cluster)
45            direct_cluster_deps = { dep for c in d.cluster for dep in definition_dependencies(c) } - set(d.cluster)
46            for dep in direct_cluster_deps:
47                dist.update(calc_trans_deps(dep))
48            trans_deps[d] = dist
49            return dist
50
51        for f in data.values():
52            if f.filename.parts[0] not in test_packages:
53                continue
54            for d in f.definitions():
55                if not isinstance(d.status, Original):
56                    continue
57                if proof := d.proof:
58                    deps = calc_trans_deps(d) - set(d.cluster)
59                    print(f"{f.filename.parts[0]}\t{d.name}\t{len(deps)}")