i965/vec4: fix optimize predicate for doubles
[mesa.git] / src / mesa / drivers / dri / i965 / brw_vec4_nir.cpp
index 4aeb8cb48d3a4e434664c71c6b74803d9952a40e..8c9b98347bfa086f845b3c128af4b6ecdb506a42 100644 (file)
@@ -1003,8 +1003,10 @@ vec4_visitor::optimize_predicate(nir_alu_instr *instr,
    src_reg op[2];
    assert(nir_op_infos[cmp_instr->op].num_inputs == 2);
    for (unsigned i = 0; i < 2; i++) {
-      op[i] = get_nir_src(cmp_instr->src[i].src,
-                          nir_op_infos[cmp_instr->op].input_types[i], 4);
+      nir_alu_type type = nir_op_infos[cmp_instr->op].input_types[i];
+      unsigned bit_size = nir_src_bit_size(cmp_instr->src[i].src);
+      type = (nir_alu_type) (((unsigned) type) | bit_size);
+      op[i] = get_nir_src(cmp_instr->src[i].src, type, 4);
       unsigned base_swizzle =
          brw_swizzle_for_nir_swizzle(cmp_instr->src[i].swizzle);
       op[i].swizzle = brw_compose_swizzle(size_swizzle, base_swizzle);
@@ -1071,6 +1073,17 @@ vec4_visitor::emit_conversion_from_double(dst_reg dst, src_reg src,
                                           bool saturate,
                                           brw_reg_type single_type)
 {
+   /* BDW PRM vol 15 - workarounds:
+    * DF->f format conversion for Align16 has wrong emask calculation when
+    * source is immediate.
+    */
+   if (devinfo->gen == 8 && single_type == BRW_REGISTER_TYPE_F &&
+       src.file == BRW_IMMEDIATE_VALUE) {
+      vec4_instruction *inst = emit(MOV(dst, brw_imm_f(src.df)));
+      inst->saturate = saturate;
+      return;
+   }
+
    dst_reg temp = dst_reg(this, glsl_type::dvec4_type);
    emit(MOV(temp, src));
 
@@ -1150,6 +1163,20 @@ vec4_visitor::nir_emit_alu(nir_alu_instr *instr)
                                 BRW_REGISTER_TYPE_F);
       break;
 
+   case nir_op_d2i:
+   case nir_op_d2u:
+      emit_conversion_from_double(dst, op[0], instr->dest.saturate,
+                                  instr->op == nir_op_d2i ? BRW_REGISTER_TYPE_D :
+                                                            BRW_REGISTER_TYPE_UD);
+      break;
+
+   case nir_op_i2d:
+   case nir_op_u2d:
+      emit_conversion_to_double(dst, op[0], instr->dest.saturate,
+                                instr->op == nir_op_i2d ? BRW_REGISTER_TYPE_D :
+                                                          BRW_REGISTER_TYPE_UD);
+      break;
+
    case nir_op_iadd:
       assert(nir_dest_bit_size(instr->dest.dest) < 64);
    case nir_op_fadd:
@@ -1518,6 +1545,24 @@ vec4_visitor::nir_emit_alu(nir_alu_instr *instr)
       emit(CMP(dst, op[0], brw_imm_f(0.0f), BRW_CONDITIONAL_NZ));
       break;
 
+   case nir_op_d2b: {
+      /* We use a MOV with conditional_mod to check if the provided value is
+       * 0.0. We want this to flush denormalized numbers to zero, so we set a
+       * source modifier on the source operand to trigger this, as source
+       * modifiers don't affect the result of the testing against 0.0.
+       */
+      src_reg value = op[0];
+      value.abs = true;
+      vec4_instruction *inst = emit(MOV(dst_null_df(), value));
+      inst->conditional_mod = BRW_CONDITIONAL_NZ;
+
+      src_reg one = src_reg(this, glsl_type::ivec4_type);
+      emit(MOV(dst_reg(one), brw_imm_d(~0)));
+      inst = emit(BRW_OPCODE_SEL, dst, one, brw_imm_d(0));
+      inst->predicate = BRW_PREDICATE_NORMAL;
+      break;
+   }
+
    case nir_op_i2b:
       emit(CMP(dst, op[0], brw_imm_d(0), BRW_CONDITIONAL_NZ));
       break;
@@ -1726,24 +1771,58 @@ vec4_visitor::nir_emit_alu(nir_alu_instr *instr)
       unreachable("not reached: should have been lowered");
 
    case nir_op_fsign:
-      /* AND(val, 0x80000000) gives the sign bit.
-       *
-       * Predicated OR ORs 1.0 (0x3f800000) with the sign bit if val is not
-       * zero.
-       */
-      emit(CMP(dst_null_f(), op[0], brw_imm_f(0.0f), BRW_CONDITIONAL_NZ));
+      if (type_sz(op[0].type) < 8) {
+         /* AND(val, 0x80000000) gives the sign bit.
+          *
+          * Predicated OR ORs 1.0 (0x3f800000) with the sign bit if val is not
+          * zero.
+          */
+         emit(CMP(dst_null_f(), op[0], brw_imm_f(0.0f), BRW_CONDITIONAL_NZ));
 
-      op[0].type = BRW_REGISTER_TYPE_UD;
-      dst.type = BRW_REGISTER_TYPE_UD;
-      emit(AND(dst, op[0], brw_imm_ud(0x80000000u)));
+         op[0].type = BRW_REGISTER_TYPE_UD;
+         dst.type = BRW_REGISTER_TYPE_UD;
+         emit(AND(dst, op[0], brw_imm_ud(0x80000000u)));
 
-      inst = emit(OR(dst, src_reg(dst), brw_imm_ud(0x3f800000u)));
-      inst->predicate = BRW_PREDICATE_NORMAL;
-      dst.type = BRW_REGISTER_TYPE_F;
+         inst = emit(OR(dst, src_reg(dst), brw_imm_ud(0x3f800000u)));
+         inst->predicate = BRW_PREDICATE_NORMAL;
+         dst.type = BRW_REGISTER_TYPE_F;
+
+         if (instr->dest.saturate) {
+            inst = emit(MOV(dst, src_reg(dst)));
+            inst->saturate = true;
+         }
+      } else {
+         /* For doubles we do the same but we need to consider:
+          *
+          * - We use a MOV with conditional_mod instead of a CMP so that we can
+          *   skip loading a 0.0 immediate. We use a source modifier on the
+          *   source of the MOV so that we flush denormalized values to 0.
+          *   Since we want to compare against 0, this won't alter the result.
+          * - We need to extract the high 32-bit of each DF where the sign
+          *   is stored.
+          * - We need to produce a DF result.
+          */
+
+         /* Check for zero */
+         src_reg value = op[0];
+         value.abs = true;
+         inst = emit(MOV(dst_null_df(), value));
+         inst->conditional_mod = BRW_CONDITIONAL_NZ;
+
+         /* AND each high 32-bit channel with 0x80000000u */
+         dst_reg tmp = dst_reg(this, glsl_type::uvec4_type);
+         emit(VEC4_OPCODE_PICK_HIGH_32BIT, tmp, op[0]);
+         emit(AND(tmp, src_reg(tmp), brw_imm_ud(0x80000000u)));
+
+         /* Add 1.0 to each channel, predicated to skip the cases where the
+          * channel's value was 0
+          */
+         inst = emit(OR(tmp, src_reg(tmp), brw_imm_ud(0x3f800000u)));
+         inst->predicate = BRW_PREDICATE_NORMAL;
 
-      if (instr->dest.saturate) {
-         inst = emit(MOV(dst, src_reg(dst)));
-         inst->saturate = true;
+         /* Now convert the result from float to double */
+         emit_conversion_to_double(dst, src_reg(tmp), instr->dest.saturate,
+                                   BRW_REGISTER_TYPE_F);
       }
       break;