gallivm: implement switch opcode
authorRoland Scheidegger <sroland@vmware.com>
Fri, 19 Apr 2013 23:46:04 +0000 (01:46 +0200)
committerRoland Scheidegger <sroland@vmware.com>
Sat, 20 Apr 2013 00:27:53 +0000 (02:27 +0200)
Should be able to handle all things which make this tricky to implement.
Fallthroughs, including most notably into/out of default, should be handled
correctly but are quite a mess.
If we see largely unoptimized switches in the wild should probably think
about some "real" switch optimization pass, e.g. things like this:

switch
case1
someinst
brk
case2
default
case3
someinst
brk
case4
someinst
endswitch

are legal, but the pointless case2/case3 statements not only cause condition
evaluation but will turn this into a "fake" fallthrough case (because
mask and defaultmask are already updated for case2 when default is
encountered) requiring executing code twice.
If default is at the end though, there's never any code re-execution, and
if that's not the case if there's no fallthrough in (not even a fake one)
and out of default there's no code re-execution neither.

v2: add comments, and use enum for break type instead of magic boolean.

Reviewed-by: Jose Fonseca <jfonseca@vmware.com>
src/gallium/auxiliary/gallivm/lp_bld_tgsi.h
src/gallium/auxiliary/gallivm/lp_bld_tgsi_action.c
src/gallium/auxiliary/gallivm/lp_bld_tgsi_soa.c

index 175b6a9f0f65c8481979e0e047a8d7914a69cbd5..9d9c742e9a3fd424f2018f237e7f11a09a286b85 100644 (file)
@@ -241,6 +241,12 @@ lp_build_tgsi_aos(struct gallivm_state *gallivm,
                   const struct tgsi_shader_info *info);
 
 
+enum lp_exec_mask_break_type {
+   LP_EXEC_MASK_BREAK_TYPE_LOOP,
+   LP_EXEC_MASK_BREAK_TYPE_SWITCH
+};
+
+
 struct lp_exec_mask {
    struct lp_build_context *bld;
 
@@ -253,6 +259,24 @@ struct lp_exec_mask {
    int cond_stack_size;
    LLVMValueRef cond_mask;
 
+   /* keep track if break belongs to switch or loop */
+   enum lp_exec_mask_break_type break_type_stack[LP_MAX_TGSI_NESTING];
+   enum lp_exec_mask_break_type break_type;
+
+   struct {
+      LLVMValueRef switch_val;
+      LLVMValueRef switch_mask;
+      LLVMValueRef switch_mask_default;
+      boolean switch_in_default;
+      unsigned switch_pc;
+   } switch_stack[LP_MAX_TGSI_NESTING];
+   int switch_stack_size;
+   LLVMValueRef switch_val;
+   LLVMValueRef switch_mask;         /* current switch exec mask */
+   LLVMValueRef switch_mask_default; /* reverse of switch mask used for default */
+   boolean switch_in_default;        /* if switch exec is currently in default */
+   unsigned switch_pc;               /* when used points to default or endswitch-1 */
+
    LLVMBasicBlockRef loop_block;
    LLVMValueRef cont_mask;
    LLVMValueRef break_mask;
index e1c362be4a851862fa242b332b55dadd03236f64..dc7c090ac7cd54afdbcafe13ea427cb08a6cc311 100644 (file)
@@ -865,6 +865,8 @@ lp_set_default_actions(struct lp_build_tgsi_context * bld_base)
    bld_base->op_actions[TGSI_OPCODE_XPD] = xpd_action;
 
    bld_base->op_actions[TGSI_OPCODE_BREAKC].fetch_args = scalar_unary_fetch_args;
+   bld_base->op_actions[TGSI_OPCODE_SWITCH].fetch_args = scalar_unary_fetch_args;
+   bld_base->op_actions[TGSI_OPCODE_CASE].fetch_args = scalar_unary_fetch_args;
    bld_base->op_actions[TGSI_OPCODE_COS].fetch_args = scalar_unary_fetch_args;
    bld_base->op_actions[TGSI_OPCODE_EX2].fetch_args = scalar_unary_fetch_args;
    bld_base->op_actions[TGSI_OPCODE_IF].fetch_args = scalar_unary_fetch_args;
index 680a9c1f3f750e743ff529eaaee3ea03e0724150..c48c6e9d92ab5e70e5b051f1e5a2206c08e8a9c5 100644 (file)
@@ -78,9 +78,11 @@ static void lp_exec_mask_init(struct lp_exec_mask *mask, struct lp_build_context
    mask->cond_stack_size = 0;
    mask->loop_stack_size = 0;
    mask->call_stack_size = 0;
+   mask->switch_stack_size = 0;
 
    mask->int_vec_type = lp_build_int_vec_type(bld->gallivm, mask->bld->type);
-   mask->exec_mask = mask->ret_mask = mask->break_mask = mask->cont_mask = mask->cond_mask =
+   mask->exec_mask = mask->ret_mask = mask->break_mask = mask->cont_mask =
+         mask->cond_mask = mask->switch_mask =
          LLVMConstAllOnes(mask->int_vec_type);
 
    mask->loop_limiter = lp_build_alloca(bld->gallivm, int_type, "looplimiter");
@@ -110,6 +112,13 @@ static void lp_exec_mask_update(struct lp_exec_mask *mask)
    } else
       mask->exec_mask = mask->cond_mask;
 
+   if (mask->switch_stack_size) {
+      mask->exec_mask = LLVMBuildAnd(builder,
+                                     mask->exec_mask,
+                                     mask->switch_mask,
+                                     "switchmask");
+   }
+
    if (mask->call_stack_size || mask->ret_in_main) {
       mask->exec_mask = LLVMBuildAnd(builder,
                                      mask->exec_mask,
@@ -120,6 +129,7 @@ static void lp_exec_mask_update(struct lp_exec_mask *mask)
    mask->has_mask = (mask->cond_stack_size > 0 ||
                      mask->loop_stack_size > 0 ||
                      mask->call_stack_size > 0 ||
+                     mask->switch_stack_size > 0 ||
                      mask->ret_in_main);
 }
 
@@ -181,6 +191,10 @@ static void lp_exec_bgnloop(struct lp_exec_mask *mask)
 
    assert(mask->loop_stack_size < LP_MAX_TGSI_NESTING);
 
+   mask->break_type_stack[mask->loop_stack_size + mask->switch_stack_size] =
+      mask->break_type;
+   mask->break_type = LP_EXEC_MASK_BREAK_TYPE_LOOP;
+
    mask->loop_stack[mask->loop_stack_size].loop_block = mask->loop_block;
    mask->loop_stack[mask->loop_stack_size].cont_mask = mask->cont_mask;
    mask->loop_stack[mask->loop_stack_size].break_mask = mask->break_mask;
@@ -200,16 +214,51 @@ static void lp_exec_bgnloop(struct lp_exec_mask *mask)
    lp_exec_mask_update(mask);
 }
 
-static void lp_exec_break(struct lp_exec_mask *mask)
+static void lp_exec_break(struct lp_exec_mask *mask,
+                          struct lp_build_tgsi_context * bld_base)
 {
    LLVMBuilderRef builder = mask->bld->gallivm->builder;
-   LLVMValueRef exec_mask = LLVMBuildNot(builder,
-                                         mask->exec_mask,
-                                         "break");
 
-   mask->break_mask = LLVMBuildAnd(builder,
-                                   mask->break_mask,
-                                   exec_mask, "break_full");
+   if (mask->break_type == LP_EXEC_MASK_BREAK_TYPE_LOOP) {
+      LLVMValueRef exec_mask = LLVMBuildNot(builder,
+                                            mask->exec_mask,
+                                            "break");
+
+      mask->break_mask = LLVMBuildAnd(builder,
+                                      mask->break_mask,
+                                      exec_mask, "break_full");
+   }
+   else {
+      unsigned opcode = bld_base->instructions[bld_base->pc + 1].Instruction.Opcode;
+      boolean break_always = (opcode == TGSI_OPCODE_ENDSWITCH ||
+                              opcode == TGSI_OPCODE_CASE);
+
+
+      if (mask->switch_in_default) {
+         /*
+          * stop default execution but only if this is an unconditional switch.
+          * (The condition here is not perfect since dead code after break is
+          * allowed but should be sufficient since false negatives are just
+          * unoptimized - so we don't have to pre-evaluate that).
+          */
+         if(break_always && mask->switch_pc) {
+            bld_base->pc = mask->switch_pc;
+            return;
+         }
+      }
+
+      if (break_always) {
+         mask->switch_mask = LLVMConstNull(mask->bld->int_vec_type);
+      }
+      else {
+         LLVMValueRef exec_mask = LLVMBuildNot(builder,
+                                               mask->exec_mask,
+                                               "break");
+         mask->switch_mask = LLVMBuildAnd(builder,
+                                          mask->switch_mask,
+                                          exec_mask, "break_switch");
+      }
+   }
 
    lp_exec_mask_update(mask);
 }
@@ -223,9 +272,16 @@ static void lp_exec_break_condition(struct lp_exec_mask *mask,
                                          cond, "cond_mask");
    cond_mask = LLVMBuildNot(builder, cond_mask, "break_cond");
 
-   mask->break_mask = LLVMBuildAnd(builder,
-                                   mask->break_mask,
-                                   cond_mask, "breakc_full");
+   if (mask->break_type == LP_EXEC_MASK_BREAK_TYPE_LOOP) {
+      mask->break_mask = LLVMBuildAnd(builder,
+                                      mask->break_mask,
+                                      cond_mask, "breakc_full");
+   }
+   else {
+      mask->switch_mask = LLVMBuildAnd(builder,
+                                       mask->switch_mask,
+                                       cond_mask, "breakc_switch");
+   }
 
    lp_exec_mask_update(mask);
 }
@@ -312,10 +368,207 @@ static void lp_exec_endloop(struct gallivm_state *gallivm,
    mask->cont_mask = mask->loop_stack[mask->loop_stack_size].cont_mask;
    mask->break_mask = mask->loop_stack[mask->loop_stack_size].break_mask;
    mask->break_var = mask->loop_stack[mask->loop_stack_size].break_var;
+   mask->break_type = mask->break_type_stack[mask->loop_stack_size + mask->switch_stack_size];
+
+   lp_exec_mask_update(mask);
+}
+
+static void lp_exec_switch(struct lp_exec_mask *mask,
+                           LLVMValueRef switchval)
+{
+   mask->break_type_stack[mask->loop_stack_size + mask->switch_stack_size] =
+      mask->break_type;
+   mask->break_type = LP_EXEC_MASK_BREAK_TYPE_SWITCH;
+
+   mask->switch_stack[mask->switch_stack_size].switch_val = mask->switch_val;
+   mask->switch_stack[mask->switch_stack_size].switch_mask = mask->switch_mask;
+   mask->switch_stack[mask->switch_stack_size].switch_mask_default = mask->switch_mask_default;
+   mask->switch_stack[mask->switch_stack_size].switch_in_default = mask->switch_in_default;
+   mask->switch_stack[mask->switch_stack_size].switch_pc = mask->switch_pc;
+   mask->switch_stack_size++;
+
+   mask->switch_val = switchval;
+   mask->switch_mask = LLVMConstNull(mask->int_vec_type);
+   mask->switch_mask_default = LLVMConstNull(mask->int_vec_type);
+   mask->switch_in_default = false;
+   mask->switch_pc = 0;
 
    lp_exec_mask_update(mask);
 }
 
+static void lp_exec_endswitch(struct lp_exec_mask *mask,
+                              struct lp_build_tgsi_context * bld_base)
+{
+   LLVMBuilderRef builder = mask->bld->gallivm->builder;
+
+   /* check if there's deferred default if so do it now */
+   if (mask->switch_pc && !mask->switch_in_default) {
+      LLVMValueRef prevmask, defaultmask;
+      unsigned tmp_pc;
+      prevmask = mask->switch_stack[mask->switch_stack_size - 1].switch_mask;
+      defaultmask = LLVMBuildNot(builder, mask->switch_mask_default, "sw_default_mask");
+      mask->switch_mask = LLVMBuildAnd(builder, prevmask, defaultmask, "sw_mask");
+      mask->switch_in_default = true;
+
+      lp_exec_mask_update(mask);
+
+      assert(bld_base->instructions[mask->switch_pc - 1].Instruction.Opcode ==
+             TGSI_OPCODE_DEFAULT);
+
+      tmp_pc = bld_base->pc;
+      bld_base->pc = mask->switch_pc;
+      /*
+       * re-purpose switch_pc to point to here again, since we stop execution of
+       * the deferred default after next break.
+       */
+      mask->switch_pc = tmp_pc - 1;
+
+      return;
+   }
+
+   else if (mask->switch_pc && mask->switch_in_default) {
+      assert(bld_base->pc == mask->switch_pc + 1);
+   }
+
+   mask->switch_stack_size--;
+   mask->switch_val = mask->switch_stack[mask->switch_stack_size].switch_val;
+   mask->switch_mask = mask->switch_stack[mask->switch_stack_size].switch_mask;
+   mask->switch_mask_default = mask->switch_stack[mask->switch_stack_size].switch_mask_default;
+   mask->switch_in_default = mask->switch_stack[mask->switch_stack_size].switch_in_default;
+   mask->switch_pc = mask->switch_stack[mask->switch_stack_size].switch_pc;
+
+   mask->break_type = mask->break_type_stack[mask->loop_stack_size + mask->switch_stack_size];
+
+   lp_exec_mask_update(mask);
+}
+
+static void lp_exec_case(struct lp_exec_mask *mask,
+                         LLVMValueRef caseval)
+{
+   LLVMBuilderRef builder = mask->bld->gallivm->builder;
+
+   LLVMValueRef casemask, prevmask;
+
+   /* skipping case mask evaluation here is NOT optional (not in all cases anyway). */
+   if (!mask->switch_in_default) {
+      prevmask = mask->switch_stack[mask->switch_stack_size - 1].switch_mask;
+      casemask = lp_build_cmp(mask->bld, PIPE_FUNC_EQUAL, caseval, mask->switch_val);
+      mask->switch_mask_default = LLVMBuildOr(builder, casemask,
+                                              mask->switch_mask_default, "sw_default_mask");
+      casemask = LLVMBuildOr(builder, casemask, mask->switch_mask, "");
+      mask->switch_mask = LLVMBuildAnd(builder, casemask, prevmask, "sw_mask");
+
+      lp_exec_mask_update(mask);
+   }
+}
+
+/*
+ * Analyse default statement in a switch.
+ * \return true if default is last statement, false otherwise
+ * \param default_pc_start contains pc of instruction to jump to
+ *                         if default wasn't last but there's no
+ *                         fallthrough into default.
+ */
+static boolean default_analyse_is_last(struct lp_exec_mask *mask,
+                                       struct lp_build_tgsi_context * bld_base,
+                                       int *default_pc_start)
+{
+   unsigned pc = bld_base->pc;
+   unsigned curr_switch_stack = mask->switch_stack_size;
+
+   /* skip over case statements which are together with default */
+   while (bld_base->instructions[pc].Instruction.Opcode == TGSI_OPCODE_CASE) {
+      pc++;
+   }
+
+   while (pc != -1 && pc < bld_base->num_instructions) {
+      unsigned opcode = bld_base->instructions[pc].Instruction.Opcode;
+      switch (opcode) {
+      case TGSI_OPCODE_CASE:
+         if (curr_switch_stack == mask->switch_stack_size) {
+            *default_pc_start = pc - 1;
+            return false;
+         }
+         break;
+      case TGSI_OPCODE_SWITCH:
+         curr_switch_stack++;
+         break;
+      case TGSI_OPCODE_ENDSWITCH:
+         if (curr_switch_stack == mask->switch_stack_size) {
+            *default_pc_start = pc - 1;
+            return true;
+         }
+         curr_switch_stack--;
+         break;
+      }
+      pc++;
+   }
+   /* should never arrive here */
+   assert(0);
+   return true;
+}
+
+static void lp_exec_default(struct lp_exec_mask *mask,
+                            struct lp_build_tgsi_context * bld_base)
+{
+   LLVMBuilderRef builder = mask->bld->gallivm->builder;
+
+   int default_exec_pc;
+   boolean default_is_last;
+
+   /*
+    * This is a messy opcode, because it may not be always at the end and
+    * there can be fallthrough in and out of it.
+    */
+
+   default_is_last = default_analyse_is_last(mask, bld_base, &default_exec_pc);
+   /*
+    * If it is last statement in switch (note that case statements appearing
+    * "at the same time" as default don't change that) everything is just fine,
+    * update switch mask and go on. This means we can handle default with
+    * fallthrough INTO it without overhead, if it is last.
+    */
+   if (default_is_last) {
+      LLVMValueRef prevmask, defaultmask;
+      prevmask = mask->switch_stack[mask->switch_stack_size - 1].switch_mask;
+      defaultmask = LLVMBuildNot(builder, mask->switch_mask_default, "sw_default_mask");
+      defaultmask = LLVMBuildOr(builder, defaultmask, mask->switch_mask, "");
+      mask->switch_mask = LLVMBuildAnd(builder, prevmask, defaultmask, "sw_mask");
+      mask->switch_in_default = true;
+
+      lp_exec_mask_update(mask);
+   }
+   else {
+      /*
+       * Technically, "case" immediately before default isn't really a
+       * fallthrough, however we still have to count them as such as we
+       * already have updated the masks.
+       * If that happens in practice could add a switch optimizer pass
+       * which just gets rid of all case statements appearing together with
+       * default (or could do switch analysis at switch start time instead).
+       */
+      unsigned opcode = bld_base->instructions[bld_base->pc - 1].Instruction.Opcode;
+      boolean ft_into = (opcode != TGSI_OPCODE_BRK ||
+                         opcode != TGSI_OPCODE_SWITCH);
+      /*
+       * If it is not last statement and there was no fallthrough into it,
+       * we record the PC and continue execution at next case (again, those
+       * case encountered at the same time don't count). At endswitch
+       * time, we update switchmask, and go back executing the code we skipped
+       * until the next break (possibly re-executing some code with changed mask
+       * if there was a fallthrough out of default).
+       * Finally, if it is not last statement and there was a fallthrough into it,
+       * do the same as with the former case, except instead of skipping the code
+       * just execute it without updating the mask, then go back and re-execute.
+       */
+      mask->switch_pc = bld_base->pc;
+      if (!ft_into) {
+         bld_base->pc = default_exec_pc;
+      }
+   }
+}
+
+
 /* stores val into an address pointed to by dst.
  * mask->exec_mask is used to figure out which bits of val
  * should be stored into the address
@@ -369,6 +622,7 @@ static void lp_exec_mask_ret(struct lp_exec_mask *mask, int *pc)
 
    if (mask->cond_stack_size == 0 &&
        mask->loop_stack_size == 0 &&
+       mask->switch_stack_size == 0 &&
        mask->call_stack_size == 0) {
       /* returning from main() */
       *pc = -1;
@@ -2407,7 +2661,7 @@ brk_emit(
 {
    struct lp_build_tgsi_soa_context * bld = lp_soa_context(bld_base);
 
-   lp_exec_break(&bld->exec_mask);
+   lp_exec_break(&bld->exec_mask, bld_base);
 }
 
 static void
@@ -2457,6 +2711,50 @@ uif_emit(
    lp_exec_mask_cond_push(&bld->exec_mask, tmp);
 }
 
+static void
+case_emit(
+   const struct lp_build_tgsi_action * action,
+   struct lp_build_tgsi_context * bld_base,
+   struct lp_build_emit_data * emit_data)
+{
+   struct lp_build_tgsi_soa_context * bld = lp_soa_context(bld_base);
+
+   lp_exec_case(&bld->exec_mask, emit_data->args[0]);
+}
+
+static void
+default_emit(
+   const struct lp_build_tgsi_action * action,
+   struct lp_build_tgsi_context * bld_base,
+   struct lp_build_emit_data * emit_data)
+{
+   struct lp_build_tgsi_soa_context * bld = lp_soa_context(bld_base);
+
+   lp_exec_default(&bld->exec_mask, bld_base);
+}
+
+static void
+switch_emit(
+   const struct lp_build_tgsi_action * action,
+   struct lp_build_tgsi_context * bld_base,
+   struct lp_build_emit_data * emit_data)
+{
+   struct lp_build_tgsi_soa_context * bld = lp_soa_context(bld_base);
+
+   lp_exec_switch(&bld->exec_mask, emit_data->args[0]);
+}
+
+static void
+endswitch_emit(
+   const struct lp_build_tgsi_action * action,
+   struct lp_build_tgsi_context * bld_base,
+   struct lp_build_emit_data * emit_data)
+{
+   struct lp_build_tgsi_soa_context * bld = lp_soa_context(bld_base);
+
+   lp_exec_endswitch(&bld->exec_mask, bld_base);
+}
+
 static void
 bgnloop_emit(
    const struct lp_build_tgsi_action * action,
@@ -2794,13 +3092,16 @@ lp_build_tgsi_soa(struct gallivm_state *gallivm,
    bld.bld_base.op_actions[TGSI_OPCODE_BRK].emit = brk_emit;
    bld.bld_base.op_actions[TGSI_OPCODE_BREAKC].emit = breakc_emit;
    bld.bld_base.op_actions[TGSI_OPCODE_CAL].emit = cal_emit;
+   bld.bld_base.op_actions[TGSI_OPCODE_CASE].emit = case_emit;
    bld.bld_base.op_actions[TGSI_OPCODE_CONT].emit = cont_emit;
    bld.bld_base.op_actions[TGSI_OPCODE_DDX].emit = ddx_emit;
    bld.bld_base.op_actions[TGSI_OPCODE_DDY].emit = ddy_emit;
+   bld.bld_base.op_actions[TGSI_OPCODE_DEFAULT].emit = default_emit;
    bld.bld_base.op_actions[TGSI_OPCODE_ELSE].emit = else_emit;
    bld.bld_base.op_actions[TGSI_OPCODE_ENDIF].emit = endif_emit;
    bld.bld_base.op_actions[TGSI_OPCODE_ENDLOOP].emit = endloop_emit;
    bld.bld_base.op_actions[TGSI_OPCODE_ENDSUB].emit = endsub_emit;
+   bld.bld_base.op_actions[TGSI_OPCODE_ENDSWITCH].emit = endswitch_emit;
    bld.bld_base.op_actions[TGSI_OPCODE_IF].emit = if_emit;
    bld.bld_base.op_actions[TGSI_OPCODE_UIF].emit = uif_emit;
    bld.bld_base.op_actions[TGSI_OPCODE_KIL].emit = kil_emit;
@@ -2808,6 +3109,7 @@ lp_build_tgsi_soa(struct gallivm_state *gallivm,
    bld.bld_base.op_actions[TGSI_OPCODE_NRM].emit = nrm_emit;
    bld.bld_base.op_actions[TGSI_OPCODE_NRM4].emit = nrm_emit;
    bld.bld_base.op_actions[TGSI_OPCODE_RET].emit = ret_emit;
+   bld.bld_base.op_actions[TGSI_OPCODE_SWITCH].emit = switch_emit;
    bld.bld_base.op_actions[TGSI_OPCODE_TEX].emit = tex_emit;
    bld.bld_base.op_actions[TGSI_OPCODE_TXB].emit = txb_emit;
    bld.bld_base.op_actions[TGSI_OPCODE_TXD].emit = txd_emit;