util: Explicitly call the unpack functions from inside bptc pack/unpack.
[mesa.git] / src / util / format / u_format_pack.py
index 0653e5e4a9e87dcdf80aceff57d53f52802c892d..a9a7792309594c259e41c9a123d75790b0f2fbec 100644 (file)
@@ -49,7 +49,6 @@ if sys.version_info < (3, 0):
 else:
     integer_types = (int, )
 
-
 def inv_swizzles(swizzles):
     '''Return an array[4] of inverse swizzle terms'''
     '''Only pick the first matching value to avoid l8 getting blue and i8 getting alpha'''
@@ -437,7 +436,7 @@ def conversion_expr(src_channel,
             src_size = 32
 
         if dst_channel.size == 16:
-            value = 'util_float_to_half(%s)' % value
+            value = 'util_float_to_half_rtz(%s)' % value
         elif dst_channel.size == 64 and src_size < 64:
             value = '(double)%s' % value
 
@@ -593,7 +592,7 @@ def generate_pack_kernel(format, src_channel, src_native_type):
     def pack_into_struct(channels, swizzles):
         inv_swizzle = inv_swizzles(swizzles)
 
-        print('         struct util_format_%s pixel;' % format.short_name())
+        print('         struct util_format_%s pixel = {0};' % format.short_name())
     
         for i in range(4):
             dst_channel = channels[i]
@@ -624,8 +623,16 @@ def generate_format_unpack(format, dst_channel, dst_native_type, dst_suffix):
 
     name = format.short_name()
 
-    print('static inline void')
-    print('util_format_%s_unpack_%s(%s *dst_row, unsigned dst_stride, const uint8_t *src_row, unsigned src_stride, unsigned width, unsigned height)' % (name, dst_suffix, dst_native_type))
+    if "8unorm" in dst_suffix:
+        dst_proto_type = dst_native_type
+    else:
+        dst_proto_type = 'void'
+
+    proto = 'util_format_%s_unpack_%s(%s *dst_row, unsigned dst_stride, const uint8_t *src_row, unsigned src_stride, unsigned width, unsigned height)' % (name, dst_suffix, dst_proto_type)
+    print('void %s;' % proto, file=sys.stdout2)
+
+    print('void')
+    print(proto)
     print('{')
 
     if is_format_supported(format):
@@ -641,7 +648,7 @@ def generate_format_unpack(format, dst_channel, dst_native_type, dst_suffix):
         print('         dst += 4;')
         print('      }')
         print('      src_row += src_stride;')
-        print('      dst_row += dst_stride/sizeof(*dst_row);')
+        print('      dst_row = (uint8_t *)dst_row + dst_stride;')
         print('   }')
 
     print('}')
@@ -653,9 +660,11 @@ def generate_format_pack(format, src_channel, src_native_type, src_suffix):
 
     name = format.short_name()
 
-    print('static inline void')
+    print('void')
     print('util_format_%s_pack_%s(uint8_t *dst_row, unsigned dst_stride, const %s *src_row, unsigned src_stride, unsigned width, unsigned height)' % (name, src_suffix, src_native_type))
     print('{')
+
+    print('void util_format_%s_pack_%s(uint8_t *dst_row, unsigned dst_stride, const %s *src_row, unsigned src_stride, unsigned width, unsigned height);' % (name, src_suffix, src_native_type), file=sys.stdout2)
     
     if is_format_supported(format):
         print('   unsigned x, y;')
@@ -677,14 +686,19 @@ def generate_format_pack(format, src_channel, src_native_type, src_suffix):
     print()
     
 
-def generate_format_fetch(format, dst_channel, dst_native_type, dst_suffix):
+def generate_format_fetch(format, dst_channel, dst_native_type):
     '''Generate the function to unpack pixels from a particular format'''
 
     name = format.short_name()
 
-    print('static inline void')
-    print('util_format_%s_fetch_%s(%s *dst, const uint8_t *src, UNUSED unsigned i, UNUSED unsigned j)' % (name, dst_suffix, dst_native_type))
+    proto = 'util_format_%s_fetch_rgba(void *in_dst, const uint8_t *src, UNUSED unsigned i, UNUSED unsigned j)' % (name)
+    print('void %s;' % proto, file=sys.stdout2)
+
+    print('void')
+    print(proto)
+
     print('{')
+    print('   %s *dst = in_dst;' % dst_native_type)
 
     if is_format_supported(format):
         generate_unpack_kernel(format, dst_channel, dst_native_type)
@@ -707,6 +721,7 @@ def generate(formats):
     print('#include "util/format_srgb.h"')
     print('#include "u_format_yuv.h"')
     print('#include "u_format_zs.h"')
+    print('#include "u_format_pack.h"')
     print()
 
     for format in formats:
@@ -722,12 +737,11 @@ def generate(formats):
 
                 generate_format_unpack(format, channel, native_type, suffix)
                 generate_format_pack(format, channel, native_type, suffix)
-                generate_format_fetch(format, channel, native_type, suffix)
+                generate_format_fetch(format, channel, native_type)
 
                 channel = Channel(SIGNED, False, True, 32)
                 native_type = 'int'
                 suffix = 'signed'
-                generate_format_unpack(format, channel, native_type, suffix)
                 generate_format_pack(format, channel, native_type, suffix)   
             elif format.is_pure_signed():
                 native_type = 'int'
@@ -736,12 +750,11 @@ def generate(formats):
 
                 generate_format_unpack(format, channel, native_type, suffix)
                 generate_format_pack(format, channel, native_type, suffix)   
-                generate_format_fetch(format, channel, native_type, suffix)
+                generate_format_fetch(format, channel, native_type)
 
                 native_type = 'unsigned'
                 suffix = 'unsigned'
                 channel = Channel(UNSIGNED, False, True, 32)
-                generate_format_unpack(format, channel, native_type, suffix)
                 generate_format_pack(format, channel, native_type, suffix)   
             else:
                 channel = Channel(FLOAT, False, False, 32)
@@ -750,7 +763,7 @@ def generate(formats):
 
                 generate_format_unpack(format, channel, native_type, suffix)
                 generate_format_pack(format, channel, native_type, suffix)
-                generate_format_fetch(format, channel, native_type, suffix)
+                generate_format_fetch(format, channel, native_type)
 
                 channel = Channel(UNSIGNED, True, False, 8)
                 native_type = 'uint8_t'
@@ -758,4 +771,3 @@ def generate(formats):
 
                 generate_format_unpack(format, channel, native_type, suffix)
                 generate_format_pack(format, channel, native_type, suffix)
-