+
+void *
+radv_lookup_physical_device_entrypoint_checked(const char *name,
+ uint32_t core_version,
+ const struct radv_instance_extension_table *instance)
+{
+ int index = radv_lookup_entrypoint(name);
+ if (index < 0 || !radv_entrypoint_is_enabled_physical_device(index, core_version, instance))
+ return NULL;
+ return radv_resolve_entrypoint(index);
+}
+
+""", output_encoding='utf-8')
+
+U32_MASK = 2**32 - 1
+
+PRIME_FACTOR = 5024183
+PRIME_STEP = 19
+
+def round_to_pow2(x):
+ return 2**int(math.ceil(math.log(x, 2)))
+
+class StringIntMapEntry(object):
+ def __init__(self, string, num):
+ self.string = string
+ self.num = num
+
+ # Calculate the same hash value that we will calculate in C.
+ h = 0
+ for c in string:
+ h = ((h * PRIME_FACTOR) + ord(c)) & U32_MASK
+ self.hash = h
+
+ self.offset = None
+
+class StringIntMap(object):
+ def __init__(self):
+ self.baked = False
+ self.strings = dict()
+
+ def add_string(self, string, num):
+ assert not self.baked
+ assert string not in self.strings
+ assert num >= 0 and num < 2**31
+ self.strings[string] = StringIntMapEntry(string, num)
+
+ def bake(self):
+ self.sorted_strings = \
+ sorted(self.strings.values(), key=lambda x: x.string)
+ offset = 0
+ for entry in self.sorted_strings:
+ entry.offset = offset
+ offset += len(entry.string) + 1
+
+ # Save off some values that we'll need in C
+ self.hash_size = round_to_pow2(len(self.strings) * 1.25)
+ self.hash_mask = self.hash_size - 1
+ self.prime_factor = PRIME_FACTOR
+ self.prime_step = PRIME_STEP
+
+ self.mapping = [-1] * self.hash_size
+ self.collisions = [0] * 10
+ for idx, s in enumerate(self.sorted_strings):
+ level = 0
+ h = s.hash
+ while self.mapping[h & self.hash_mask] >= 0:
+ h = h + PRIME_STEP
+ level = level + 1
+ self.collisions[min(level, 9)] += 1
+ self.mapping[h & self.hash_mask] = idx
+
+EntrypointParam = namedtuple('EntrypointParam', 'type name decl')
+
+class EntrypointBase(object):
+ def __init__(self, name):
+ self.name = name
+ self.alias = None
+ self.guard = None
+ self.enabled = False
+ self.num = None
+ # Extensions which require this entrypoint
+ self.core_version = None
+ self.extensions = []
+
+class Entrypoint(EntrypointBase):
+ def __init__(self, name, return_type, params, guard = None):
+ super(Entrypoint, self).__init__(name)
+ self.return_type = return_type
+ self.params = params
+ self.guard = guard
+ self.device_command = len(params) > 0 and (params[0].type == 'VkDevice' or params[0].type == 'VkQueue' or params[0].type == 'VkCommandBuffer')
+ self.physical_device_command = len(params) > 0 and params[0].type == 'VkPhysicalDevice'
+
+ def prefixed_name(self, prefix):
+ assert self.name.startswith('vk')
+ return prefix + '_' + self.name[2:]
+
+ def decl_params(self):
+ return ', '.join(p.decl for p in self.params)
+
+ def call_params(self):
+ return ', '.join(p.name for p in self.params)
+
+class EntrypointAlias(EntrypointBase):
+ def __init__(self, name, entrypoint):
+ super(EntrypointAlias, self).__init__(name)
+ self.alias = entrypoint
+ self.device_command = entrypoint.device_command
+ self.physical_device_command = entrypoint.physical_device_command
+
+ def prefixed_name(self, prefix):
+ return self.alias.prefixed_name(prefix)
+
+def get_entrypoints(doc, entrypoints_to_defines, start_index):
+ """Extract the entry points from the registry."""
+ entrypoints = OrderedDict()
+
+ for command in doc.findall('./commands/command'):
+ if 'alias' in command.attrib:
+ alias = command.attrib['name']
+ target = command.attrib['alias']
+ entrypoints[alias] = EntrypointAlias(alias, entrypoints[target])
+ else:
+ name = command.find('./proto/name').text
+ ret_type = command.find('./proto/type').text
+ params = [EntrypointParam(
+ type = p.find('./type').text,
+ name = p.find('./name').text,
+ decl = ''.join(p.itertext())
+ ) for p in command.findall('./param')]
+ guard = entrypoints_to_defines.get(name)
+ # They really need to be unique
+ assert name not in entrypoints
+ entrypoints[name] = Entrypoint(name, ret_type, params, guard)
+
+ for feature in doc.findall('./feature'):
+ assert feature.attrib['api'] == 'vulkan'
+ version = VkVersion(feature.attrib['number'])
+ if version > MAX_API_VERSION:
+ continue
+
+ for command in feature.findall('./require/command'):
+ e = entrypoints[command.attrib['name']]
+ e.enabled = True
+ assert e.core_version is None
+ e.core_version = version
+
+ supported_exts = dict((ext.name, ext) for ext in EXTENSIONS)
+ for extension in doc.findall('.extensions/extension'):
+ ext_name = extension.attrib['name']
+ if ext_name not in supported_exts:
+ continue
+
+ ext = supported_exts[ext_name]
+ ext.type = extension.attrib['type']
+
+ for command in extension.findall('./require/command'):
+ e = entrypoints[command.attrib['name']]
+ e.enabled = True
+ assert e.core_version is None
+ e.extensions.append(ext)
+
+ # if the base command is not supported by the driver yet, don't alias aliases
+ for e in entrypoints.values():
+ if e.alias and not e.alias.enabled:
+ e_clone = copy.deepcopy(e.alias)
+ e_clone.enabled = True
+ e_clone.name = e.name
+ entrypoints[e.name] = e_clone
+
+ return [e for e in entrypoints.values() if e.enabled]
+
+
+def get_entrypoints_defines(doc):
+ """Maps entry points to extension defines."""
+ entrypoints_to_defines = {}
+
+ for extension in doc.findall('./extensions/extension[@protect]'):
+ define = extension.attrib['protect']
+
+ for entrypoint in extension.findall('./require/command'):
+ fullname = entrypoint.attrib['name']
+ entrypoints_to_defines[fullname] = define
+
+ platform_define = {}
+ for platform in doc.findall('./platforms/platform'):
+ name = platform.attrib['name']
+ define = platform.attrib['protect']
+ platform_define[name] = define
+
+ for extension in doc.findall('./extensions/extension[@platform]'):
+ platform = extension.attrib['platform']
+ define = platform_define[platform]
+
+ for entrypoint in extension.findall('./require/command'):
+ fullname = entrypoint.attrib['name']
+ entrypoints_to_defines[fullname] = define
+
+ return entrypoints_to_defines
+
+
+def gen_code(entrypoints):
+ """Generate the C code."""
+ strmap = StringIntMap()
+ for e in entrypoints:
+ strmap.add_string(e.name, e.num)
+ strmap.bake()
+
+ return TEMPLATE_C.render(entrypoints=entrypoints,
+ LAYERS=LAYERS,
+ strmap=strmap,
+ filename=os.path.basename(__file__))
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--outdir', help='Where to write the files.',
+ required=True)
+ parser.add_argument('--xml',
+ help='Vulkan API XML file.',
+ required=True,
+ action='append',
+ dest='xml_files')
+ args = parser.parse_args()
+
+ entrypoints = []
+
+ for filename in args.xml_files:
+ doc = et.parse(filename)
+ entrypoints += get_entrypoints(doc, get_entrypoints_defines(doc),
+ start_index=len(entrypoints))
+
+ for num, e in enumerate(entrypoints):
+ e.num = num
+
+ # For outputting entrypoints.h we generate a radv_EntryPoint() prototype
+ # per entry point.
+ with open(os.path.join(args.outdir, 'radv_entrypoints.h'), 'wb') as f:
+ f.write(TEMPLATE_H.render(entrypoints=entrypoints,
+ LAYERS=LAYERS,
+ filename=os.path.basename(__file__)))
+ with open(os.path.join(args.outdir, 'radv_entrypoints.c'), 'wb') as f:
+ f.write(gen_code(entrypoints))
+
+
+if __name__ == '__main__':
+ main()