|
| 1 | +import os |
| 2 | +from importlib import import_module |
| 3 | +from types import ModuleType |
| 4 | + |
| 5 | +import pytest |
| 6 | + |
| 7 | +from recirq.algorithmic_benchmark_library import BENCHMARKS, get_all_algo_configs, workflow |
| 8 | + |
| 9 | +RECIRQ_DIR = os.path.abspath(os.path.dirname(__file__) + '/../') |
| 10 | + |
| 11 | +if not workflow: |
| 12 | + pytestmark = pytest.mark.skip('algorithmic_benchmark_library requires pre-release of Cirq.') |
| 13 | + |
| 14 | + |
| 15 | +@pytest.mark.parametrize('algo', BENCHMARKS) |
| 16 | +def test_domain(algo): |
| 17 | + # By convention, the domain should be a recirq module. |
| 18 | + assert algo.domain.startswith('recirq.'), 'domain should be a recirq module.' |
| 19 | + mod = import_module(algo.domain) |
| 20 | + assert isinstance(mod, ModuleType), 'domain should be a recirq module.' |
| 21 | + |
| 22 | + |
| 23 | +def test_benchmark_name_unique_in_domain(): |
| 24 | + # In a given domain, all benchmark names should be unique |
| 25 | + pairs = [(algo.domain, algo.name) for algo in BENCHMARKS] |
| 26 | + assert len(set(pairs)) == len(pairs) |
| 27 | + |
| 28 | + |
| 29 | +@pytest.mark.parametrize('algo', BENCHMARKS) |
| 30 | +def test_executable_family_is_formulaic(algo): |
| 31 | + # Check consistency in the AlgorithmicBenchmark dataclass: |
| 32 | + assert algo.executable_family == algo.spec_class.executable_family, \ |
| 33 | + "benchmark's executable_family should match that of the spec_class" |
| 34 | + |
| 35 | + # By convention, we set this to be the module name. By further convention, |
| 36 | + # {algo.domain}.{algo.name} should be the module name. |
| 37 | + assert algo.executable_family == f'{algo.domain}.{algo.name}', \ |
| 38 | + "The executable family should be set to the benchmarks's domain.name" |
| 39 | + |
| 40 | + # Check the convention that it should give a module |
| 41 | + mod = import_module(algo.executable_family) |
| 42 | + assert isinstance(mod, ModuleType), \ |
| 43 | + "The executable family should specify an importable module." |
| 44 | + |
| 45 | + |
| 46 | +@pytest.mark.parametrize('algo', BENCHMARKS) |
| 47 | +def test_classes_and_funcs(algo): |
| 48 | + # The various class objects should exist in the module |
| 49 | + mod = import_module(algo.executable_family) |
| 50 | + assert algo.spec_class == getattr(mod, algo.spec_class.__name__), \ |
| 51 | + "The spec_class must exist in the benchmark's module" |
| 52 | + assert algo.data_class == getattr(mod, algo.data_class.__name__), \ |
| 53 | + "The data_class must exist in the benchmark's module" |
| 54 | + assert algo.executable_generator_func == getattr(mod, algo.executable_generator_func.__name__), \ |
| 55 | + "the executable_generator_func must exist in the benchmark's module" |
| 56 | + |
| 57 | + |
| 58 | +def test_globally_unique_executable_family(): |
| 59 | + # Each entry should have a unique executable family |
| 60 | + fams = [algo.executable_family for algo in BENCHMARKS] |
| 61 | + assert len(set(fams)) == len(fams) |
| 62 | + |
| 63 | + |
| 64 | +def test_globally_unique_config_full_name(): |
| 65 | + full_names = [config.full_name for algo, config in get_all_algo_configs()] |
| 66 | + assert len(set(full_names)) == len(full_names) |
| 67 | + |
| 68 | + |
| 69 | +@pytest.mark.parametrize('algo_config', get_all_algo_configs()) |
| 70 | +def test_gen_script(algo_config): |
| 71 | + algo, config = algo_config |
| 72 | + |
| 73 | + # Make sure it's formulaic |
| 74 | + assert config.gen_script == f'gen-{config.short_name}.py', \ |
| 75 | + "The gen_script should be of the form 'gen-{short_name}'" |
| 76 | + |
| 77 | + # Make sure it exists |
| 78 | + gen_script_path = (f"{RECIRQ_DIR}/{algo.domain.replace('.', '/')}/" |
| 79 | + f"{algo.name.replace('.', '/')}/{config.gen_script}") |
| 80 | + assert os.path.exists(gen_script_path) |
0 commit comments