util/u_format: add _is_alpha()
[mesa.git] / src / gallium / auxiliary / util / u_format_pack.py
index 8072fdb138bbd63493105dcddec34658025979db..6ccf04c29e7754a778d3ed3193c4a3ca91c6bc7b 100644 (file)
 from u_format_parse import *
 
 
+def inv_swizzles(swizzles):
+    '''Return an array[4] of inverse swizzle terms'''
+    '''Only pick the first matching value to avoid l8 getting blue and i8 getting alpha'''
+    inv_swizzle = [None]*4
+    for i in range(4):
+        swizzle = swizzles[i]
+        if swizzle < 4 and inv_swizzle[swizzle] == None:
+            inv_swizzle[swizzle] = i
+    return inv_swizzle
+
+def print_channels(format, func):
+    if format.nr_channels() <= 1:
+        func(format.le_channels, format.le_swizzles)
+    else:
+        print '#ifdef PIPE_ARCH_BIG_ENDIAN'
+        func(format.be_channels, format.be_swizzles)
+        print '#else'
+        func(format.le_channels, format.le_swizzles)
+        print '#endif'
+
 def generate_format_type(format):
     '''Generate a structure that describes the format.'''
 
     assert format.layout == PLAIN
     
-    print 'union util_format_%s {' % format.short_name()
-    
-    if format.block_size() in (8, 16, 32, 64):
-        print '   uint%u_t value;' % (format.block_size(),)
-
-    use_bitfields = False
-    for channel in format.channels:
-        if channel.size % 8 or not is_pot(channel.size):
-            use_bitfields = True
-
-    print '   struct {'
-    for channel in format.channels:
-        if use_bitfields:
+    def generate_bitfields(channels, swizzles):
+        for channel in channels:
             if channel.type == VOID:
                 if channel.size:
                     print '      unsigned %s:%u;' % (channel.name, channel.size)
@@ -74,7 +83,9 @@ def generate_format_type(format):
                     print '      unsigned %s:%u;' % (channel.name, channel.size)
             else:
                 assert 0
-        else:
+
+    def generate_full_fields(channels, swizzles):
+        for channel in channels:
             assert channel.size % 8 == 0 and is_pot(channel.size)
             if channel.type == VOID:
                 if channel.size:
@@ -94,6 +105,22 @@ def generate_format_type(format):
                     assert 0
             else:
                 assert 0
+
+    print 'union util_format_%s {' % format.short_name()
+    
+    if format.block_size() in (8, 16, 32, 64):
+        print '   uint%u_t value;' % (format.block_size(),)
+
+    use_bitfields = False
+    for channel in format.le_channels:
+        if channel.size % 8 or not is_pot(channel.size):
+            use_bitfields = True
+
+    print '   struct {'
+    if use_bitfields:
+        print_channels(format, generate_bitfields)
+    else:
+        print_channels(format, generate_full_fields)
     print '   } chan;'
     print '};'
     print
@@ -109,7 +136,7 @@ def is_format_supported(format):
         return False
 
     for i in range(4):
-        channel = format.channels[i]
+        channel = format.le_channels[i]
         if channel.type not in (VOID, UNSIGNED, SIGNED, FLOAT, FIXED):
             return False
         if channel.type == FLOAT and channel.size not in (16, 32, 64):
@@ -117,27 +144,6 @@ def is_format_supported(format):
 
     return True
 
-def is_format_pure_unsigned(format):
-    for i in range(4):
-        channel = format.channels[i]
-        if channel.type not in (VOID, UNSIGNED):
-            return False
-        if channel.type == UNSIGNED and channel.pure == False:
-            return False
-
-    return True
-
-
-def is_format_pure_signed(format):
-    for i in range(4):
-        channel = format.channels[i]
-        if channel.type not in (VOID, SIGNED):
-            return False
-        if channel.type == SIGNED and channel.pure == False:
-            return False
-
-    return True
-
 def native_type(format):
     '''Get the native appropriate for a format.'''
 
@@ -152,7 +158,7 @@ def native_type(format):
             return 'uint%u_t' % format.block_size()
         else:
             # For array pixel formats return the integer type that matches the color channel
-            channel = format.channels[0]
+            channel = format.array_element()
             if channel.type in (UNSIGNED, VOID):
                 return 'uint%u_t' % channel.size
             elif channel.type in (SIGNED, FIXED):
@@ -410,13 +416,13 @@ def generate_unpack_kernel(format, dst_channel, dst_native_type):
 
     src_native_type = native_type(format)
 
-    if format.is_bitmask():
+    def unpack_from_bitmask(channels, swizzles):
         depth = format.block_size()
         print '         uint%u_t value = *(const uint%u_t *)src;' % (depth, depth) 
 
         # Declare the intermediate variables
         for i in range(format.nr_channels()):
-            src_channel = format.channels[i]
+            src_channel = channels[i]
             if src_channel.type == UNSIGNED:
                 print '         uint%u_t %s;' % (depth, src_channel.name)
             elif src_channel.type == SIGNED:
@@ -424,7 +430,7 @@ def generate_unpack_kernel(format, dst_channel, dst_native_type):
 
         # Compute the intermediate unshifted values 
         for i in range(format.nr_channels()):
-            src_channel = format.channels[i]
+            src_channel = channels[i]
             value = 'value'
             shift = src_channel.shift
             if src_channel.type == UNSIGNED:
@@ -451,9 +457,9 @@ def generate_unpack_kernel(format, dst_channel, dst_native_type):
                 
         # Convert, swizzle, and store final values
         for i in range(4):
-            swizzle = format.swizzles[i]
+            swizzle = swizzles[i]
             if swizzle < 4:
-                src_channel = format.channels[swizzle]
+                src_channel = channels[swizzle]
                 src_colorspace = format.colorspace
                 if src_colorspace == SRGB and i == 3:
                     # Alpha channel is linear
@@ -473,14 +479,14 @@ def generate_unpack_kernel(format, dst_channel, dst_native_type):
                 assert False
             print '         dst[%u] = %s; /* %s */' % (i, value, 'rgba'[i])
         
-    else:
+    def unpack_from_union(channels, swizzles):
         print '         union util_format_%s pixel;' % format.short_name()
         print '         memcpy(&pixel, src, sizeof pixel);'
     
         for i in range(4):
-            swizzle = format.swizzles[i]
+            swizzle = swizzles[i]
             if swizzle < 4:
-                src_channel = format.channels[swizzle]
+                src_channel = channels[swizzle]
                 src_colorspace = format.colorspace
                 if src_colorspace == SRGB and i == 3:
                     # Alpha channel is linear
@@ -500,6 +506,11 @@ def generate_unpack_kernel(format, dst_channel, dst_native_type):
                 assert False
             print '         dst[%u] = %s; /* %s */' % (i, value, 'rgba'[i])
     
+    if format.is_bitmask():
+        print_channels(format, unpack_from_bitmask)
+    else:
+        print_channels(format, unpack_from_union)
+
 
 def generate_pack_kernel(format, src_channel, src_native_type):
 
@@ -510,14 +521,14 @@ def generate_pack_kernel(format, src_channel, src_native_type):
 
     assert format.layout == PLAIN
 
-    inv_swizzle = format.inv_swizzles()
+    def pack_into_bitmask(channels, swizzles):
+        inv_swizzle = inv_swizzles(swizzles)
 
-    if format.is_bitmask():
         depth = format.block_size()
         print '         uint%u_t value = 0;' % depth 
 
         for i in range(4):
-            dst_channel = format.channels[i]
+            dst_channel = channels[i]
             shift = dst_channel.shift
             if inv_swizzle[i] is not None:
                 value ='src[%u]' % inv_swizzle[i]
@@ -544,11 +555,13 @@ def generate_pack_kernel(format, src_channel, src_native_type):
                 
         print '         *(uint%u_t *)dst = value;' % depth 
 
-    else:
+    def pack_into_union(channels, swizzles):
+        inv_swizzle = inv_swizzles(swizzles)
+
         print '         union util_format_%s pixel;' % format.short_name()
     
         for i in range(4):
-            dst_channel = format.channels[i]
+            dst_channel = channels[i]
             width = dst_channel.size
             if inv_swizzle[i] is None:
                 continue
@@ -565,6 +578,11 @@ def generate_pack_kernel(format, src_channel, src_native_type):
     
         print '         memcpy(dst, &pixel, sizeof pixel);'
     
+    if format.is_bitmask():
+        print_channels(format, pack_into_bitmask)
+    else:
+        print_channels(format, pack_into_union)
+
 
 def generate_format_unpack(format, dst_channel, dst_native_type, dst_suffix):
     '''Generate the function to unpack pixels from a particular format'''
@@ -641,7 +659,7 @@ def generate_format_fetch(format, dst_channel, dst_native_type, dst_suffix):
 
 
 def is_format_hand_written(format):
-    return format.layout in ('s3tc', 'rgtc', 'etc', 'subsampled', 'other') or format.colorspace == ZS
+    return format.layout in ('s3tc', 'rgtc', 'etc', 'bptc', 'subsampled', 'other') or format.colorspace == ZS
 
 
 def generate(formats):
@@ -651,7 +669,7 @@ def generate(formats):
     print '#include "u_half.h"'
     print '#include "u_format.h"'
     print '#include "u_format_other.h"'
-    print '#include "u_format_srgb.h"'
+    print '#include "util/format_srgb.h"'
     print '#include "u_format_yuv.h"'
     print '#include "u_format_zs.h"'
     print
@@ -662,7 +680,7 @@ def generate(formats):
             if is_format_supported(format):
                 generate_format_type(format)
 
-            if is_format_pure_unsigned(format):
+            if format.is_pure_unsigned():
                 native_type = 'unsigned'
                 suffix = 'unsigned'
                 channel = Channel(UNSIGNED, False, True, 32)
@@ -676,7 +694,7 @@ def generate(formats):
                 suffix = 'signed'
                 generate_format_unpack(format, channel, native_type, suffix)
                 generate_format_pack(format, channel, native_type, suffix)   
-            elif is_format_pure_signed(format):
+            elif format.is_pure_signed():
                 native_type = 'int'
                 suffix = 'signed'
                 channel = Channel(SIGNED, False, True, 32)