util: Generate correct format conversions for half floats.
authorMichal Krol <michal@vmware.com>
Thu, 1 Apr 2010 11:56:03 +0000 (13:56 +0200)
committerMichal Krol <michal@vmware.com>
Thu, 1 Apr 2010 11:56:03 +0000 (13:56 +0200)
src/gallium/auxiliary/util/u_format_pack.py

index d36c6377380955fcbb451771db74f7790ffc780a..390be0c3b2003f60d74a3e785778d51e0b5ac54d 100644 (file)
@@ -231,66 +231,86 @@ def conversion_expr(src_channel, dst_channel, dst_native_type, value, clamp=True
     if src_channel == dst_channel:
         return value
 
-    if src_channel.type == FLOAT and dst_channel.type == FLOAT:
-        if src_channel.size == 64:
-            value = '(float)%s' % (value)
-        elif src_channel.size == 16:
-            value = 'util_half_to_float(%s)' % (value)
+    src_type = src_channel.type
+    src_size = src_channel.size
+    src_norm = src_channel.norm
 
-        if dst_channel.size == 16:
-            value = 'util_float_to_half(%s)' % (value)
-        elif dst_channel.size == 64:
-            value = '(double)%s' % (value)
-
-        return value
+    # Promote half to float
+    if src_type == FLOAT and src_size == 16:
+        value = 'util_half_to_float(%s)' % value
+        src_size = 32
 
     if clamp:
-        value = clamp_expr(src_channel, dst_channel, dst_native_type, value)
-
-    if dst_channel.type == FLOAT:
-        if src_channel.norm:
-            one = get_one(src_channel)
-            if src_channel.size <= 23:
-                scale = '(1.0f/0x%x)' % one
-            else:
-                # bigger than single precision mantissa, use double
-                scale = '(1.0/0x%x)' % one
-            value = '(%s * %s)' % (value, scale)
-        return '(%s)%s' % (dst_native_type, value)
-
-    if src_channel.type == FLOAT:
-        if dst_channel.norm:
-            dst_one = get_one(dst_channel)
-            if dst_channel.size <= 23:
-                scale = '0x%x' % dst_one
-            else:
-                # bigger than single precision mantissa, use double
-                scale = '(double)0x%x' % dst_one
-            value = '(%s * %s)' % (value, scale)
-        return '(%s)%s' % (dst_native_type, value)
+        if dst_channel.type != FLOAT or src_type != FLOAT:
+            value = clamp_expr(src_channel, dst_channel, dst_native_type, value)
 
-    if src_channel.type in (SIGNED, UNSIGNED) and dst_channel.type in (SIGNED, UNSIGNED):
-        if not src_channel.norm and not dst_channel.norm:
+    if src_type in (SIGNED, UNSIGNED) and dst_channel.type in (SIGNED, UNSIGNED):
+        if not src_norm and not dst_channel.norm:
             # neither is normalized -- just cast
             return '(%s)%s' % (dst_native_type, value)
 
         src_one = get_one(src_channel)
         dst_one = get_one(dst_channel)
 
-        if src_one > dst_one and src_channel.norm and dst_channel.norm:
+        if src_one > dst_one and src_norm and dst_channel.norm:
             # We can just bitshift
             src_shift = get_one_shift(src_channel)
             dst_shift = get_one_shift(dst_channel)
             value = '(%s >> %s)' % (value, src_shift - dst_shift)
         else:
             # We need to rescale using an intermediate type big enough to hold the multiplication of both
-            tmp_native_type = intermediate_native_type(src_channel.size + dst_channel.size, src_channel.sign and dst_channel.sign)
+            tmp_native_type = intermediate_native_type(src_size + dst_channel.size, src_channel.sign and dst_channel.sign)
             value = '((%s)%s)' % (tmp_native_type, value)
             value = '(%s * 0x%x / 0x%x)' % (value, dst_one, src_one)
         value = '(%s)%s' % (dst_native_type, value)
         return value
 
-    assert False
+    # Promote to either float or double
+    if src_type != FLOAT:
+        if src_norm:
+            one = get_one(src_channel)
+            if src_size <= 23:
+                value = '(%s * (1.0f/0x%x))' % (value, one)
+                if dst_channel.size <= 32:
+                    value = '(float)%s' % value
+                src_size = 32
+            else:
+                # bigger than single precision mantissa, use double
+                value = '(%s * (1.0/0x%x))' % (value, one)
+                src_size = 64
+            src_norm = False
+        else:
+            if src_size <= 23 or dst_channel.size <= 32:
+                value = '(float)%s' % value
+                src_size = 32
+            else:
+                # bigger than single precision mantissa, use double
+                value = '(double)%s' % value
+                src_size = 64
+        src_type = FLOAT
+
+    # Convert double or float to non-float
+    if dst_channel.type != FLOAT:
+        if dst_channel.norm:
+            dst_one = get_one(dst_channel)
+            if dst_channel.size <= 23:
+                value = '(%s * 0x%x)' % (value, dst_one)
+            else:
+                # bigger than single precision mantissa, use double
+                value = '(%s * (double)0x%x)' % (value, dst_one)
+        value = '(%s)%s' % (dst_native_type, value)
+    else:
+        # Cast double to float when converting to either half or float
+        if dst_channel.size <= 32 and src_size > 32:
+            value = '(float)%s' % value
+            src_size = 32
+
+        if dst_channel.size == 16:
+            value = 'util_float_to_half(%s)' % value
+        elif dst_channel.size == 64 and src_size < 64:
+            value = '(double)%s' % value
+
+    return value
 
 
 def generate_unpack_kernel(format, dst_channel, dst_native_type):