vulkan: make generated enum to strings helpers available from c++
[mesa.git] / src / vulkan / util / gen_enum_to_str.py
index 3c9f260595b1e84f6599d213835a09f2b9d37621..06f74eb487c01eb421a12643f5fd8f7fbec5841f 100644 (file)
@@ -68,15 +68,15 @@ C_TEMPLATE = Template(textwrap.dedent(u"""\
     vk_${enum.name[2:]}_to_str(${enum.name} input)
     {
         switch(input) {
-        % for v in enum.values:
-            % if v in FOREIGN_ENUM_VALUES:
+        % for v in sorted(enum.values.keys()):
+            % if enum.values[v] in FOREIGN_ENUM_VALUES:
 
             #pragma GCC diagnostic push
             #pragma GCC diagnostic ignored "-Wswitch"
             % endif
             case ${v}:
-                return "${v}";
-            % if v in FOREIGN_ENUM_VALUES:
+                return "${enum.values[v]}";
+            % if enum.values[v] in FOREIGN_ENUM_VALUES:
             #pragma GCC diagnostic pop
 
             % endif
@@ -101,10 +101,22 @@ H_TEMPLATE = Template(textwrap.dedent(u"""\
     #include <vulkan/vulkan.h>
     #include <vulkan/vk_android_native_buffer.h>
 
+    #ifdef __cplusplus
+    extern "C" {
+    #endif
+
+    % for ext in extensions:
+    #define _${ext.name}_number (${ext.number})
+    % endfor
+
     % for enum in enums:
     const char * vk_${enum.name[2:]}_to_str(${enum.name} input);
     % endfor
 
+    #ifdef __cplusplus
+    } /* extern "C" */
+    #endif
+
     #endif"""),
     output_encoding='utf-8')
 
@@ -129,45 +141,88 @@ class NamedFactory(object):
             n = self.registry[name] = self.type(name, **kwargs)
         return n
 
+    def get(self, name):
+        return self.registry.get(name)
+
+
+class VkExtension(object):
+    """Simple struct-like class representing extensions"""
+
+    def __init__(self, name, number=None):
+        self.name = name
+        self.number = number
+
 
 class VkEnum(object):
     """Simple struct-like class representing a single Vulkan Enum."""
 
     def __init__(self, name, values=None):
         self.name = name
-        self.values = values or []
-
-
-def parse_xml(enum_factory, filename):
-    """Parse the XML file. Accumulate results into the efactory.
+        # Maps numbers to names
+        self.values = values or dict()
+        self.name_to_value = dict()
+
+    def add_value(self, name, value=None,
+                  extnum=None, offset=None,
+                  error=False):
+        assert value is not None or extnum is not None
+        if value is None:
+            value = 1000000000 + (extnum - 1) * 1000 + offset
+            if error:
+                value = -value
+
+        self.name_to_value[name] = value
+        if value not in self.values:
+            self.values[value] = name
+        elif len(self.values[value]) > len(name):
+            self.values[value] = name
+
+    def add_value_from_xml(self, elem, extension=None):
+        if 'value' in elem.attrib:
+            self.add_value(elem.attrib['name'],
+                           value=int(elem.attrib['value'], base=0))
+        elif 'alias' in elem.attrib:
+            self.add_value(elem.attrib['name'],
+                           value=self.name_to_value[elem.attrib['alias']])
+        else:
+            error = 'dir' in elem.attrib and elem.attrib['dir'] == '-'
+            if 'extnumber' in elem.attrib:
+                extnum = int(elem.attrib['extnumber'])
+            else:
+                extnum = extension.number
+            self.add_value(elem.attrib['name'],
+                           extnum=extnum,
+                           offset=int(elem.attrib['offset']),
+                           error=error)
+
+
+def parse_xml(enum_factory, ext_factory, filename):
+    """Parse the XML file. Accumulate results into the factories.
 
     This parser is a memory efficient iterative XML parser that returns a list
     of VkEnum objects.
     """
 
-    with open(filename, 'rb') as f:
-        context = iter(et.iterparse(f, events=('start', 'end')))
+    xml = et.parse(filename)
+
+    for enum_type in xml.findall('./enums[@type="enum"]'):
+        enum = enum_factory(enum_type.attrib['name'])
+        for value in enum_type.findall('./enum'):
+            enum.add_value_from_xml(value)
 
-        # This gives the root element, since goal is to iterate over the
-        # elements without building a tree, this allows the root to be cleared
-        # (erase the elements) after the children have been processed.
-        _, root = next(context)
+    for value in xml.findall('./feature/require/enum[@extends]'):
+        enum = enum_factory.get(value.attrib['extends'])
+        if enum is not None:
+            enum.add_value_from_xml(value)
 
-        for event, elem in context:
-            if event == 'end' and elem.tag == 'enums':
-                type_ = elem.attrib.get('type')
-                if type_ == 'enum':
-                    enum = enum_factory(elem.attrib['name'])
-                    enum.values.extend([e.attrib['name'] for e in elem
-                                        if e.tag == 'enum'])
-            elif event == 'end' and elem.tag == 'extension':
-                if elem.attrib['supported'] != 'vulkan':
-                    continue
-                for e in elem.findall('.//enum[@extends][@offset]'):
-                    enum = enum_factory(e.attrib['extends'])
-                    enum.values.append(e.attrib['name'])
+    for ext_elem in xml.findall('./extensions/extension[@supported="vulkan"]'):
+        extension = ext_factory(ext_elem.attrib['name'],
+                                number=int(ext_elem.attrib['number']))
 
-            root.clear()
+        for value in ext_elem.findall('./require/enum[@extends]'):
+            enum = enum_factory.get(value.attrib['extends'])
+            if enum is not None:
+                enum.add_value_from_xml(value, extension)
 
 
 def main():
@@ -183,9 +238,11 @@ def main():
     args = parser.parse_args()
 
     enum_factory = NamedFactory(VkEnum)
+    ext_factory = NamedFactory(VkExtension)
     for filename in args.xml_files:
-        parse_xml(enum_factory, filename)
+        parse_xml(enum_factory, ext_factory, filename)
     enums = sorted(enum_factory.registry.values(), key=lambda e: e.name)
+    extensions = sorted(ext_factory.registry.values(), key=lambda e: e.name)
 
     for template, file_ in [(C_TEMPLATE, os.path.join(args.outdir, 'vk_enum_to_str.c')),
                             (H_TEMPLATE, os.path.join(args.outdir, 'vk_enum_to_str.h'))]:
@@ -193,6 +250,7 @@ def main():
             f.write(template.render(
                 file=os.path.basename(__file__),
                 enums=enums,
+                extensions=extensions,
                 copyright=COPYRIGHT,
                 FOREIGN_ENUM_VALUES=FOREIGN_ENUM_VALUES))