I came across this post by @stuartarchibald that describes how to determine where compile time is spent in the njit function:
Timings:
0_translate_bytecode :0.000001 0.027676 0.000001
1_fixup_args :0.000001 0.000001 0.000000
2_ir_processing :0.000001 0.001911 0.000001
3_with_lifting :0.000001 0.002063 0.000001
4_inline_closure_likes :0.000001 0.005272 0.000001
5_rewrite_semantic_constants :0.000001 0.000131 0.000000
6_dead_branch_prune :0.000000 0.000118 0.000000
7_generic_rewrites :0.000000 0.009968 0.000001
8_make_function_op_code_to_jit_function :0.000001 0.000052 0.000000
9_inline_inlinables :0.000001 0.000161 0.000001
10_dead_branch_prune :0.000000 0.000108 0.000000
11_find_literally :0.000000 0.000098 0.000000
12_literal_unroll :0.000000 0.000063 0.000000
13_reconstruct_ssa :0.000000 0.005832 0.000001
14_nopython_type_inference :0.000001 1.144870 0.000001
15_annotate_types :0.000001 0.003121 0.000001
16_strip_phis :0.000001 0.002152 0.000001
17_inline_overloads :0.000001 0.001972 0.000001
18_pre_parfor_pass :0.000001 0.000687 0.000001
19_nopython_rewrites :0.000000 0.007209 0.000001
20_parfor_pass :0.000001 0.501806 0.000002
21_ir_legalization :0.000001 0.005024 0.000001
22_nopython_backend :0.000000 2.010890 0.000001
23_dump_parfor_diagnostics :0.000001 0.000005 0.000000
And it looks like the 14_nopython_type_inference and 22_nopython_backend is where we are suffering the most and explains the 3-4 second compile time. Blindly, I tried declaring an explicit function signature to see if it had any affect on the nopython_type_inference but, unsurprisingly, that did not help. Unfortunately, the link above did not provide any concrete suggestions for reducing the time for these hotspots.