lp_build_interp_soa_update_pos_dyn(interp, gallivm, loop_state.counter, NULL);
    }
 
+   LLVMValueRef out_sample_mask_storage = NULL;
+   if (shader->info.base.writes_samplemask) {
+      out_sample_mask_storage = lp_build_alloca(gallivm, int_vec_type, "write_mask");
+   }
    system_values.sample_pos = sample_pos_array;
 
    lp_build_interp_soa_update_inputs_dyn(interp, gallivm, loop_state.counter, NULL, NULL);
       }
    }
    if (shader->info.base.writes_samplemask) {
+      LLVMValueRef output_smask = NULL;
       int smaski = find_output_by_semantic(&shader->info.base,
                                            TGSI_SEMANTIC_SAMPLEMASK,
                                            0);
-      LLVMValueRef smask;
       struct lp_build_context smask_bld;
       lp_build_context_init(&smask_bld, gallivm, int_type);
 
       assert(smaski >= 0);
-      smask = LLVMBuildLoad(builder, outputs[smaski][0], "smask");
+      output_smask = LLVMBuildLoad(builder, outputs[smaski][0], "smask");
       /*
        * Pixel is alive according to the first sample in the mask.
        */
-      smask = LLVMBuildBitCast(builder, smask, smask_bld.vec_type, "");
-      smask = lp_build_and(&smask_bld, smask, smask_bld.one);
-      smask = lp_build_cmp(&smask_bld, PIPE_FUNC_NOTEQUAL, smask, smask_bld.zero);
-      lp_build_mask_update(&mask, smask);
+      output_smask = LLVMBuildBitCast(builder, output_smask, smask_bld.vec_type, "");
+      if (!key->multisample) {
+         output_smask = lp_build_and(&smask_bld, output_smask, smask_bld.one);
+         output_smask = lp_build_cmp(&smask_bld, PIPE_FUNC_NOTEQUAL, output_smask, smask_bld.zero);
+         lp_build_mask_update(&mask, output_smask);
+      }
+      LLVMBuildStore(builder, output_smask, out_sample_mask_storage);
    }
 
    if (key->multisample) {
       /* combine the execution mask post fragment shader with the coverage mask. */
       s_mask = LLVMBuildLoad(builder, s_mask_ptr, "");
       s_mask = LLVMBuildAnd(builder, s_mask, lp_build_mask_value(&mask), "");
+
+      /* if the shader writes sample mask use that */
+      if (shader->info.base.writes_samplemask) {
+         LLVMValueRef out_smask_idx = LLVMBuildShl(builder, lp_build_const_int32(gallivm, 1), sample_loop_state.counter, "");
+         out_smask_idx = lp_build_broadcast(gallivm, int_vec_type, out_smask_idx);
+         LLVMValueRef output_smask = LLVMBuildLoad(builder, out_sample_mask_storage, "");
+         LLVMValueRef smask_bit = LLVMBuildAnd(builder, output_smask, out_smask_idx, "");
+         LLVMValueRef cmp = LLVMBuildICmp(builder, LLVMIntNE, smask_bit, lp_build_const_int_vec(gallivm, int_type, 0), "");
+         smask_bit = LLVMBuildSExt(builder, cmp, int_vec_type, "");
+
+         s_mask = LLVMBuildAnd(builder, s_mask, smask_bit, "");
+      }
    }
 
    depth_ptr = depth_base_ptr;