intel/vec4: Try to emit a VF source in try_immediate_source
[mesa.git] / src / intel / compiler / brw_fs_visitor.cpp
index 481d9c51e7a700c9dce84442d985c29107694115..eb23b4087d06af3a597f53b9450ae712cd50a1c4 100644 (file)
@@ -35,14 +35,16 @@ using namespace brw;
 /* Sample from the MCS surface attached to this multisample texture. */
 fs_reg
 fs_visitor::emit_mcs_fetch(const fs_reg &coordinate, unsigned components,
-                           const fs_reg &texture)
+                           const fs_reg &texture,
+                           const fs_reg &texture_handle)
 {
    const fs_reg dest = vgrf(glsl_type::uvec4_type);
 
    fs_reg srcs[TEX_LOGICAL_NUM_SRCS];
    srcs[TEX_LOGICAL_SRC_COORDINATE] = coordinate;
    srcs[TEX_LOGICAL_SRC_SURFACE] = texture;
-   srcs[TEX_LOGICAL_SRC_SAMPLER] = texture;
+   srcs[TEX_LOGICAL_SRC_SAMPLER] = brw_imm_ud(0);
+   srcs[TEX_LOGICAL_SRC_SURFACE_HANDLE] = texture_handle;
    srcs[TEX_LOGICAL_SRC_COORD_COMPONENTS] = brw_imm_d(components);
    srcs[TEX_LOGICAL_SRC_GRAD_COMPONENTS] = brw_imm_d(0);
 
@@ -103,6 +105,7 @@ fs_visitor::emit_dummy_fs()
    fs_inst *write;
    write = bld.emit(FS_OPCODE_FB_WRITE);
    write->eot = true;
+   write->last_rt = true;
    if (devinfo->gen >= 6) {
       write->base_mrf = 2;
       write->mlen = 4 * reg_width;
@@ -125,7 +128,8 @@ fs_visitor::emit_dummy_fs()
    stage_prog_data->nr_pull_params = 0;
    stage_prog_data->curb_read_length = 0;
    stage_prog_data->dispatch_grf_start_reg = 2;
-   wm_prog_data->dispatch_grf_start_reg_2 = 2;
+   wm_prog_data->dispatch_grf_start_reg_16 = 2;
+   wm_prog_data->dispatch_grf_start_reg_32 = 2;
    grf_used = 1; /* Gen4-5 don't allow zero GRF blocks */
 
    calculate_cfg();
@@ -135,17 +139,15 @@ fs_visitor::emit_dummy_fs()
  * data.  It will get adjusted to be a real location before
  * generate_code() time.
  */
-struct brw_reg
+fs_reg
 fs_visitor::interp_reg(int location, int channel)
 {
    assert(stage == MESA_SHADER_FRAGMENT);
    struct brw_wm_prog_data *prog_data = brw_wm_prog_data(this->prog_data);
-   int regnr = prog_data->urb_setup[location] * 2 + channel / 2;
-   int stride = (channel & 1) * 4;
-
+   int regnr = prog_data->urb_setup[location] * 4 + channel;
    assert(prog_data->urb_setup[location] != -1);
 
-   return brw_vec1_grf(regnr, stride);
+   return fs_reg(ATTR, regnr, BRW_REGISTER_TYPE_F);
 }
 
 /** Emits the interpolation for the varying inputs. */
@@ -202,101 +204,106 @@ fs_visitor::emit_interpolation_setup_gen4()
 void
 fs_visitor::emit_interpolation_setup_gen6()
 {
-   struct brw_reg g1_uw = retype(brw_vec1_grf(1, 0), BRW_REGISTER_TYPE_UW);
-
    fs_builder abld = bld.annotate("compute pixel centers");
-   if (devinfo->gen >= 8 || dispatch_width == 8) {
-      /* The "Register Region Restrictions" page says for BDW (and newer,
-       * presumably):
-       *
-       *     "When destination spans two registers, the source may be one or
-       *      two registers. The destination elements must be evenly split
-       *      between the two registers."
-       *
-       * Thus we can do a single add(16) in SIMD8 or an add(32) in SIMD16 to
-       * compute our pixel centers.
-       */
-      fs_reg int_pixel_xy(VGRF, alloc.allocate(dispatch_width / 8),
-                          BRW_REGISTER_TYPE_UW);
-
-      const fs_builder dbld = abld.exec_all().group(dispatch_width * 2, 0);
-      dbld.ADD(int_pixel_xy,
-               fs_reg(stride(suboffset(g1_uw, 4), 1, 4, 0)),
-               fs_reg(brw_imm_v(0x11001010)));
-
-      this->pixel_x = vgrf(glsl_type::float_type);
-      this->pixel_y = vgrf(glsl_type::float_type);
-      abld.emit(FS_OPCODE_PIXEL_X, this->pixel_x, int_pixel_xy);
-      abld.emit(FS_OPCODE_PIXEL_Y, this->pixel_y, int_pixel_xy);
-   } else {
-      /* The "Register Region Restrictions" page says for SNB, IVB, HSW:
-       *
-       *     "When destination spans two registers, the source MUST span two
-       *      registers."
-       *
-       * Since the GRF source of the ADD will only read a single register, we
-       * must do two separate ADDs in SIMD16.
-       */
-      fs_reg int_pixel_x = vgrf(glsl_type::uint_type);
-      fs_reg int_pixel_y = vgrf(glsl_type::uint_type);
-      int_pixel_x.type = BRW_REGISTER_TYPE_UW;
-      int_pixel_y.type = BRW_REGISTER_TYPE_UW;
-      abld.ADD(int_pixel_x,
-               fs_reg(stride(suboffset(g1_uw, 4), 2, 4, 0)),
-               fs_reg(brw_imm_v(0x10101010)));
-      abld.ADD(int_pixel_y,
-               fs_reg(stride(suboffset(g1_uw, 5), 2, 4, 0)),
-               fs_reg(brw_imm_v(0x11001100)));
-
-      /* As of gen6, we can no longer mix float and int sources.  We have
-       * to turn the integer pixel centers into floats for their actual
-       * use.
-       */
-      this->pixel_x = vgrf(glsl_type::float_type);
-      this->pixel_y = vgrf(glsl_type::float_type);
-      abld.MOV(this->pixel_x, int_pixel_x);
-      abld.MOV(this->pixel_y, int_pixel_y);
+
+   this->pixel_x = vgrf(glsl_type::float_type);
+   this->pixel_y = vgrf(glsl_type::float_type);
+
+   for (unsigned i = 0; i < DIV_ROUND_UP(dispatch_width, 16); i++) {
+      const fs_builder hbld = abld.group(MIN2(16, dispatch_width), i);
+      struct brw_reg gi_uw = retype(brw_vec1_grf(1 + i, 0), BRW_REGISTER_TYPE_UW);
+
+      if (devinfo->gen >= 8 || dispatch_width == 8) {
+         /* The "Register Region Restrictions" page says for BDW (and newer,
+          * presumably):
+          *
+          *     "When destination spans two registers, the source may be one or
+          *      two registers. The destination elements must be evenly split
+          *      between the two registers."
+          *
+          * Thus we can do a single add(16) in SIMD8 or an add(32) in SIMD16
+          * to compute our pixel centers.
+          */
+         const fs_builder dbld =
+            abld.exec_all().group(hbld.dispatch_width() * 2, 0);
+         fs_reg int_pixel_xy = dbld.vgrf(BRW_REGISTER_TYPE_UW);
+
+         dbld.ADD(int_pixel_xy,
+                  fs_reg(stride(suboffset(gi_uw, 4), 1, 4, 0)),
+                  fs_reg(brw_imm_v(0x11001010)));
+
+         hbld.emit(FS_OPCODE_PIXEL_X, offset(pixel_x, hbld, i), int_pixel_xy);
+         hbld.emit(FS_OPCODE_PIXEL_Y, offset(pixel_y, hbld, i), int_pixel_xy);
+      } else {
+         /* The "Register Region Restrictions" page says for SNB, IVB, HSW:
+          *
+          *     "When destination spans two registers, the source MUST span
+          *      two registers."
+          *
+          * Since the GRF source of the ADD will only read a single register,
+          * we must do two separate ADDs in SIMD16.
+          */
+         const fs_reg int_pixel_x = hbld.vgrf(BRW_REGISTER_TYPE_UW);
+         const fs_reg int_pixel_y = hbld.vgrf(BRW_REGISTER_TYPE_UW);
+
+         hbld.ADD(int_pixel_x,
+                  fs_reg(stride(suboffset(gi_uw, 4), 2, 4, 0)),
+                  fs_reg(brw_imm_v(0x10101010)));
+         hbld.ADD(int_pixel_y,
+                  fs_reg(stride(suboffset(gi_uw, 5), 2, 4, 0)),
+                  fs_reg(brw_imm_v(0x11001100)));
+
+         /* As of gen6, we can no longer mix float and int sources.  We have
+          * to turn the integer pixel centers into floats for their actual
+          * use.
+          */
+         hbld.MOV(offset(pixel_x, hbld, i), int_pixel_x);
+         hbld.MOV(offset(pixel_y, hbld, i), int_pixel_y);
+      }
    }
 
    abld = bld.annotate("compute pos.w");
-   this->pixel_w = fs_reg(brw_vec8_grf(payload.source_w_reg, 0));
+   this->pixel_w = fetch_payload_reg(abld, payload.source_w_reg);
    this->wpos_w = vgrf(glsl_type::float_type);
    abld.emit(SHADER_OPCODE_RCP, this->wpos_w, this->pixel_w);
 
    struct brw_wm_prog_data *wm_prog_data = brw_wm_prog_data(prog_data);
+
+   for (int i = 0; i < BRW_BARYCENTRIC_MODE_COUNT; ++i) {
+      this->delta_xy[i] = fetch_payload_reg(
+         bld, payload.barycentric_coord_reg[i], BRW_REGISTER_TYPE_F, 2);
+   }
+
    uint32_t centroid_modes = wm_prog_data->barycentric_interp_modes &
       (1 << BRW_BARYCENTRIC_PERSPECTIVE_CENTROID |
        1 << BRW_BARYCENTRIC_NONPERSPECTIVE_CENTROID);
 
-   for (int i = 0; i < BRW_BARYCENTRIC_MODE_COUNT; ++i) {
-      uint8_t reg = payload.barycentric_coord_reg[i];
-      this->delta_xy[i] = fs_reg(brw_vec16_grf(reg, 0));
-
-      if (devinfo->needs_unlit_centroid_workaround &&
-          (centroid_modes & (1 << i))) {
-         /* Get the pixel/sample mask into f0 so that we know which
-          * pixels are lit.  Then, for each channel that is unlit,
-          * replace the centroid data with non-centroid data.
-          */
-         bld.emit(FS_OPCODE_MOV_DISPATCH_TO_FLAGS);
-
-         uint8_t pixel_reg = payload.barycentric_coord_reg[i - 1];
-
-         set_predicate_inv(BRW_PREDICATE_NORMAL, true,
-                           bld.half(0).MOV(brw_vec8_grf(reg, 0),
-                                           brw_vec8_grf(pixel_reg, 0)));
-         set_predicate_inv(BRW_PREDICATE_NORMAL, true,
-                           bld.half(0).MOV(brw_vec8_grf(reg + 1, 0),
-                                           brw_vec8_grf(pixel_reg + 1, 0)));
-         if (dispatch_width == 16) {
-            set_predicate_inv(BRW_PREDICATE_NORMAL, true,
-                              bld.half(1).MOV(brw_vec8_grf(reg + 2, 0),
-                                              brw_vec8_grf(pixel_reg + 2, 0)));
-            set_predicate_inv(BRW_PREDICATE_NORMAL, true,
-                              bld.half(1).MOV(brw_vec8_grf(reg + 3, 0),
-                                              brw_vec8_grf(pixel_reg + 3, 0)));
+   if (devinfo->needs_unlit_centroid_workaround && centroid_modes) {
+      /* Get the pixel/sample mask into f0 so that we know which
+       * pixels are lit.  Then, for each channel that is unlit,
+       * replace the centroid data with non-centroid data.
+       */
+      for (unsigned i = 0; i < DIV_ROUND_UP(dispatch_width, 16); i++) {
+         bld.exec_all().group(1, 0)
+            .MOV(retype(brw_flag_reg(0, i), BRW_REGISTER_TYPE_UW),
+                 retype(brw_vec1_grf(1 + i, 7), BRW_REGISTER_TYPE_UW));
+      }
+
+      for (int i = 0; i < BRW_BARYCENTRIC_MODE_COUNT; ++i) {
+         if (!(centroid_modes & (1 << i)))
+            continue;
+
+         const fs_reg &pixel_delta_xy = delta_xy[i - 1];
+
+         for (unsigned q = 0; q < dispatch_width / 8; q++) {
+            for (unsigned c = 0; c < 2; c++) {
+               const unsigned idx = c + (q & 2) + (q & 1) * dispatch_width / 8;
+               set_predicate_inv(
+                  BRW_PREDICATE_NORMAL, true,
+                  bld.half(q).MOV(horiz_offset(delta_xy[i], idx * 8),
+                                  horiz_offset(pixel_delta_xy, idx * 8)));
+            }
          }
-         assert(dispatch_width != 32); /* not implemented yet */
       }
    }
 }
@@ -364,16 +371,14 @@ fs_visitor::emit_single_fb_write(const fs_builder &bld,
    struct brw_wm_prog_data *prog_data = brw_wm_prog_data(this->prog_data);
 
    /* Hand over gl_FragDepth or the payload depth. */
-   const fs_reg dst_depth = (payload.dest_depth_reg ?
-                             fs_reg(brw_vec8_grf(payload.dest_depth_reg, 0)) :
-                             fs_reg());
+   const fs_reg dst_depth = fetch_payload_reg(bld, payload.dest_depth_reg);
    fs_reg src_depth, src_stencil;
 
    if (source_depth_to_render_target) {
       if (nir->info.outputs_written & BITFIELD64_BIT(FRAG_RESULT_DEPTH))
          src_depth = frag_depth;
       else
-         src_depth = fs_reg(brw_vec8_grf(payload.source_depth_reg, 0));
+         src_depth = fetch_payload_reg(bld, payload.source_depth_reg);
    }
 
    if (nir->info.outputs_written & BITFIELD64_BIT(FRAG_RESULT_STENCIL))
@@ -396,6 +401,82 @@ fs_visitor::emit_single_fb_write(const fs_builder &bld,
    return write;
 }
 
+void
+fs_visitor::emit_alpha_to_coverage_workaround(const fs_reg &src0_alpha)
+{
+   /* We need to compute alpha to coverage dithering manually in shader
+    * and replace sample mask store with the bitwise-AND of sample mask and
+    * alpha to coverage dithering.
+    *
+    * The following formula is used to compute final sample mask:
+    *  m = int(16.0 * clamp(src0_alpha, 0.0, 1.0))
+    *  dither_mask = 0x1111 * ((0xfea80 >> (m & ~3)) & 0xf) |
+    *     0x0808 * (m & 2) | 0x0100 * (m & 1)
+    *  sample_mask = sample_mask & dither_mask
+    *
+    * It gives a number of ones proportional to the alpha for 2, 4, 8 or 16
+    * least significant bits of the result:
+    *  0.0000 0000000000000000
+    *  0.0625 0000000100000000
+    *  0.1250 0001000000010000
+    *  0.1875 0001000100010000
+    *  0.2500 1000100010001000
+    *  0.3125 1000100110001000
+    *  0.3750 1001100010011000
+    *  0.4375 1001100110011000
+    *  0.5000 1010101010101010
+    *  0.5625 1010101110101010
+    *  0.6250 1011101010111010
+    *  0.6875 1011101110111010
+    *  0.7500 1110111011101110
+    *  0.8125 1110111111101110
+    *  0.8750 1111111011111110
+    *  0.9375 1111111111111110
+    *  1.0000 1111111111111111
+    */
+   const fs_builder abld = bld.annotate("compute alpha_to_coverage & "
+      "sample_mask");
+
+   /* clamp(src0_alpha, 0.f, 1.f) */
+   const fs_reg float_tmp = abld.vgrf(BRW_REGISTER_TYPE_F);
+   set_saturate(true, abld.MOV(float_tmp, src0_alpha));
+
+   /* 16.0 * clamp(src0_alpha, 0.0, 1.0) */
+   abld.MUL(float_tmp, float_tmp, brw_imm_f(16.0));
+
+   /* m = int(16.0 * clamp(src0_alpha, 0.0, 1.0)) */
+   const fs_reg m = abld.vgrf(BRW_REGISTER_TYPE_UW);
+   abld.MOV(m, float_tmp);
+
+   /* 0x1111 * ((0xfea80 >> (m & ~3)) & 0xf) */
+   const fs_reg int_tmp_1 = abld.vgrf(BRW_REGISTER_TYPE_UW);
+   const fs_reg shift_const = abld.vgrf(BRW_REGISTER_TYPE_UD);
+   abld.MOV(shift_const, brw_imm_d(0xfea80));
+   abld.AND(int_tmp_1, m, brw_imm_uw(~3));
+   abld.SHR(int_tmp_1, shift_const, int_tmp_1);
+   abld.AND(int_tmp_1, int_tmp_1, brw_imm_uw(0xf));
+   abld.MUL(int_tmp_1, int_tmp_1, brw_imm_uw(0x1111));
+
+   /* 0x0808 * (m & 2) */
+   const fs_reg int_tmp_2 = abld.vgrf(BRW_REGISTER_TYPE_UW);
+   abld.AND(int_tmp_2, m, brw_imm_uw(2));
+   abld.MUL(int_tmp_2, int_tmp_2, brw_imm_uw(0x0808));
+
+   abld.OR(int_tmp_1, int_tmp_1, int_tmp_2);
+
+   /* 0x0100 * (m & 1) */
+   const fs_reg int_tmp_3 = abld.vgrf(BRW_REGISTER_TYPE_UW);
+   abld.AND(int_tmp_3, m, brw_imm_uw(1));
+   abld.MUL(int_tmp_3, int_tmp_3, brw_imm_uw(0x0100));
+
+   abld.OR(int_tmp_1, int_tmp_1, int_tmp_3);
+
+   /* sample_mask = sample_mask & dither_mask */
+   const fs_reg mask = abld.vgrf(BRW_REGISTER_TYPE_UD);
+   abld.AND(mask, sample_mask, int_tmp_1);
+   sample_mask = mask;
+}
+
 void
 fs_visitor::emit_fb_writes()
 {
@@ -424,6 +505,22 @@ fs_visitor::emit_fb_writes()
                            "in SIMD16+ mode.\n");
    }
 
+   /* ANV doesn't know about sample mask output during the wm key creation
+    * so we compute if we need replicate alpha and emit alpha to coverage
+    * workaround here.
+    */
+   prog_data->replicate_alpha = key->alpha_test_replicate_alpha ||
+      (key->nr_color_regions > 1 && key->alpha_to_coverage &&
+       (sample_mask.file == BAD_FILE || devinfo->gen == 6));
+
+   /* From the SKL PRM, Volume 7, "Alpha Coverage":
+    *  "If Pixel Shader outputs oMask, AlphaToCoverage is disabled in
+    *   hardware, regardless of the state setting for this feature."
+    */
+   if (devinfo->gen > 6 && key->alpha_to_coverage &&
+       sample_mask.file != BAD_FILE && this->outputs[0].file != BAD_FILE)
+      emit_alpha_to_coverage_workaround(offset(this->outputs[0], bld, 3));
+
    for (int target = 0; target < key->nr_color_regions; target++) {
       /* Skip over outputs that weren't written. */
       if (this->outputs[target].file == BAD_FILE)
@@ -433,7 +530,7 @@ fs_visitor::emit_fb_writes()
          ralloc_asprintf(this->mem_ctx, "FB write target %d", target));
 
       fs_reg src0_alpha;
-      if (devinfo->gen >= 6 && key->replicate_alpha && target != 0)
+      if (devinfo->gen >= 6 && prog_data->replicate_alpha && target != 0)
          src0_alpha = offset(outputs[0], bld, 3);
 
       inst = emit_single_fb_write(abld, this->outputs[target],
@@ -441,7 +538,8 @@ fs_visitor::emit_fb_writes()
       inst->target = target;
    }
 
-   prog_data->dual_src_blend = (this->dual_src_output.file != BAD_FILE);
+   prog_data->dual_src_blend = (this->dual_src_output.file != BAD_FILE &&
+                                this->outputs[0].file != BAD_FILE);
    assert(!prog_data->dual_src_blend || key->nr_color_regions == 1);
 
    if (inst == NULL) {
@@ -461,6 +559,7 @@ fs_visitor::emit_fb_writes()
       inst->target = 0;
    }
 
+   inst->last_rt = true;
    inst->eot = true;
 }
 
@@ -597,7 +696,7 @@ fs_visitor::emit_urb_writes(const fs_reg &gs_vertex_count)
          per_slot_offsets = brw_imm_ud(output_vertex_size_owords *
                                        gs_vertex_count.ud);
       } else {
-         per_slot_offsets = vgrf(glsl_type::int_type);
+         per_slot_offsets = vgrf(glsl_type::uint_type);
          bld.MUL(per_slot_offsets, gs_vertex_count,
                  brw_imm_ud(output_vertex_size_owords));
       }
@@ -722,7 +821,13 @@ fs_visitor::emit_urb_writes(const fs_reg &gs_vertex_count)
                            header_size);
 
          fs_inst *inst = abld.emit(opcode, reg_undef, payload);
-         inst->eot = slot == last_slot && stage != MESA_SHADER_GEOMETRY;
+
+         /* For ICL WA 1805992985 one needs additional write in the end. */
+         if (devinfo->gen == 11 && stage == MESA_SHADER_TESS_EVAL)
+            inst->eot = false;
+         else
+            inst->eot = slot == last_slot && stage != MESA_SHADER_GEOMETRY;
+
          inst->mlen = length + header_size;
          inst->offset = urb_offset;
          urb_offset = starting_urb_offset + slot + 1;
@@ -758,6 +863,49 @@ fs_visitor::emit_urb_writes(const fs_reg &gs_vertex_count)
       inst->mlen = 2;
       inst->offset = 1;
       return;
+   } 
+   /* ICL WA 1805992985:
+    *
+    * ICLLP GPU hangs on one of tessellation vkcts tests with DS not done. The
+    * send cycle, which is a urb write with an eot must be 4 phases long and
+    * all 8 lanes must valid.
+    */
+   if (devinfo->gen == 11 && stage == MESA_SHADER_TESS_EVAL) {
+      fs_reg payload = fs_reg(VGRF, alloc.allocate(6), BRW_REGISTER_TYPE_UD);
+
+      /* Workaround requires all 8 channels (lanes) to be valid. This is
+       * understood to mean they all need to be alive. First trick is to find
+       * a live channel and copy its urb handle for all the other channels to
+       * make sure all handles are valid.
+       */
+      bld.exec_all().MOV(payload, bld.emit_uniformize(urb_handle));
+
+      /* Second trick is to use masked URB write where one can tell the HW to
+       * actually write data only for selected channels even though all are
+       * active.
+       * Third trick is to take advantage of the must-be-zero (MBZ) area in
+       * the very beginning of the URB.
+       *
+       * One masks data to be written only for the first channel and uses
+       * offset zero explicitly to land data to the MBZ area avoiding trashing
+       * any other part of the URB.
+       *
+       * Since the WA says that the write needs to be 4 phases long one uses
+       * 4 slots data. All are explicitly zeros in order to to keep the MBZ
+       * area written as zeros.
+       */
+      bld.exec_all().MOV(offset(payload, bld, 1), brw_imm_ud(0x10000u));
+      bld.exec_all().MOV(offset(payload, bld, 2), brw_imm_ud(0u));
+      bld.exec_all().MOV(offset(payload, bld, 3), brw_imm_ud(0u));
+      bld.exec_all().MOV(offset(payload, bld, 4), brw_imm_ud(0u));
+      bld.exec_all().MOV(offset(payload, bld, 5), brw_imm_ud(0u));
+
+      fs_inst *inst = bld.exec_all().emit(SHADER_OPCODE_URB_WRITE_SIMD8_MASKED,
+                                          reg_undef, payload);
+      inst->eot = true;
+      inst->mlen = 6;
+      inst->offset = 0;
    }
 }
 
@@ -786,9 +934,19 @@ fs_visitor::emit_cs_terminate()
 void
 fs_visitor::emit_barrier()
 {
-   assert(devinfo->gen >= 7);
-   const uint32_t barrier_id_mask =
-      devinfo->gen >= 9 ? 0x8f000000u : 0x0f000000u;
+   uint32_t barrier_id_mask;
+   switch (devinfo->gen) {
+   case 7:
+   case 8:
+      barrier_id_mask = 0x0f000000u; break;
+   case 9:
+   case 10:
+      barrier_id_mask = 0x8f000000u; break;
+   case 11:
+      barrier_id_mask = 0x7f000000u; break;
+   default:
+      unreachable("barrier is only available on gen >= 7");
+   }
 
    /* We are getting the barrier ID from the compute shader header */
    assert(stage == MESA_SHADER_COMPUTE);
@@ -811,7 +969,7 @@ fs_visitor::emit_barrier()
 
 fs_visitor::fs_visitor(const struct brw_compiler *compiler, void *log_data,
                        void *mem_ctx,
-                       const void *key,
+                       const brw_base_prog_key *key,
                        struct brw_stage_prog_data *prog_data,
                        struct gl_program *prog,
                        const nir_shader *shader,
@@ -836,7 +994,7 @@ fs_visitor::fs_visitor(const struct brw_compiler *compiler, void *log_data,
                        int shader_time_index)
    : backend_shader(compiler, log_data, mem_ctx, shader,
                     &prog_data->base.base),
-     key(&c->key), gs_compile(c),
+     key(&c->key.base), gs_compile(c),
      prog_data(&prog_data->base.base), prog(NULL),
      dispatch_width(8),
      shader_time_index(shader_time_index),
@@ -849,28 +1007,7 @@ fs_visitor::fs_visitor(const struct brw_compiler *compiler, void *log_data,
 void
 fs_visitor::init()
 {
-   switch (stage) {
-   case MESA_SHADER_FRAGMENT:
-      key_tex = &((const brw_wm_prog_key *) key)->tex;
-      break;
-   case MESA_SHADER_VERTEX:
-      key_tex = &((const brw_vs_prog_key *) key)->tex;
-      break;
-   case MESA_SHADER_TESS_CTRL:
-      key_tex = &((const brw_tcs_prog_key *) key)->tex;
-      break;
-   case MESA_SHADER_TESS_EVAL:
-      key_tex = &((const brw_tes_prog_key *) key)->tex;
-      break;
-   case MESA_SHADER_GEOMETRY:
-      key_tex = &((const brw_gs_prog_key *) key)->tex;
-      break;
-   case MESA_SHADER_COMPUTE:
-      key_tex = &((const brw_cs_prog_key*) key)->tex;
-      break;
-   default:
-      unreachable("unhandled shader stage");
-   }
+   this->key_tex = &key->tex;
 
    this->max_dispatch_width = 32;
    this->prog_data = this->stage_prog_data;
@@ -898,6 +1035,7 @@ fs_visitor::init()
 
    this->promoted_constants = 0,
 
+   this->grf_used = 0;
    this->spilled_any_registers = false;
 }