i965/fs: Fix extract_i8/u8 to a 64-bit destination
[mesa.git] / src / intel / compiler / brw_fs_nir.cpp
index 201a23f83bda790736515d2e0beb1a92340e026c..ca3707db90f76ed4b955d10d96d162a486c4e5a0 100644 (file)
@@ -81,10 +81,23 @@ fs_visitor::nir_setup_outputs()
 void
 fs_visitor::nir_setup_uniforms()
 {
-   if (dispatch_width != min_dispatch_width)
+   /* Only the first compile gets to set up uniforms. */
+   if (push_constant_loc) {
+      assert(pull_constant_loc);
       return;
+   }
 
    uniforms = nir->num_uniforms / 4;
+
+   if (stage == MESA_SHADER_COMPUTE) {
+      /* Add a uniform for the thread local id.  It must be the last uniform
+       * on the list.
+       */
+      assert(uniforms == prog_data->nr_params);
+      uint32_t *param = brw_stage_prog_data_add_params(prog_data, 1);
+      *param = BRW_PARAM_BUILTIN_SUBGROUP_ID;
+      subgroup_id = fs_reg(UNIFORM, uniforms++, BRW_REGISTER_TYPE_UD);
+   }
 }
 
 static bool
@@ -218,6 +231,24 @@ fs_visitor::nir_emit_system_values()
       nir_system_values[i] = fs_reg();
    }
 
+   /* Always emit SUBGROUP_INVOCATION.  Dead code will clean it up if we
+    * never end up using it.
+    */
+   {
+      const fs_builder abld = bld.annotate("gl_SubgroupInvocation", NULL);
+      fs_reg &reg = nir_system_values[SYSTEM_VALUE_SUBGROUP_INVOCATION];
+      reg = abld.vgrf(BRW_REGISTER_TYPE_W);
+
+      const fs_builder allbld8 = abld.group(8, 0).exec_all();
+      allbld8.MOV(reg, brw_imm_v(0x76543210));
+      if (dispatch_width > 8)
+         allbld8.ADD(byte_offset(reg, 16), reg, brw_imm_uw(8u));
+      if (dispatch_width > 16) {
+         const fs_builder allbld16 = abld.group(16, 0).exec_all();
+         allbld16.ADD(byte_offset(reg, 32), reg, brw_imm_uw(16u));
+      }
+   }
+
    nir_foreach_function(function, nir) {
       assert(strcmp(function->name, "main") == 0);
       assert(function->impl);
@@ -264,7 +295,7 @@ brw_reg_type_from_bit_size(const unsigned bit_size,
       case 32:
          return BRW_REGISTER_TYPE_D;
       case 64:
-         return BRW_REGISTER_TYPE_DF;
+         return BRW_REGISTER_TYPE_Q;
       default:
          unreachable("Invalid bit size");
       }
@@ -277,7 +308,7 @@ brw_reg_type_from_bit_size(const unsigned bit_size,
       case 32:
          return BRW_REGISTER_TYPE_UD;
       case 64:
-         return BRW_REGISTER_TYPE_DF;
+         return BRW_REGISTER_TYPE_UQ;
       default:
          unreachable("Invalid bit size");
       }
@@ -694,8 +725,12 @@ fs_visitor::nir_emit_alu(const fs_builder &bld, nir_alu_instr *instr)
       break;
 
    case nir_op_f2f64:
+   case nir_op_f2i64:
+   case nir_op_f2u64:
    case nir_op_i2f64:
+   case nir_op_i2i64:
    case nir_op_u2f64:
+   case nir_op_u2u64:
       /* CHV PRM, vol07, 3D Media GPGPU Engine, Register Region Restrictions:
        *
        *    "When source or destination is 64b (...), regioning in Align1
@@ -723,12 +758,8 @@ fs_visitor::nir_emit_alu(const fs_builder &bld, nir_alu_instr *instr)
    case nir_op_f2f32:
    case nir_op_f2i32:
    case nir_op_f2u32:
-   case nir_op_f2i64:
-   case nir_op_f2u64:
    case nir_op_i2i32:
-   case nir_op_i2i64:
    case nir_op_u2u32:
-   case nir_op_u2u64:
       inst = bld.MOV(result, op[0]);
       inst->saturate = instr->dest.saturate;
       break;
@@ -1094,12 +1125,13 @@ fs_visitor::nir_emit_alu(const fs_builder &bld, nir_alu_instr *instr)
          if (instr->op == nir_op_f2b) {
             zero = vgrf(glsl_type::double_type);
             tmp = vgrf(glsl_type::double_type);
+            bld.MOV(zero, setup_imm_df(bld, 0.0));
          } else {
             zero = vgrf(glsl_type::int64_t_type);
             tmp = vgrf(glsl_type::int64_t_type);
+            bld.MOV(zero, brw_imm_q(0));
          }
 
-         bld.MOV(zero, setup_imm_df(bld, 0.0));
          /* A SIMD16 execution needs to be split in two instructions, so use
           * a vgrf instead of the flag register as dst so instruction splitting
           * works
@@ -1363,10 +1395,31 @@ fs_visitor::nir_emit_alu(const fs_builder &bld, nir_alu_instr *instr)
 
    case nir_op_extract_u8:
    case nir_op_extract_i8: {
-      const brw_reg_type type = brw_int_type(1, instr->op == nir_op_extract_i8);
       nir_const_value *byte = nir_src_as_const_value(instr->src[1].src);
       assert(byte != NULL);
-      bld.MOV(result, subscript(op[0], type, byte->u32[0]));
+
+      /* The PRMs say:
+       *
+       *    BDW+
+       *    There is no direct conversion from B/UB to Q/UQ or Q/UQ to B/UB.
+       *    Use two instructions and a word or DWord intermediate integer type.
+       */
+      if (nir_dest_bit_size(instr->dest.dest) == 64) {
+         const brw_reg_type type = brw_int_type(2, instr->op == nir_op_extract_i8);
+
+         if (instr->op == nir_op_extract_i8) {
+            /* If we need to sign extend, extract to a word first */
+            fs_reg w_temp = bld.vgrf(BRW_REGISTER_TYPE_W);
+            bld.MOV(w_temp, subscript(op[0], type, byte->u32[0]));
+            bld.MOV(result, w_temp);
+         } else {
+            /* Otherwise use an AND with 0xff and a word type */
+            bld.AND(result, subscript(op[0], type, byte->u32[0] / 2), brw_imm_uw(0xff));
+         }
+      } else {
+         const brw_reg_type type = brw_int_type(1, instr->op == nir_op_extract_i8);
+         bld.MOV(result, subscript(op[0], type, byte->u32[0]));
+      }
       break;
    }
 
@@ -1410,9 +1463,17 @@ fs_visitor::nir_emit_load_const(const fs_builder &bld,
       break;
 
    case 64:
-      for (unsigned i = 0; i < instr->def.num_components; i++)
-         bld.MOV(offset(reg, bld, i),
-                 setup_imm_df(bld, instr->value.f64[i]));
+      assert(devinfo->gen >= 7);
+      if (devinfo->gen == 7) {
+         /* We don't get 64-bit integer types until gen8 */
+         for (unsigned i = 0; i < instr->def.num_components; i++) {
+            bld.MOV(retype(offset(reg, bld, i), BRW_REGISTER_TYPE_DF),
+                    setup_imm_df(bld, instr->value.f64[i]));
+         }
+      } else {
+         for (unsigned i = 0; i < instr->def.num_components; i++)
+            bld.MOV(offset(reg, bld, i), brw_imm_q(instr->value.i64[i]));
+      }
       break;
 
    default:
@@ -3409,6 +3470,10 @@ fs_visitor::nir_emit_cs_intrinsic(const fs_builder &bld,
       cs_prog_data->uses_barrier = true;
       break;
 
+   case nir_intrinsic_load_subgroup_id:
+      bld.MOV(retype(dest, BRW_REGISTER_TYPE_UD), subgroup_id);
+      break;
+
    case nir_intrinsic_load_local_invocation_id:
    case nir_intrinsic_load_work_group_id: {
       gl_system_value sv = nir_system_value_from_intrinsic(instr->intrinsic);
@@ -4149,24 +4214,10 @@ fs_visitor::nir_emit_intrinsic(const fs_builder &bld, nir_intrinsic_instr *instr
       break;
    }
 
-   case nir_intrinsic_load_subgroup_size:
-      bld.MOV(retype(dest, BRW_REGISTER_TYPE_D), brw_imm_d(dispatch_width));
-      break;
-
-   case nir_intrinsic_load_subgroup_invocation: {
-      fs_reg tmp = bld.vgrf(BRW_REGISTER_TYPE_UW);
-      dest = retype(dest, BRW_REGISTER_TYPE_UD);
-      const fs_builder allbld8 = bld.group(8, 0).exec_all();
-      allbld8.MOV(tmp, brw_imm_v(0x76543210));
-      if (dispatch_width > 8)
-         allbld8.ADD(byte_offset(tmp, 16), tmp, brw_imm_uw(8u));
-      if (dispatch_width > 16) {
-         const fs_builder allbld16 = bld.group(16, 0).exec_all();
-         allbld16.ADD(byte_offset(tmp, 32), tmp, brw_imm_uw(16u));
-      }
-      bld.MOV(dest, tmp);
+   case nir_intrinsic_load_subgroup_invocation:
+      bld.MOV(retype(dest, BRW_REGISTER_TYPE_D),
+              nir_system_values[SYSTEM_VALUE_SUBGROUP_INVOCATION]);
       break;
-   }
 
    case nir_intrinsic_load_subgroup_eq_mask:
    case nir_intrinsic_load_subgroup_ge_mask:
@@ -4176,7 +4227,7 @@ fs_visitor::nir_emit_intrinsic(const fs_builder &bld, nir_intrinsic_instr *instr
       unreachable("not reached");
 
    case nir_intrinsic_vote_any: {
-      const fs_builder ubld = bld.exec_all();
+      const fs_builder ubld = bld.exec_all().group(1, 0);
 
       /* The any/all predicates do not consider channel enables. To prevent
        * dead channels from affecting the result, we initialize the flag with
@@ -4208,7 +4259,7 @@ fs_visitor::nir_emit_intrinsic(const fs_builder &bld, nir_intrinsic_instr *instr
       break;
    }
    case nir_intrinsic_vote_all: {
-      const fs_builder ubld = bld.exec_all();
+      const fs_builder ubld = bld.exec_all().group(1, 0);
 
       /* The any/all predicates do not consider channel enables. To prevent
        * dead channels from affecting the result, we initialize the flag with
@@ -4242,7 +4293,7 @@ fs_visitor::nir_emit_intrinsic(const fs_builder &bld, nir_intrinsic_instr *instr
    case nir_intrinsic_vote_eq: {
       fs_reg value = get_nir_src(instr->src[0]);
       fs_reg uniformized = bld.emit_uniformize(value);
-      const fs_builder ubld = bld.exec_all();
+      const fs_builder ubld = bld.exec_all().group(1, 0);
 
       /* The any/all predicates do not consider channel enables. To prevent
        * dead channels from affecting the result, we initialize the flag with
@@ -4287,7 +4338,7 @@ fs_visitor::nir_emit_intrinsic(const fs_builder &bld, nir_intrinsic_instr *instr
       if (dispatch_width == 32)
          flag.type = BRW_REGISTER_TYPE_UD;
 
-      bld.exec_all().MOV(flag, brw_imm_ud(0u));
+      bld.exec_all().group(1, 0).MOV(flag, brw_imm_ud(0u));
       bld.CMP(bld.null_reg_ud(), value, brw_imm_ud(0u), BRW_CONDITIONAL_NZ);
 
       if (instr->dest.ssa.bit_size > 32) {
@@ -4307,15 +4358,13 @@ fs_visitor::nir_emit_intrinsic(const fs_builder &bld, nir_intrinsic_instr *instr
       bld.exec_all().emit(SHADER_OPCODE_BROADCAST, tmp, value,
                           bld.emit_uniformize(invocation));
 
-      bld.MOV(retype(dest, BRW_REGISTER_TYPE_D),
-              fs_reg(component(tmp, 0)));
+      bld.MOV(retype(dest, value.type), fs_reg(component(tmp, 0)));
       break;
    }
 
    case nir_intrinsic_read_first_invocation: {
       const fs_reg value = get_nir_src(instr->src[0]);
-      bld.MOV(retype(dest, BRW_REGISTER_TYPE_D),
-              bld.emit_uniformize(value));
+      bld.MOV(retype(dest, value.type), bld.emit_uniformize(value));
       break;
    }