Phi Node error when creating a while loop for an atomic spinlock

I’m trying to implement a spinlocked mutex using llvmlite atomics. The reason for this are that I don’t think a prange reduction will work as my input axes differ from my output axes.

This has involved implementing a while loop: Here’s the attempt, but the phi node seems to be setup incorrectly.

import ctypes

import numba
import numpy as np
from llvmlite import ir
from numba.core import types
from numba.core.errors import TypingError
from numba.extending import intrinsic

try:
  libc = ctypes.CDLL("libc.so.6")
except OSError:
  libc = ctypes.CDLL("libc.so")

os_yield_fn = libc.sched_yield
os_yield_fn.argtypes = []


@intrinsic
def lock(typingctx, lock: types.Array, idx: types.UniTuple):
  if not isinstance(lock, types.Array) or not isinstance(lock.dtype, types.Integer):
    raise TypingError(f"lock {lock} must be an Array of integers")

  if (
    not isinstance(idx, types.UniTuple)
    or not isinstance(idx.dtype, types.Integer)
    or len(idx) != lock.ndim
  ):
    raise TypingError(f"idx {idx} must be a Tuple of length {lock.ndim} integers")

  sig = types.bool(lock, idx)

  def yield_idx_wrapper(i):
    print("Yielding", i)
    os_yield_fn()

  def codegen(context, builder, signature, args):
    lock, idx = args
    lock_type, idx_type = signature.args
    llvm_lock_type = context.get_value_type(lock_type.dtype)
    lock_array = context.make_array(lock_type)(context, builder, lock)

    native_idx = [builder.extract_value(idx, i) for i in range(len(idx_type))]
    out_ptr = builder.gep(lock_array.data, native_idx)

    # Store this block and create loop blocks
    entry_block = builder.block
    loop_cond = builder.append_basic_block(name="lock.loop.cond")
    loop_body = builder.append_basic_block(name="lock.loop.body")
    loop_end = builder.append_basic_block(name="lock.loop.end")

    # Initialize counter and branch to loop cond
    initial_count = ir.Constant(ir.IntType(64), 0)
    builder.branch(loop_cond)

    with builder.goto_block(loop_cond):
      # Create PHI node for counter with incoming values from entry and loop body
      count_phi = builder.phi(ir.IntType(64), name="lock.loop.count")
      # Attempt atomic compare-and-exchange
      xchng_result = builder.cmpxchg(
        out_ptr,
        ir.Constant(llvm_lock_type, 0),
        ir.Constant(llvm_lock_type, 1),
        ordering="acquire",
        failordering="monotonic",
      )
      success = builder.extract_value(xchng_result, 1)
      pred = builder.icmp_signed("==", success, success.type(1))
      builder.cbranch(pred, loop_end, loop_body)

    with builder.goto_block(loop_body):
      # Loop body block
      next_count = builder.add(count_phi, count_phi.type(1))
      context.compile_internal(builder, yield_idx_wrapper, types.none(types.int64), [next_count])
      builder.branch(loop_cond)

    # Add incoming value from loop body to the PHI node
    count_phi.add_incoming(initial_count, entry_block)
    count_phi.add_incoming(next_count, loop_body)

    # End block
    builder.position_at_end(loop_end)
    return ir.Constant(ir.IntType(1), 1)

  return sig, codegen


if __name__ == "__main__":

  @numba.njit(nogil=True)
  def f(a, i):
    return lock(a, i)

  print(f(np.full(10, 1, np.int32), (0,)))

This produces the following exception.


RuntimeError: PHI node entries do not match predecessors!
  %lock.while.index = phi i64 [ 0, %B0 ], [ %.32, %lock.loop.body ]
label %lock.loop.body
label %lock.loop.body.endif

I’m not quite sure what’s incorrect as this generally tracks a loop implementation here:

If we dump the LLVM IR with NUMBA_DUMP_LLVM=1 set, we get:

; ModuleID = "f$1"
target triple = "x86_64-conda-linux-gnu"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"

@"_ZN08NumbaEnv8__main__1fB2v1B40c8tJTIeFIjxB2IKSgI4CrvQClQZ6FcpCShpgEU0AE5ArrayIiLi1E1C7mutable7alignedE8UniTupleIxLi1EE" = common global i8* null
define i32 @"_ZN8__main__1fB2v1B40c8tJTIeFIjxB2IKSgI4CrvQClQZ6FcpCShpgEU0AE5ArrayIiLi1E1C7mutable7alignedE8UniTupleIxLi1EE"(i8* noalias nocapture %"retptr", {i8*, i32, i8*, i8*, i32}** noalias nocapture %"excinfo", i8* %"arg.a.0", i8* %"arg.a.1", i64 %"arg.a.2", i64 %"arg.a.3", i32* %"arg.a.4", i64 %"arg.a.5.0", i64 %"arg.a.6.0", i64 %"arg.i.0")
{
entry:
  %"inserted.meminfo" = insertvalue {i8*, i8*, i64, i64, i32*, [1 x i64], [1 x i64]} undef, i8* %"arg.a.0", 0
  %"inserted.parent" = insertvalue {i8*, i8*, i64, i64, i32*, [1 x i64], [1 x i64]} %"inserted.meminfo", i8* %"arg.a.1", 1
  %"inserted.nitems" = insertvalue {i8*, i8*, i64, i64, i32*, [1 x i64], [1 x i64]} %"inserted.parent", i64 %"arg.a.2", 2
  %"inserted.itemsize" = insertvalue {i8*, i8*, i64, i64, i32*, [1 x i64], [1 x i64]} %"inserted.nitems", i64 %"arg.a.3", 3
  %"inserted.data" = insertvalue {i8*, i8*, i64, i64, i32*, [1 x i64], [1 x i64]} %"inserted.itemsize", i32* %"arg.a.4", 4
  %".12" = insertvalue [1 x i64] undef, i64 %"arg.a.5.0", 0
  %"inserted.shape" = insertvalue {i8*, i8*, i64, i64, i32*, [1 x i64], [1 x i64]} %"inserted.data", [1 x i64] %".12", 5
  %".13" = insertvalue [1 x i64] undef, i64 %"arg.a.6.0", 0
  %"inserted.strides" = insertvalue {i8*, i8*, i64, i64, i32*, [1 x i64], [1 x i64]} %"inserted.shape", [1 x i64] %".13", 6
  %".14" = insertvalue [1 x i64] undef, i64 %"arg.i.0", 0
  %".19" = alloca {i8*, i8*, i64, i64, i32*, [1 x i64], [1 x i64]}
  store {i8*, i8*, i64, i64, i32*, [1 x i64], [1 x i64]} zeroinitializer, {i8*, i8*, i64, i64, i32*, [1 x i64], [1 x i64]}* %".19"
  %".33" = alloca i8*
  store i8* null, i8** %".33"
  %"excinfo.1" = alloca {i8*, i32, i8*, i8*, i32}*
  store {i8*, i32, i8*, i8*, i32}* null, {i8*, i32, i8*, i8*, i32}** %"excinfo.1"
  %"try_state" = alloca i64
  store i64 0, i64* %"try_state"
  br label %"B0"
B0:
  %"extracted.meminfo" = extractvalue {i8*, i8*, i64, i64, i32*, [1 x i64], [1 x i64]} %"inserted.strides", 0
  %"extracted.parent" = extractvalue {i8*, i8*, i64, i64, i32*, [1 x i64], [1 x i64]} %"inserted.strides", 1
  %"extracted.nitems" = extractvalue {i8*, i8*, i64, i64, i32*, [1 x i64], [1 x i64]} %"inserted.strides", 2
  %"extracted.itemsize" = extractvalue {i8*, i8*, i64, i64, i32*, [1 x i64], [1 x i64]} %"inserted.strides", 3
  %"extracted.data" = extractvalue {i8*, i8*, i64, i64, i32*, [1 x i64], [1 x i64]} %"inserted.strides", 4
  %"extracted.shape" = extractvalue {i8*, i8*, i64, i64, i32*, [1 x i64], [1 x i64]} %"inserted.strides", 5
  %".15" = extractvalue [1 x i64] %"extracted.shape", 0
  %"extracted.strides" = extractvalue {i8*, i8*, i64, i64, i32*, [1 x i64], [1 x i64]} %"inserted.strides", 6
  %".16" = extractvalue [1 x i64] %"extracted.strides", 0
  call void @"NRT_incref"(i8* %"extracted.meminfo")
  %".18" = extractvalue [1 x i64] %".14", 0
  store {i8*, i8*, i64, i64, i32*, [1 x i64], [1 x i64]} zeroinitializer, {i8*, i8*, i64, i64, i32*, [1 x i64], [1 x i64]}* %".19"
  store {i8*, i8*, i64, i64, i32*, [1 x i64], [1 x i64]} %"inserted.strides", {i8*, i8*, i64, i64, i32*, [1 x i64], [1 x i64]}* %".19"
  %".23" = extractvalue [1 x i64] %".14", 0
  %".24" = getelementptr inbounds {i8*, i8*, i64, i64, i32*, [1 x i64], [1 x i64]}, {i8*, i8*, i64, i64, i32*, [1 x i64], [1 x i64]}* %".19", i32 0, i32 4
  %".25" = load i32*, i32** %".24"
  %".26" = getelementptr i32, i32* %".25", i64 %".23"
  br label %"lock.loop.cond"
lock.loop.cond:
  %"lock.loop.count" = phi  i64 [0, %"B0"], [%".32", %"lock.loop.body"]
  %".28" = cmpxchg i32* %".26", i32 0, i32 1 acquire monotonic
  %".29" = extractvalue {i32, i1} %".28", 1
  %".30" = icmp eq i1 %".29", 1
  br i1 %".30", label %"lock.loop.end", label %"lock.loop.body"
lock.loop.body:
  %".32" = add i64 %"lock.loop.count", 1
  store i8* null, i8** %".33"
  %".37" = call i32 @"_ZN8__main__4lock12_3clocals_3e17yield_idx_wrapperB2v2B42c8tJTC_2fWQA93W1AaAIYBPIqRBFCjDSZRVAJmaQIAEx"(i8** %".33", {i8*, i32, i8*, i8*, i32}** %"excinfo.1", i64 %".32")
  %".38" = load {i8*, i32, i8*, i8*, i32}*, {i8*, i32, i8*, i8*, i32}** %"excinfo.1"
  %".39" = icmp eq i32 %".37", 0
  %".40" = icmp eq i32 %".37", -2
  %".41" = icmp eq i32 %".37", -1
  %".42" = icmp eq i32 %".37", -3
  %".43" = or i1 %".39", %".40"
  %".44" = xor i1 %".43", -1
  %".45" = icmp sge i32 %".37", 1
  %".46" = select  i1 %".45", {i8*, i32, i8*, i8*, i32}* %".38", {i8*, i32, i8*, i8*, i32}* undef
  %".47" = load i8*, i8** %".33"
  br i1 %".44", label %"lock.loop.body.if", label %"lock.loop.body.endif", !prof !0
lock.loop.end:
  %".60" = extractvalue [1 x i64] %".14", 0
  %"extracted.meminfo.1" = extractvalue {i8*, i8*, i64, i64, i32*, [1 x i64], [1 x i64]} %"inserted.strides", 0
  %"extracted.parent.1" = extractvalue {i8*, i8*, i64, i64, i32*, [1 x i64], [1 x i64]} %"inserted.strides", 1
  %"extracted.nitems.1" = extractvalue {i8*, i8*, i64, i64, i32*, [1 x i64], [1 x i64]} %"inserted.strides", 2
  %"extracted.itemsize.1" = extractvalue {i8*, i8*, i64, i64, i32*, [1 x i64], [1 x i64]} %"inserted.strides", 3
  %"extracted.data.1" = extractvalue {i8*, i8*, i64, i64, i32*, [1 x i64], [1 x i64]} %"inserted.strides", 4
  %"extracted.shape.1" = extractvalue {i8*, i8*, i64, i64, i32*, [1 x i64], [1 x i64]} %"inserted.strides", 5
  %".61" = extractvalue [1 x i64] %"extracted.shape.1", 0
  %"extracted.strides.1" = extractvalue {i8*, i8*, i64, i64, i32*, [1 x i64], [1 x i64]} %"inserted.strides", 6
  %".62" = extractvalue [1 x i64] %"extracted.strides.1", 0
  call void @"NRT_decref"(i8* %"extracted.meminfo.1")
  %".64" = zext i1 1 to i8
  store i8 %".64", i8* %"retptr"
  ret i32 0
lock.loop.body.if:
  store i64 0, i64* %"try_state"
  %".51" = load i64, i64* %"try_state"
  %".52" = icmp ugt i64 %".51", 0
  %".53" = load {i8*, i32, i8*, i8*, i32}*, {i8*, i32, i8*, i8*, i32}** %"excinfo"
  store {i8*, i32, i8*, i8*, i32}* %".46", {i8*, i32, i8*, i8*, i32}** %"excinfo"
  %".55" = xor i1 %".52", -1
  br i1 %".55", label %"lock.loop.body.if.if", label %"lock.loop.body.if.endif"
lock.loop.body.endif:
  br label %"lock.loop.cond"
lock.loop.body.if.if:
  ret i32 %".37"
lock.loop.body.if.endif:
  br label %"lock.loop.body.endif"
}

declare void @"NRT_incref"(i8* noalias nocapture %".1")

declare i32 @"_ZN8__main__4lock12_3clocals_3e17yield_idx_wrapperB2v2B42c8tJTC_2fWQA93W1AaAIYBPIqRBFCjDSZRVAJmaQIAEx"(i8** noalias nocapture %"retptr", {i8*, i32, i8*, i8*, i32}** noalias nocapture %"excinfo", i64 %"arg.i")

declare void @"NRT_decref"(i8* noalias nocapture %".1")

!0 = !{ !"branch_weights", i32 1, i32 99 }

The phi node is in lock.loop.cond:

lock.loop.cond:
  %"lock.loop.count" = phi  i64 [0, %"B0"], [%".32", %"lock.loop.body"]

Branches to it are in B0 and lock.loop.body.endif:

B0:
  ; ...
  br label %"lock.loop.cond"
lock.loop.body.endif:
  br label %"lock.loop.cond"

lock.loop.body.endif does not match the lock.loop.body label identified in the phi node.

(This is just a quick diagnosis of the symptom for now - I may follow up with a suggestion of the issue in the original code if I can find it quickly)

I think the solution is to keep track of which block the branch is inserted into, rather than assuming it’s the same block as before the call to compile_internal - that call has inserted some control flow, and moved the builder to another block:

diff --git a/repro.py b/repro.py
index 4562ac0..90eff20 100644
--- a/repro.py
+++ b/repro.py
@@ -72,11 +72,12 @@ def lock(typingctx, lock: types.Array, idx: types.UniTuple):
       # Loop body block
       next_count = builder.add(count_phi, count_phi.type(1))
       context.compile_internal(builder, yield_idx_wrapper, types.none(types.int64), [next_count])
+      branch_block = builder.block
       builder.branch(loop_cond)
 
     # Add incoming value from loop body to the PHI node
     count_phi.add_incoming(initial_count, entry_block)
-    count_phi.add_incoming(next_count, loop_body)
+    count_phi.add_incoming(next_count, branch_block)
 
     # End block
     builder.position_at_end(loop_end)

then the phi node is correct:

lock.loop.cond:
  %"lock.loop.count" = phi  i64 [0, %"B0"], [%".32", %"lock.loop.body.endif"]
2 Likes

Thanks for the clear explanation @gmarkall. While I’ve written quite a few basic intrinsics, I’ve never done anything more complicated than some branching blocks. Consequently have never had to dive into the LLVM IR itself. Your strategy for debugging this kind of problem will be useful in future.

1 Like