tgsi/ureg: make the dst register match the src indirection
[mesa.git] / src / gallium / auxiliary / draw / draw_llvm.c
index 9ab3a9f874857d815bce1f9ffe642f093a281d20..d2821a1463389674e0daa57ac4a7239b110d1251 100644 (file)
@@ -64,6 +64,19 @@ draw_llvm_generate(struct draw_llvm *llvm, struct draw_llvm_variant *var,
                    boolean elts);
 
 
+struct draw_gs_llvm_iface {
+   struct lp_build_tgsi_gs_iface base;
+
+   struct draw_gs_llvm_variant *variant;
+   LLVMValueRef input;
+};
+
+static INLINE const struct draw_gs_llvm_iface *
+draw_gs_llvm_iface(const struct lp_build_tgsi_gs_iface *iface)
+{
+   return (const struct draw_gs_llvm_iface *)iface;
+}
+
 /**
  * Create LLVM type for struct draw_jit_texture
  */
@@ -190,7 +203,7 @@ create_jit_context_type(struct gallivm_state *gallivm,
 {
    LLVMTargetDataRef target = gallivm->target;
    LLVMTypeRef float_type = LLVMFloatTypeInContext(gallivm->context);
-   LLVMTypeRef elem_types[5];
+   LLVMTypeRef elem_types[DRAW_JIT_CTX_NUM_FIELDS];
    LLVMTypeRef context_type;
 
    elem_types[0] = LLVMArrayType(LLVMPointerType(float_type, 0), /* vs_constants */
@@ -211,11 +224,11 @@ create_jit_context_type(struct gallivm_state *gallivm,
 #endif
 
    LP_CHECK_MEMBER_OFFSET(struct draw_jit_context, vs_constants,
-                          target, context_type, 0);
+                          target, context_type, DRAW_JIT_CTX_CONSTANTS);
    LP_CHECK_MEMBER_OFFSET(struct draw_jit_context, planes,
-                          target, context_type, 1);
+                          target, context_type, DRAW_JIT_CTX_PLANES);
    LP_CHECK_MEMBER_OFFSET(struct draw_jit_context, viewport,
-                          target, context_type, 2);
+                          target, context_type, DRAW_JIT_CTX_VIEWPORT);
    LP_CHECK_MEMBER_OFFSET(struct draw_jit_context, textures,
                           target, context_type,
                           DRAW_JIT_CTX_TEXTURES);
@@ -241,7 +254,7 @@ create_gs_jit_context_type(struct gallivm_state *gallivm,
    LLVMTargetDataRef target = gallivm->target;
    LLVMTypeRef float_type = LLVMFloatTypeInContext(gallivm->context);
    LLVMTypeRef int_type = LLVMInt32TypeInContext(gallivm->context);
-   LLVMTypeRef elem_types[8];
+   LLVMTypeRef elem_types[DRAW_GS_JIT_CTX_NUM_FIELDS];
    LLVMTypeRef context_type;
 
    elem_types[0] = LLVMArrayType(LLVMPointerType(float_type, 0), /* constants */
@@ -249,17 +262,17 @@ create_gs_jit_context_type(struct gallivm_state *gallivm,
    elem_types[1] = LLVMPointerType(LLVMArrayType(LLVMArrayType(float_type, 4),
                                                  DRAW_TOTAL_CLIP_PLANES), 0);
    elem_types[2] = LLVMPointerType(float_type, 0); /* viewport */
-   
-   elem_types[3] = LLVMPointerType(LLVMPointerType(int_type, 0), 0);
-   elem_types[4] = LLVMPointerType(LLVMVectorType(int_type,
-                                                  vector_length), 0);
-   elem_types[5] = LLVMPointerType(LLVMVectorType(int_type,
-                                                  vector_length), 0);
 
-   elem_types[6] = LLVMArrayType(texture_type,
+   elem_types[3] = LLVMArrayType(texture_type,
                                  PIPE_MAX_SHADER_SAMPLER_VIEWS); /* textures */
-   elem_types[7] = LLVMArrayType(sampler_type,
+   elem_types[4] = LLVMArrayType(sampler_type,
                                  PIPE_MAX_SAMPLERS); /* samplers */
+   
+   elem_types[5] = LLVMPointerType(LLVMPointerType(int_type, 0), 0);
+   elem_types[6] = LLVMPointerType(LLVMVectorType(int_type,
+                                                  vector_length), 0);
+   elem_types[7] = LLVMPointerType(LLVMVectorType(int_type,
+                                                  vector_length), 0);
 
    context_type = LLVMStructTypeInContext(gallivm->context, elem_types,
                                           Elements(elem_types), 0);
@@ -270,23 +283,26 @@ create_gs_jit_context_type(struct gallivm_state *gallivm,
 #endif
 
    LP_CHECK_MEMBER_OFFSET(struct draw_gs_jit_context, constants,
-                          target, context_type, 0);
+                          target, context_type, DRAW_GS_JIT_CTX_CONSTANTS);
    LP_CHECK_MEMBER_OFFSET(struct draw_gs_jit_context, planes,
-                          target, context_type, 1);
+                          target, context_type, DRAW_GS_JIT_CTX_PLANES);
    LP_CHECK_MEMBER_OFFSET(struct draw_gs_jit_context, viewport,
-                          target, context_type, 2);
-   LP_CHECK_MEMBER_OFFSET(struct draw_gs_jit_context, prim_lengths,
-                          target, context_type, 3);
-   LP_CHECK_MEMBER_OFFSET(struct draw_gs_jit_context, emitted_vertices,
-                          target, context_type, 4);
-   LP_CHECK_MEMBER_OFFSET(struct draw_gs_jit_context, emitted_prims,
-                          target, context_type, 5);
+                          target, context_type, DRAW_GS_JIT_CTX_VIEWPORT);
    LP_CHECK_MEMBER_OFFSET(struct draw_gs_jit_context, textures,
                           target, context_type,
                           DRAW_GS_JIT_CTX_TEXTURES);
    LP_CHECK_MEMBER_OFFSET(struct draw_gs_jit_context, samplers,
                           target, context_type,
                           DRAW_GS_JIT_CTX_SAMPLERS);
+   LP_CHECK_MEMBER_OFFSET(struct draw_gs_jit_context, prim_lengths,
+                          target, context_type,
+                          DRAW_GS_JIT_CTX_PRIM_LENGTHS);
+   LP_CHECK_MEMBER_OFFSET(struct draw_gs_jit_context, emitted_vertices,
+                          target, context_type,
+                          DRAW_GS_JIT_CTX_EMITTED_VERTICES);
+   LP_CHECK_MEMBER_OFFSET(struct draw_gs_jit_context, emitted_prims,
+                          target, context_type,
+                          DRAW_GS_JIT_CTX_EMITTED_PRIMS);
    LP_CHECK_STRUCT_SIZE(struct draw_gs_jit_context,
                         target, context_type);
 
@@ -577,7 +593,6 @@ generate_vs(struct draw_llvm_variant *variant,
                      NULL /*struct lp_build_mask_context *mask*/,
                      consts_ptr,
                      system_values,
-                     NULL /*pos*/,
                      inputs,
                      outputs,
                      sampler,
@@ -981,18 +996,19 @@ generate_viewport(struct draw_llvm_variant *variant,
    int i;
    struct gallivm_state *gallivm = variant->gallivm;
    struct lp_type f32_type = vs_type;
+   const unsigned pos = draw_current_shader_position_output(variant->llvm->draw);
    LLVMTypeRef vs_type_llvm = lp_build_vec_type(gallivm, vs_type);
-   LLVMValueRef out3 = LLVMBuildLoad(builder, outputs[0][3], ""); /*w0 w1 .. wn*/
+   LLVMValueRef out3 = LLVMBuildLoad(builder, outputs[pos][3], ""); /*w0 w1 .. wn*/
    LLVMValueRef const1 = lp_build_const_vec(gallivm, f32_type, 1.0);       /*1.0 1.0 1.0 1.0*/
    LLVMValueRef vp_ptr = draw_jit_context_viewport(gallivm, context_ptr);
 
    /* for 1/w convention*/
    out3 = LLVMBuildFDiv(builder, const1, out3, "");
-   LLVMBuildStore(builder, out3, outputs[0][3]);
+   LLVMBuildStore(builder, out3, outputs[pos][3]);
 
    /* Viewport Mapping */
    for (i=0; i<3; i++) {
-      LLVMValueRef out = LLVMBuildLoad(builder, outputs[0][i], ""); /*x0 x1 .. xn*/
+      LLVMValueRef out = LLVMBuildLoad(builder, outputs[pos][i], ""); /*x0 x1 .. xn*/
       LLVMValueRef scale;
       LLVMValueRef trans;
       LLVMValueRef scale_i;
@@ -1018,7 +1034,7 @@ generate_viewport(struct draw_llvm_variant *variant,
       out = LLVMBuildFAdd(builder, out, trans, "");
 
       /* store transformed outputs */
-      LLVMBuildStore(builder, out, outputs[0][i]);
+      LLVMBuildStore(builder, out, outputs[pos][i]);
    }
 
 }
@@ -1234,22 +1250,67 @@ clipmask_booli32(struct gallivm_state *gallivm,
    return ret;
 }
 
+static LLVMValueRef
+draw_gs_llvm_fetch_input(const struct lp_build_tgsi_gs_iface *gs_iface,
+                         struct lp_build_tgsi_context * bld_base,
+                         boolean is_indirect,
+                         LLVMValueRef vertex_index,
+                         LLVMValueRef attrib_index,
+                         LLVMValueRef swizzle_index)
+{
+   const struct draw_gs_llvm_iface *gs = draw_gs_llvm_iface(gs_iface);
+   struct gallivm_state *gallivm = bld_base->base.gallivm;
+   LLVMBuilderRef builder = gallivm->builder;
+   LLVMValueRef indices[3];
+   LLVMValueRef res;
+   struct lp_type type = bld_base->base.type;
+
+   if (is_indirect) {
+      int i;
+      res = bld_base->base.zero;
+      for (i = 0; i < type.length; ++i) {
+         LLVMValueRef idx = lp_build_const_int32(gallivm, i);
+         LLVMValueRef vert_chan_index = LLVMBuildExtractElement(builder,
+                                                                vertex_index, idx, "");
+         LLVMValueRef channel_vec, value;
+         indices[0] = vert_chan_index;
+         indices[1] = attrib_index;
+         indices[2] = swizzle_index;
+         
+         channel_vec = LLVMBuildGEP(builder, gs->input, indices, 3, "");
+         channel_vec = LLVMBuildLoad(builder, channel_vec, "");
+         value = LLVMBuildExtractElement(builder, channel_vec, idx, "");
+
+         res = LLVMBuildInsertElement(builder, res, value, idx, "");
+      }
+   } else {
+      indices[0] = vertex_index;
+      indices[1] = attrib_index;
+      indices[2] = swizzle_index;
+
+      res = LLVMBuildGEP(builder, gs->input, indices, 3, "");
+      res = LLVMBuildLoad(builder, res, "");
+   }
+
+   return res;
+}
+
 static void
-draw_gs_llvm_emit_vertex(struct lp_build_tgsi_context * bld_base,
+draw_gs_llvm_emit_vertex(const struct lp_build_tgsi_gs_iface *gs_base,
+                         struct lp_build_tgsi_context * bld_base,
                          LLVMValueRef (*outputs)[4],
-                         LLVMValueRef emitted_vertices_vec,
-                         void *user_data)
+                         LLVMValueRef emitted_vertices_vec)
 {
-   struct draw_gs_llvm_variant *variant =
-      (struct draw_gs_llvm_variant *)user_data;
+   const struct draw_gs_llvm_iface *gs_iface = draw_gs_llvm_iface(gs_base);
+   struct draw_gs_llvm_variant *variant = gs_iface->variant;
    struct gallivm_state *gallivm = variant->gallivm;
    LLVMBuilderRef builder = gallivm->builder;
    struct lp_type gs_type = bld_base->base.type;
    LLVMValueRef clipmask = lp_build_const_int_vec(gallivm,
                                                   lp_int_type(gs_type), 0);
    LLVMValueRef indices[LP_MAX_VECTOR_LENGTH];
-   LLVMValueRef max_output_vertices =
-      lp_build_const_int32(gallivm, variant->shader->base.max_output_vertices);
+   LLVMValueRef next_prim_offset =
+      lp_build_const_int32(gallivm, variant->shader->base.primitive_boundary);
    LLVMValueRef io = variant->io_ptr;
    unsigned i;
    const struct tgsi_shader_info *gs_info = &variant->shader->base.info;
@@ -1258,7 +1319,7 @@ draw_gs_llvm_emit_vertex(struct lp_build_tgsi_context * bld_base,
       LLVMValueRef ind = lp_build_const_int32(gallivm, i);
       LLVMValueRef currently_emitted =
          LLVMBuildExtractElement(builder, emitted_vertices_vec, ind, "");
-      indices[i] = LLVMBuildMul(builder, ind, max_output_vertices, "");
+      indices[i] = LLVMBuildMul(builder, ind, next_prim_offset, "");
       indices[i] = LLVMBuildAdd(builder, indices[i], currently_emitted, "");
    }
 
@@ -1269,13 +1330,13 @@ draw_gs_llvm_emit_vertex(struct lp_build_tgsi_context * bld_base,
 }
 
 static void
-draw_gs_llvm_end_primitive(struct lp_build_tgsi_context * bld_base,
+draw_gs_llvm_end_primitive(const struct lp_build_tgsi_gs_iface *gs_base,
+                           struct lp_build_tgsi_context * bld_base,
                            LLVMValueRef verts_per_prim_vec,
-                           LLVMValueRef emitted_prims_vec,
-                           void *user_data)
+                           LLVMValueRef emitted_prims_vec)
 {
-   struct draw_gs_llvm_variant *variant =
-      (struct draw_gs_llvm_variant *)user_data;
+   const struct draw_gs_llvm_iface *gs_iface = draw_gs_llvm_iface(gs_base);
+   struct draw_gs_llvm_variant *variant = gs_iface->variant;
    struct gallivm_state *gallivm = variant->gallivm;
    LLVMBuilderRef builder = gallivm->builder;
    LLVMValueRef prim_lengts_ptr =
@@ -1298,13 +1359,13 @@ draw_gs_llvm_end_primitive(struct lp_build_tgsi_context * bld_base,
 }
 
 static void
-draw_gs_llvm_epilogue(struct lp_build_tgsi_context * bld_base,
+draw_gs_llvm_epilogue(const struct lp_build_tgsi_gs_iface *gs_base,
+                      struct lp_build_tgsi_context * bld_base,
                       LLVMValueRef total_emitted_vertices_vec,
-                      LLVMValueRef emitted_prims_vec,
-                      void *user_data)
+                      LLVMValueRef emitted_prims_vec)
 {
-   struct draw_gs_llvm_variant *variant =
-      (struct draw_gs_llvm_variant *)user_data;
+   const struct draw_gs_llvm_iface *gs_iface = draw_gs_llvm_iface(gs_base);
+   struct draw_gs_llvm_variant *variant = gs_iface->variant;
    struct gallivm_state *gallivm = variant->gallivm;
    LLVMBuilderRef builder = gallivm->builder;
    LLVMValueRef emitted_verts_ptr =
@@ -1350,14 +1411,15 @@ draw_llvm_generate(struct draw_llvm *llvm, struct draw_llvm_variant *variant,
    struct lp_build_sampler_soa *sampler = 0;
    LLVMValueRef ret, clipmask_bool_ptr;
    const struct draw_geometry_shader *gs = draw->gs.geometry_shader;
+   struct draw_llvm_variant_key *key = &variant->key;
    /* If geometry shader is present we need to skip both the viewport
     * transformation and clipping otherwise the inputs to the geometry
     * shader will be incorrect.
     */
-   const boolean bypass_viewport = gs || variant->key.bypass_viewport;
-   const boolean enable_cliptest = !gs && (variant->key.clip_xy ||
-                                           variant->key.clip_z  ||
-                                           variant->key.clip_user);
+   const boolean bypass_viewport = gs || key->bypass_viewport;
+   const boolean enable_cliptest = !gs && (key->clip_xy ||
+                                           key->clip_z  ||
+                                           key->clip_user);
    LLVMValueRef variant_func;
    const unsigned pos = draw_current_shader_position_output(llvm->draw);
    const unsigned cv = draw_current_shader_clipvertex_output(llvm->draw);
@@ -1447,7 +1509,7 @@ draw_llvm_generate(struct draw_llvm *llvm, struct draw_llvm_variant *variant,
 
    /* code generated texture sampling */
    sampler = draw_llvm_sampler_soa_create(
-      draw_llvm_variant_key_samplers(&variant->key),
+      draw_llvm_variant_key_samplers(key),
       context_ptr);
 
    if (elts) {
@@ -1524,7 +1586,7 @@ draw_llvm_generate(struct draw_llvm *llvm, struct draw_llvm_variant *variant,
                   &system_values,
                   context_ptr,
                   sampler,
-                  variant->key.clamp_vertex_color);
+                  key->clamp_vertex_color);
 
       if (pos != -1 && cv != -1) {
          /* store original positions in clip before further manipulation */
@@ -1539,11 +1601,11 @@ draw_llvm_generate(struct draw_llvm *llvm, struct draw_llvm_variant *variant,
                                          gallivm,
                                          vs_type,
                                          outputs,
-                                         variant->key.clip_xy,
-                                         variant->key.clip_z,
-                                         variant->key.clip_user,
-                                         variant->key.clip_halfz,
-                                         variant->key.ucp_enable,
+                                         key->clip_xy,
+                                         key->clip_z,
+                                         key->clip_user,
+                                         key->clip_halfz,
+                                         key->ucp_enable,
                                          context_ptr, &have_clipdist);
             temp = LLVMBuildOr(builder, clipmask, temp, "");
             /* store temporary clipping boolean value */
@@ -1609,7 +1671,7 @@ draw_llvm_make_variant_key(struct draw_llvm *llvm, char *store)
    key->clip_z = llvm->draw->clip_z;
    key->clip_user = llvm->draw->clip_user;
    key->bypass_viewport = llvm->draw->identity_viewport;
-   key->clip_halfz = !llvm->draw->rasterizer->gl_rasterization_rules;
+   key->clip_halfz = llvm->draw->rasterizer->clip_halfz;
    key->need_edgeflags = (llvm->draw->vs.edgeflag_output ? TRUE : FALSE);
    key->ucp_enable = llvm->draw->rasterizer->clip_plane_enable;
    key->has_gs = llvm->draw->gs.geometry_shader != NULL;
@@ -1703,6 +1765,9 @@ draw_llvm_set_mapped_texture(struct draw_context *draw,
       assert(sview_idx < Elements(draw->llvm->gs_jit_context.textures));
 
       jit_tex = &draw->llvm->gs_jit_context.textures[sview_idx];
+   } else {
+      assert(0);
+      return;
    }
 
    jit_tex->width = width;
@@ -1721,33 +1786,36 @@ draw_llvm_set_mapped_texture(struct draw_context *draw,
 
 
 void
-draw_llvm_set_sampler_state(struct draw_context *draw)
+draw_llvm_set_sampler_state(struct draw_context *draw, 
+                            unsigned shader_type)
 {
    unsigned i;
 
-   for (i = 0; i < draw->num_samplers[PIPE_SHADER_VERTEX]; i++) {
-      struct draw_jit_sampler *jit_sam = &draw->llvm->jit_context.samplers[i];
-
-      if (draw->samplers[i]) {
-         const struct pipe_sampler_state *s
-            = draw->samplers[PIPE_SHADER_VERTEX][i];
-         jit_sam->min_lod = s->min_lod;
-         jit_sam->max_lod = s->max_lod;
-         jit_sam->lod_bias = s->lod_bias;
-         COPY_4V(jit_sam->border_color, s->border_color.f);
+   if (shader_type == PIPE_SHADER_VERTEX) {
+      for (i = 0; i < draw->num_samplers[PIPE_SHADER_VERTEX]; i++) {
+         struct draw_jit_sampler *jit_sam = &draw->llvm->jit_context.samplers[i];
+
+         if (draw->samplers[i]) {
+            const struct pipe_sampler_state *s
+               = draw->samplers[PIPE_SHADER_VERTEX][i];
+            jit_sam->min_lod = s->min_lod;
+            jit_sam->max_lod = s->max_lod;
+            jit_sam->lod_bias = s->lod_bias;
+            COPY_4V(jit_sam->border_color, s->border_color.f);
+         }
       }
-   }
-
-   for (i = 0; i < draw->num_samplers[PIPE_SHADER_GEOMETRY]; i++) {
-      struct draw_jit_sampler *jit_sam = &draw->llvm->gs_jit_context.samplers[i];
-
-      if (draw->samplers[i]) {
-         const struct pipe_sampler_state *s
-            = draw->samplers[PIPE_SHADER_GEOMETRY][i];
-         jit_sam->min_lod = s->min_lod;
-         jit_sam->max_lod = s->max_lod;
-         jit_sam->lod_bias = s->lod_bias;
-         COPY_4V(jit_sam->border_color, s->border_color.f);
+   } else if (shader_type == PIPE_SHADER_GEOMETRY) {
+      for (i = 0; i < draw->num_samplers[PIPE_SHADER_GEOMETRY]; i++) {
+         struct draw_jit_sampler *jit_sam = &draw->llvm->gs_jit_context.samplers[i];
+
+         if (draw->samplers[i]) {
+            const struct pipe_sampler_state *s
+               = draw->samplers[PIPE_SHADER_GEOMETRY][i];
+            jit_sam->min_lod = s->min_lod;
+            jit_sam->max_lod = s->max_lod;
+            jit_sam->lod_bias = s->lod_bias;
+            COPY_4V(jit_sam->border_color, s->border_color.f);
+         }
       }
    }
 }
@@ -1849,10 +1917,11 @@ draw_gs_llvm_generate(struct draw_llvm *llvm,
    struct gallivm_state *gallivm = variant->gallivm;
    LLVMContextRef context = gallivm->context;
    LLVMTypeRef int32_type = LLVMInt32TypeInContext(context);
-   LLVMTypeRef arg_types[5];
+   LLVMTypeRef arg_types[6];
    LLVMTypeRef func_type;
    LLVMValueRef variant_func;
    LLVMValueRef context_ptr;
+   LLVMValueRef prim_id_ptr;
    LLVMBasicBlockRef block;
    LLVMBuilderRef builder;
    LLVMValueRef io_ptr, input_array, num_prims, mask_val;
@@ -1861,11 +1930,13 @@ draw_gs_llvm_generate(struct draw_llvm *llvm,
    struct lp_bld_tgsi_system_values system_values;
    struct lp_type gs_type;
    unsigned i;
-   struct lp_build_tgsi_gs_iface gs_iface;
+   struct draw_gs_llvm_iface gs_iface;
    const struct tgsi_token *tokens = variant->shader->base.state.tokens;
    LLVMValueRef consts_ptr;
    LLVMValueRef outputs[PIPE_MAX_SHADER_OUTPUTS][TGSI_NUM_CHANNELS];
    struct lp_build_mask_context mask;
+   const struct tgsi_shader_info *gs_info = &variant->shader->base.info;
+   unsigned vector_length = variant->shader->base.vector_length;
 
    memset(&system_values, 0, sizeof(system_values));
 
@@ -1876,6 +1947,8 @@ draw_gs_llvm_generate(struct draw_llvm *llvm,
    arg_types[2] = variant->vertex_header_ptr_type;     /* vertex_header */
    arg_types[3] = int32_type;                          /* num_prims */
    arg_types[4] = int32_type;                          /* instance_id */
+   arg_types[5] = LLVMPointerType(
+      LLVMVectorType(int32_type, vector_length), 0);   /* prim_id_ptr */
 
    func_type = LLVMFunctionType(int32_type, arg_types, Elements(arg_types), 0);
 
@@ -1895,22 +1968,25 @@ draw_gs_llvm_generate(struct draw_llvm *llvm,
    io_ptr                    = LLVMGetParam(variant_func, 2);
    num_prims                 = LLVMGetParam(variant_func, 3);
    system_values.instance_id = LLVMGetParam(variant_func, 4);
+   prim_id_ptr               = LLVMGetParam(variant_func, 5);
 
    lp_build_name(context_ptr, "context");
    lp_build_name(input_array, "input");
    lp_build_name(io_ptr, "io");
    lp_build_name(io_ptr, "num_prims");
    lp_build_name(system_values.instance_id, "instance_id");
+   lp_build_name(prim_id_ptr, "prim_id_ptr");
 
    variant->context_ptr = context_ptr;
    variant->io_ptr = io_ptr;
    variant->num_prims = num_prims;
 
+   gs_iface.base.fetch_input = draw_gs_llvm_fetch_input;
+   gs_iface.base.emit_vertex = draw_gs_llvm_emit_vertex;
+   gs_iface.base.end_primitive = draw_gs_llvm_end_primitive;
+   gs_iface.base.gs_epilogue = draw_gs_llvm_epilogue;
    gs_iface.input = input_array;
-   gs_iface.emit_vertex = draw_gs_llvm_emit_vertex;
-   gs_iface.end_primitive = draw_gs_llvm_end_primitive;
-   gs_iface.gs_epilogue = draw_gs_llvm_epilogue;
-   gs_iface.user_data = variant;
+   gs_iface.variant = variant;
 
    /*
     * Function body
@@ -1927,7 +2003,7 @@ draw_gs_llvm_generate(struct draw_llvm *llvm,
    gs_type.sign = TRUE;     /* values are signed */
    gs_type.norm = FALSE;    /* values are not limited to [0,1] or [-1,1] */
    gs_type.width = 32;      /* 32-bit float */
-   gs_type.length = variant->shader->base.vector_length;
+   gs_type.length = vector_length;
 
    consts_ptr = draw_gs_jit_context_constants(variant->gallivm, context_ptr);
 
@@ -1938,18 +2014,23 @@ draw_gs_llvm_generate(struct draw_llvm *llvm,
    mask_val = generate_mask_value(variant, gs_type);
    lp_build_mask_begin(&mask, gallivm, gs_type, mask_val);
 
+   if (gs_info->uses_primid) {
+      system_values.prim_id = LLVMBuildLoad(builder, prim_id_ptr, "prim_id");;
+   }
+
    lp_build_tgsi_soa(variant->gallivm,
                      tokens,
                      gs_type,
                      &mask,
                      consts_ptr,
                      &system_values,
-                     NULL /*pos*/,
                      NULL,
                      outputs,
                      sampler,
                      &llvm->draw->gs.geometry_shader->info,
-                     &gs_iface);
+                     (const struct lp_build_tgsi_gs_iface *)&gs_iface);
+
+   sampler->destroy(sampler);
 
    lp_build_mask_end(&mask);