Optimizing Generators to have the same performance as for loops

Consider the following code to parse some PCAP

Using for loops

import numpy as np

from pcap_parser_util import *


@conditional_njit
def numba_loop(total_buffer, compile_only):
    if compile_only:
        return

    data_count = 0

    PCAP_GLOBAL_HEADER_SIZE = 24  # Bytes

    # Move past the global header
    offset_start_pcap = PCAP_GLOBAL_HEADER_SIZE
    offset_end_pcap = PCAP_GLOBAL_HEADER_SIZE

    while True:
        offset_start_pcap = offset_end_pcap
        offset_end_pcap = offset_start_pcap + pcap_packet_header_itemsize
        offset_buffer = total_buffer[offset_start_pcap:offset_end_pcap]
        if len(offset_buffer) != offset_end_pcap - offset_start_pcap:
            break

        pcap_hdr = np.frombuffer(offset_buffer, dtype=pcap_packet_header_dtype)

        if pcap_hdr[0]["original_length"] != pcap_hdr[0]["included_length"]:
            assert False

        offset_start_pcap = offset_end_pcap
        offset_end_pcap = offset_start_pcap + pcap_hdr[0]["original_length"]
        offset_buffer = total_buffer[offset_start_pcap:offset_end_pcap]
        if len(offset_buffer) != offset_end_pcap - offset_start_pcap:
            break

        pcap_buffer = total_buffer[offset_start_pcap:offset_end_pcap]

        offset_start_network = 0
        offset_end_network = 0

        while True:
            offset_start_network = offset_end_network
            offset_end_network = offset_start_network + ethernet_header_itemsize
            offset_buffer = pcap_buffer[offset_start_network:offset_end_network]
            if len(offset_buffer) != offset_end_network - offset_start_network:
                break

            ethernet_hdr = np.frombuffer(offset_buffer, dtype=ethernet_header_dtype)
            assert ethernet_hdr[0]["ethertype"] == 0x0008

            offset_start_network = offset_end_network
            offset_end_network = offset_start_network + ipv4_header_itemsize
            offset_buffer = pcap_buffer[offset_start_network:offset_end_network]
            if len(offset_buffer) != offset_end_network - offset_start_network:
                break

            ipv4_hdr = np.frombuffer(offset_buffer, dtype=ipv4_header_dtype)

            offset_start_network = offset_end_network
            offset_end_network = offset_start_network + udp_header_itemsize
            offset_buffer = pcap_buffer[offset_start_network:offset_end_network]
            if len(offset_buffer) != offset_end_network - offset_start_network:
                break

            udp_hdr = np.frombuffer(offset_buffer, dtype=udp_header_dtype)

            offset_start_network = offset_end_network
            offset_end_network = offset_start_network + iex_tp_header_itemsize
            offset_buffer = pcap_buffer[offset_start_network:offset_end_network]
            if len(offset_buffer) != offset_end_network - offset_start_network:
                break

            iex_tp_hdr = np.frombuffer(offset_buffer, dtype=iex_tp_header_dtype)

            assert iex_tp_hdr[0]["version"] == 1

            offset_start_network = offset_end_network
            offset_end_network = offset_start_network + iex_tp_hdr[0]["payload_length"]
            offset_buffer = pcap_buffer[offset_start_network:offset_end_network]
            if len(offset_buffer) != offset_end_network - offset_start_network:
                break

            iex_buffer = np.frombuffer(offset_buffer, dtype=np.uint8)
            offset_start_iex = 0
            offset_end_iex = 0

            while True:
                offset_start_iex = offset_end_iex
                offset_end_iex = offset_start_iex + 2
                offset_buffer = iex_buffer[offset_start_iex:offset_end_iex]
                if len(offset_buffer) != offset_end_iex - offset_start_iex:
                    break

                message_hdr = np.frombuffer(offset_buffer, dtype=np.uint16)

                offset_start_iex = offset_end_iex
                offset_end_iex = offset_start_iex + 1
                offset_buffer = iex_buffer[offset_start_iex:offset_end_iex]
                if len(offset_buffer) != offset_end_iex - offset_start_iex:
                    break

                message_type = np.frombuffer(offset_buffer, dtype=np.uint8)[0]
                offset_end_iex -= 1

                if message_type == 0x41:
                    offset_start_iex = offset_end_iex
                    offset_end_iex = offset_start_iex + auction_information_message_itemsize
                    offset_buffer = iex_buffer[offset_start_iex:offset_end_iex]
                    if len(offset_buffer) != offset_end_iex - offset_start_iex:
                        break

                    data_buffer = np.frombuffer(offset_buffer, dtype=auction_information_message_dtype)
                    # print("data_buffer")
                    # print(data_buffer)
                elif message_type == 0x44:
                    offset_start_iex = offset_end_iex
                    offset_end_iex = offset_start_iex + security_directory_message_itemsize
                    offset_buffer = iex_buffer[offset_start_iex:offset_end_iex]
                    if len(offset_buffer) != offset_end_iex - offset_start_iex:
                        break

                    data_buffer = np.frombuffer(
                        offset_buffer, dtype=security_directory_message_dtype
                    )
                    # print("data_buffer")
                    # print(data_buffer)
                elif message_type == 0x48:
                    offset_start_iex = offset_end_iex
                    offset_end_iex = offset_start_iex + trading_status_message_itemsize
                    offset_buffer = iex_buffer[offset_start_iex:offset_end_iex]
                    if len(offset_buffer) != offset_end_iex - offset_start_iex:
                        break

                    data_buffer = np.frombuffer(
                        offset_buffer, dtype=trading_status_message_dtype
                    )
                    # print("data_buffer")
                    # print(data_buffer)
                elif message_type == 0x49:
                    offset_start_iex = offset_end_iex
                    offset_end_iex = offset_start_iex + retail_liquidity_indicator_message_itemsize
                    offset_buffer = iex_buffer[offset_start_iex:offset_end_iex]
                    if len(offset_buffer) != offset_end_iex - offset_start_iex:
                        break

                    data_buffer = np.frombuffer(
                        offset_buffer, dtype=retail_liquidity_indicator_message_dtype
                    )
                    # print("data_buffer")
                    # print(data_buffer)
                elif message_type == 0x4F:
                    offset_start_iex = offset_end_iex
                    offset_end_iex = offset_start_iex + operational_halt_status_message_itemsize
                    offset_buffer = iex_buffer[offset_start_iex:offset_end_iex]
                    if len(offset_buffer) != offset_end_iex - offset_start_iex:
                        break

                    data_buffer = np.frombuffer(
                        offset_buffer, dtype=operational_halt_status_message_dtype
                    )
                    # print("data_buffer")
                    # print(data_buffer)
                elif message_type == 0x50:
                    offset_start_iex = offset_end_iex
                    offset_end_iex = offset_start_iex + short_sale_price_test_status_message_itemsize
                    offset_buffer = iex_buffer[offset_start_iex:offset_end_iex]
                    if len(offset_buffer) != offset_end_iex - offset_start_iex:
                        break

                    data_buffer = np.frombuffer(
                        offset_buffer, dtype=short_sale_price_test_status_message_dtype
                    )
                    # print("data_buffer")
                    # print(data_buffer)
                elif message_type == 0x51:
                    offset_start_iex = offset_end_iex
                    offset_end_iex = offset_start_iex + quote_update_message_itemsize
                    offset_buffer = iex_buffer[offset_start_iex:offset_end_iex]
                    if len(offset_buffer) != offset_end_iex - offset_start_iex:
                        break

                    data_buffer = np.frombuffer(offset_buffer, dtype=quote_update_message_dtype)
                    # print("data_buffer")
                    # print(data_buffer)
                elif message_type == 0x53:
                    offset_start_iex = offset_end_iex
                    offset_end_iex = offset_start_iex + system_event_message_itemsize
                    offset_buffer = iex_buffer[offset_start_iex:offset_end_iex]
                    if len(offset_buffer) != offset_end_iex - offset_start_iex:
                        break

                    data_buffer = np.frombuffer(offset_buffer, dtype=system_event_message_dtype)
                    # print("data_buffer")
                    # print(data_buffer)
                elif message_type == 0x54:
                    offset_start_iex = offset_end_iex
                    offset_end_iex = offset_start_iex + trade_report_message_itemsize
                    offset_buffer = iex_buffer[offset_start_iex:offset_end_iex]
                    if len(offset_buffer) != offset_end_iex - offset_start_iex:
                        break

                    data_buffer = np.frombuffer(offset_buffer, dtype=trade_report_message_dtype)
                    # print("data_buffer")
                    # print(data_buffer)
                elif message_type == 0x58:
                    offset_start_iex = offset_end_iex
                    offset_end_iex = offset_start_iex + official_price_message_itemsize
                    offset_buffer = iex_buffer[offset_start_iex:offset_end_iex]
                    if len(offset_buffer) != offset_end_iex - offset_start_iex:
                        break

                    data_buffer = np.frombuffer(offset_buffer, dtype=official_price_message_dtype)
                    # print("data_buffer")
                    # print(data_buffer)
                else:
                    assert False, f"{message_type} {chr(message_type)}"

                data_count += 1

    return data_count


if __name__ == "__main__":
    pass

And using a generator

import numpy as np

from pcap_parser_util import *


@conditional_njit
def gen__pcap_block(total_buffer):
    PCAP_GLOBAL_HEADER_SIZE = 24  # Bytes

    # Move past the global header
    offset_start = PCAP_GLOBAL_HEADER_SIZE
    offset_end = PCAP_GLOBAL_HEADER_SIZE

    while True:
        offset_start = offset_end
        offset_end = offset_start + pcap_packet_header_itemsize
        offset_buffer = total_buffer[offset_start:offset_end]
        if len(offset_buffer) != offset_end - offset_start:
            break

        pcap_hdr = np.frombuffer(offset_buffer, dtype=pcap_packet_header_dtype)

        if pcap_hdr[0]["original_length"] != pcap_hdr[0]["included_length"]:
            assert False

        offset_start = offset_end
        offset_end = offset_start + pcap_hdr[0]["original_length"]
        offset_buffer = total_buffer[offset_start:offset_end]
        if len(offset_buffer) != offset_end - offset_start:
            break

        yield pcap_hdr, total_buffer[offset_start:offset_end]


@conditional_njit
def gen__iex_block(total_buffer):
    for t in gen__pcap_block(total_buffer):
        pcap_hdr, pcap_buffer = t
        offset_start_ = 0

        offset_start = 0
        offset_end = offset_start + ethernet_header_itemsize
        offset_buffer = pcap_buffer[offset_start:offset_end]
        if len(offset_buffer) != offset_end - offset_start:
            break

        ethernet_hdr = np.frombuffer(offset_buffer, dtype=ethernet_header_dtype)

        offset_start = offset_end
        offset_end = offset_start + ipv4_header_itemsize
        offset_buffer = pcap_buffer[offset_start:offset_end]
        if len(offset_buffer) != offset_end - offset_start:
            break

        ipv4_hdr = np.frombuffer(offset_buffer, dtype=ipv4_header_dtype)

        offset_start = offset_end
        offset_end = offset_start + udp_header_itemsize
        offset_buffer = pcap_buffer[offset_start:offset_end]
        if len(offset_buffer) != offset_end - offset_start:
            break

        udp_hdr = np.frombuffer(offset_buffer, dtype=udp_header_dtype)

        offset_start = offset_end
        offset_end = offset_start + iex_tp_header_itemsize
        offset_buffer = pcap_buffer[offset_start:offset_end]
        if len(offset_buffer) != offset_end - offset_start:
            break

        iex_tp_hdr = np.frombuffer(offset_buffer, dtype=iex_tp_header_dtype)

        offset_start = offset_end
        offset_end = offset_start + iex_tp_hdr[0]["payload_length"]
        offset_buffer = pcap_buffer[offset_start:offset_end]
        if len(offset_buffer) != offset_end - offset_start:
            break

        message_buffer = np.frombuffer(offset_buffer, dtype=np.uint8)

        yield pcap_hdr, ethernet_hdr, ipv4_hdr, udp_hdr, iex_tp_hdr, message_buffer


@conditional_njit
def numba_generator(total_buffer, compile_only):
    if compile_only:
        return

    data_count = 0
    for t in gen__iex_block(total_buffer):
        pcap_hdr, ethernet_hdr, ipv4_hdr, udp_hdr, iex_tp_hdr, message_buffer = t

        offset_start = 0
        offset_end = 0

        while True:
            offset_start = offset_end
            offset_end = offset_start + 2
            offset_buffer = message_buffer[offset_start:offset_end]
            if len(offset_buffer) != offset_end - offset_start:
                break

            message_hdr = np.frombuffer(offset_buffer, dtype=np.uint16)

            offset_start = offset_end
            offset_end = offset_start + 1
            offset_buffer = message_buffer[offset_start:offset_end]
            if len(offset_buffer) != offset_end - offset_start:
                break

            message_type = np.frombuffer(offset_buffer, dtype=np.uint8)[0]
            # print("message_type")
            # print(message_type, chr(message_type))

            offset_end = offset_end - 1

            if message_type == 0x41:
                offset_start = offset_end
                offset_end = offset_start + auction_information_message_itemsize
                offset_buffer = message_buffer[offset_start:offset_end]
                if len(offset_buffer) != offset_end - offset_start:
                    break

                data_buffer = np.frombuffer(offset_buffer, dtype=auction_information_message_dtype)
                # print("data_buffer")
                # print(data_buffer)
            elif message_type == 0x44:
                offset_start = offset_end
                offset_end = offset_start + security_directory_message_itemsize
                offset_buffer = message_buffer[offset_start:offset_end]
                if len(offset_buffer) != offset_end - offset_start:
                    break

                data_buffer = np.frombuffer(
                    offset_buffer, dtype=security_directory_message_dtype
                )
                # print("data_buffer")
                # print(data_buffer)
            elif message_type == 0x48:
                offset_start = offset_end
                offset_end = offset_start + trading_status_message_itemsize
                offset_buffer = message_buffer[offset_start:offset_end]
                if len(offset_buffer) != offset_end - offset_start:
                    break

                data_buffer = np.frombuffer(
                    offset_buffer, dtype=trading_status_message_dtype
                )
                # print("data_buffer")
                # print(data_buffer)
            elif message_type == 0x49:
                offset_start = offset_end
                offset_end = offset_start + retail_liquidity_indicator_message_itemsize
                offset_buffer = message_buffer[offset_start:offset_end]
                if len(offset_buffer) != offset_end - offset_start:
                    break

                data_buffer = np.frombuffer(
                    offset_buffer, dtype=retail_liquidity_indicator_message_dtype
                )
                # print("data_buffer")
                # print(data_buffer)
            elif message_type == 0x4F:
                offset_start = offset_end
                offset_end = offset_start + operational_halt_status_message_itemsize
                offset_buffer = message_buffer[offset_start:offset_end]
                if len(offset_buffer) != offset_end - offset_start:
                    break

                data_buffer = np.frombuffer(
                    offset_buffer, dtype=operational_halt_status_message_dtype
                )
                # print("data_buffer")
                # print(data_buffer)
            elif message_type == 0x50:
                offset_start = offset_end
                offset_end = offset_start + short_sale_price_test_status_message_itemsize
                offset_buffer = message_buffer[offset_start:offset_end]
                if len(offset_buffer) != offset_end - offset_start:
                    break

                data_buffer = np.frombuffer(
                    offset_buffer, dtype=short_sale_price_test_status_message_dtype
                )
                # print("data_buffer")
                # print(data_buffer)
            elif message_type == 0x51:
                offset_start = offset_end
                offset_end = offset_start + quote_update_message_itemsize
                offset_buffer = message_buffer[offset_start:offset_end]
                if len(offset_buffer) != offset_end - offset_start:
                    break

                data_buffer = np.frombuffer(offset_buffer, dtype=quote_update_message_dtype)
                # print("data_buffer")
                # print(data_buffer)
            elif message_type == 0x53:
                offset_start = offset_end
                offset_end = offset_start + system_event_message_itemsize
                offset_buffer = message_buffer[offset_start:offset_end]
                if len(offset_buffer) != offset_end - offset_start:
                    break

                data_buffer = np.frombuffer(offset_buffer, dtype=system_event_message_dtype)
                # print("data_buffer")
                # print(data_buffer)
            elif message_type == 0x54:
                offset_start = offset_end
                offset_end = offset_start + trade_report_message_itemsize
                offset_buffer = message_buffer[offset_start:offset_end]
                if len(offset_buffer) != offset_end - offset_start:
                    break

                data_buffer = np.frombuffer(offset_buffer, dtype=trade_report_message_dtype)
                # print("data_buffer")
                # print(data_buffer)
            elif message_type == 0x58:
                offset_start = offset_end
                offset_end = offset_start + official_price_message_itemsize
                offset_buffer = message_buffer[offset_start:offset_end]
                if len(offset_buffer) != offset_end - offset_start:
                    break

                data_buffer = np.frombuffer(offset_buffer, dtype=official_price_message_dtype)
                # print("data_buffer")
                # print(data_buffer)
            else:
                assert False, f"{message_type} {chr(message_type)}"

            data_count += 1

    return data_count


if __name__ == "__main__":
    pass

Testing the performance of the code to parse some big PCAP of around 5GB

hbina085@akarinpc ~/P/NumbaPcapParser (master)> du -shc /home/hbina085/Downloads/20220324_IEXTP1_TOPS1.6.pcap                                                                                                                (base) 
5.1G    /home/hbina085/Downloads/20220324_IEXTP1_TOPS1.6.pcap
5.1G    total

Running this yields wildly big performance gap

/home/hbina085/miniconda3/envs/numbadev/bin/python /home/hbina085/PycharmProjects/NumbaPcapParser/pcap_parser.py 
file_size 13445714003
GENERATOR
time_end - time_start 39.621919003 s
LOOP
time_end - time_start 11.530971387 s
gen_count: 96434286
loop_count: 96434286

I really like generators since the code for it is much, much cleaner (as you can see above).
Is it possible to optimize generators further in the case where you don’t carry around the generators like a stateful object (See Notes on generators — Numba 0.52.0.dev0+274.g626b40e-py3.7-linux-x86_64.egg documentation)
I was thinking about it, and kinda realized that generators could be desugared into a for loop (if you just loop them fully which can be implied by using the for loop).
So everything should be inlined.

Has anyone attemped this before?
Is there something not possible with this?
My guess is that since Numba only sees the bytecode, maybe its not so obvious that a generator will be consumed fully.
The pattern matching for this optimization might be a bit complex.
I think the optimization mapping itself should be trivial??

Maybe things like this should happen in RustPython instead…since it has the actualy Python AST.