ble-reticulum/migration/tests/test_fragmentation_cpp_equivalence.py

309 lines
11 KiB
Python

import os
import sys
import time
import pytest
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
SRC_DIR = os.path.join(REPO_ROOT, "src")
CPP_BUILD_DIR = os.path.join(REPO_ROOT, "migration", "protocol_core")
sys.path.insert(0, SRC_DIR)
sys.path.insert(0, CPP_BUILD_DIR)
from ble_reticulum.BLEFragmentation import ( # noqa: E402
BLEFragmenter as PyBLEFragmenter,
)
from ble_reticulum.BLEFragmentation import ( # noqa: E402
BLEReassembler as PyBLEReassembler,
)
from ble_reticulum.BLEFragmentation import HDLCFramer as PyHDLCFramer # noqa: E402
cpp = pytest.importorskip(
"ble_protocol_core_cpp",
reason=(
"compiled pybind11 module missing; build with "
"`python3 migration/protocol_core/setup.py build_ext --inplace`"
),
)
CppBLEFragmenter = cpp.BLEFragmenter
CppBLEReassembler = cpp.BLEReassembler
CppHDLCFramer = cpp.HDLCFramer
def assert_same_exception(py_callable, cpp_callable):
with pytest.raises(Exception) as py_exc:
py_callable()
with pytest.raises(Exception) as cpp_exc:
cpp_callable()
assert type(cpp_exc.value) is type(py_exc.value)
assert str(cpp_exc.value) == str(py_exc.value)
def compare_fragmenter(mtu, packet):
py_fragmenter = PyBLEFragmenter(mtu=mtu)
cpp_fragmenter = CppBLEFragmenter(mtu=mtu)
py_fragments = py_fragmenter.fragment_packet(packet)
cpp_fragments = cpp_fragmenter.fragment_packet(packet)
assert cpp_fragmenter.mtu == py_fragmenter.mtu
assert cpp_fragmenter.payload_size == py_fragmenter.payload_size
assert cpp_fragments == py_fragments
assert cpp_fragmenter.get_fragment_overhead(len(packet)) == py_fragmenter.get_fragment_overhead(
len(packet)
)
return py_fragments, cpp_fragments
def compare_reassembly(mtu, packet, order=None, sender_id="device1"):
py_fragments, cpp_fragments = compare_fragmenter(mtu, packet)
py_reassembler = PyBLEReassembler()
cpp_reassembler = CppBLEReassembler()
if order is None:
order = list(range(len(py_fragments)))
py_result = None
cpp_result = None
for index in order:
py_result = py_reassembler.receive_fragment(py_fragments[index], sender_id)
cpp_result = cpp_reassembler.receive_fragment(cpp_fragments[index], sender_id)
assert cpp_result == py_result
assert py_result == packet
assert cpp_result == packet
assert cpp_reassembler.get_statistics() == py_reassembler.get_statistics()
class TestBLEFragmenterCppEquivalence:
def test_single_fragment_packets(self):
for mtu, packet in [
(185, b"Hello, Reticulum!"),
(20, b"A"),
(50, bytes(range(10))),
]:
py_fragments, cpp_fragments = compare_fragmenter(mtu, packet)
assert len(py_fragments) == 1
assert len(cpp_fragments) == 1
def test_multi_fragment_packets(self):
for mtu, packet in [
(185, b"A" * 500),
(100, b"B" * 300),
(20, bytes(range(256))),
]:
py_fragments, cpp_fragments = compare_fragmenter(mtu, packet)
assert len(py_fragments) > 1
assert len(cpp_fragments) > 1
compare_reassembly(mtu, packet)
@pytest.mark.parametrize("mtu", [20, 23, 50, 185])
def test_mtu_boundary_sizes(self, mtu):
payload_size = max(mtu, 20) - PyBLEFragmenter.HEADER_SIZE
sizes = [
1,
payload_size - 1,
payload_size,
payload_size + 1,
payload_size * 2,
payload_size * 2 + 1,
]
for size in sizes:
if size <= 0:
continue
packet = bytes((i % 251 for i in range(size)))
compare_reassembly(mtu, packet)
def test_empty_and_non_bytes_packet_errors(self):
assert_same_exception(
lambda: PyBLEFragmenter().fragment_packet(b""),
lambda: CppBLEFragmenter().fragment_packet(b""),
)
with pytest.raises(TypeError):
CppBLEFragmenter().fragment_packet("not bytes")
with pytest.raises(TypeError):
PyBLEFragmenter().fragment_packet("not bytes")
class TestBLEReassemblerCppEquivalence:
def test_single_fragment_reassembly(self):
compare_reassembly(185, b"Short message")
def test_multi_fragment_reassembly(self):
compare_reassembly(100, b"E" * 300)
def test_out_of_order_fragments_with_start_first(self):
packet = b"F" * 150
py_fragments, _ = compare_fragmenter(50, packet)
assert len(py_fragments) == 4
compare_reassembly(50, packet, order=[0, 2, 1, 3])
def test_malformed_fragments(self):
cases = [
b"\x01\x00",
b"\xff\x00\x00\x00\x01payload",
b"\x01\x00\x01\x00\x01payload",
b"\x01\x00\x00\x00\x00payload",
]
for fragment in cases:
py_reassembler = PyBLEReassembler()
cpp_reassembler = CppBLEReassembler()
assert_same_exception(
lambda fragment=fragment: py_reassembler.receive_fragment(fragment, "device1"),
lambda fragment=fragment: cpp_reassembler.receive_fragment(fragment, "device1"),
)
def test_duplicate_fragments_same_data(self):
packet = b"D" * 160
py_fragments, cpp_fragments = compare_fragmenter(50, packet)
py_reassembler = PyBLEReassembler()
cpp_reassembler = CppBLEReassembler()
assert py_reassembler.receive_fragment(py_fragments[0], "device1") is None
assert cpp_reassembler.receive_fragment(cpp_fragments[0], "device1") is None
assert py_reassembler.receive_fragment(py_fragments[1], "device1") is None
assert cpp_reassembler.receive_fragment(cpp_fragments[1], "device1") is None
assert py_reassembler.receive_fragment(py_fragments[1], "device1") is None
assert cpp_reassembler.receive_fragment(cpp_fragments[1], "device1") is None
py_result = None
cpp_result = None
for index in range(2, len(py_fragments)):
py_result = py_reassembler.receive_fragment(py_fragments[index], "device1")
cpp_result = cpp_reassembler.receive_fragment(cpp_fragments[index], "device1")
assert cpp_result == py_result
assert py_result == packet
assert cpp_result == packet
def test_duplicate_fragments_different_data(self):
packet = b"Q" * 160
py_fragments, cpp_fragments = compare_fragmenter(50, packet)
py_reassembler = PyBLEReassembler()
cpp_reassembler = CppBLEReassembler()
py_reassembler.receive_fragment(py_fragments[0], "device1")
cpp_reassembler.receive_fragment(cpp_fragments[0], "device1")
py_reassembler.receive_fragment(py_fragments[1], "device1")
cpp_reassembler.receive_fragment(cpp_fragments[1], "device1")
py_bad = bytearray(py_fragments[1])
cpp_bad = bytearray(cpp_fragments[1])
py_bad[-1] ^= 0x01
cpp_bad[-1] ^= 0x01
assert_same_exception(
lambda: py_reassembler.receive_fragment(bytes(py_bad), "device1"),
lambda: cpp_reassembler.receive_fragment(bytes(cpp_bad), "device1"),
)
assert len(cpp_reassembler.reassembly_buffers) == len(py_reassembler.reassembly_buffers)
def test_stale_buffer_cleanup(self):
packet = b"G" * 300
py_fragments, cpp_fragments = compare_fragmenter(100, packet)
py_reassembler = PyBLEReassembler(timeout=0.1)
cpp_reassembler = CppBLEReassembler(timeout=0.1)
assert py_reassembler.receive_fragment(py_fragments[0], "device1") is None
assert cpp_reassembler.receive_fragment(cpp_fragments[0], "device1") is None
assert len(cpp_reassembler.reassembly_buffers) == len(py_reassembler.reassembly_buffers)
time.sleep(0.2)
assert cpp_reassembler.cleanup_stale_buffers() == py_reassembler.cleanup_stale_buffers()
assert len(cpp_reassembler.reassembly_buffers) == len(py_reassembler.reassembly_buffers)
assert cpp_reassembler.get_statistics() == py_reassembler.get_statistics()
def test_statistics_reset(self):
packet = b"H" * 300
py_fragments, cpp_fragments = compare_fragmenter(100, packet)
py_reassembler = PyBLEReassembler()
cpp_reassembler = CppBLEReassembler()
for py_fragment, cpp_fragment in zip(py_fragments, cpp_fragments):
assert cpp_reassembler.receive_fragment(cpp_fragment, "device1") == py_reassembler.receive_fragment(
py_fragment, "device1"
)
assert cpp_reassembler.get_statistics() == py_reassembler.get_statistics()
py_reassembler.reset_statistics()
cpp_reassembler.reset_statistics()
assert cpp_reassembler.get_statistics() == py_reassembler.get_statistics()
def test_internal_reassemble_method_matches_python(self):
py_reassembler = PyBLEReassembler()
cpp_reassembler = CppBLEReassembler()
buffer = {"fragments": {0: b"abc", 1: b"def", 2: b"ghi"}, "total": 3}
assert cpp_reassembler._reassemble(buffer) == py_reassembler._reassemble(buffer)
malformed = {"fragments": {0: b"abc", 2: b"ghi"}, "total": 3}
assert_same_exception(
lambda: py_reassembler._reassemble(malformed),
lambda: cpp_reassembler._reassemble(malformed),
)
class TestHDLCFramerCppEquivalence:
@pytest.mark.parametrize(
"packet",
[
b"",
b"Hello, World!",
bytes([0x7E, 0x01, 0x7E]),
bytes([0x7D, 0x02, 0x7D]),
bytes(range(256)),
],
)
def test_frame_deframe_round_trips(self, packet):
py_framed = PyHDLCFramer.frame_packet(packet)
cpp_framed = CppHDLCFramer.frame_packet(packet)
assert cpp_framed == py_framed
assert CppHDLCFramer.deframe_packet(cpp_framed) == PyHDLCFramer.deframe_packet(
py_framed
)
assert CppHDLCFramer.deframe_packet(cpp_framed) == packet
def test_many_hdlc_round_trips(self):
for value in range(256):
packet = bytes([value] * 10)
py_framed = PyHDLCFramer.frame_packet(packet)
cpp_framed = CppHDLCFramer.frame_packet(packet)
assert cpp_framed == py_framed
assert CppHDLCFramer.deframe_packet(cpp_framed) == packet
@pytest.mark.parametrize(
"frame",
[
b"",
b"\x7e",
b"missing-flags",
bytes([PyHDLCFramer.FLAG, 0x01, PyHDLCFramer.FLAG, PyHDLCFramer.FLAG]),
bytes([PyHDLCFramer.FLAG, PyHDLCFramer.ESCAPE, PyHDLCFramer.FLAG]),
],
)
def test_invalid_hdlc_escape_sequences_and_frames(self, frame):
assert_same_exception(
lambda frame=frame: PyHDLCFramer.deframe_packet(frame),
lambda frame=frame: CppHDLCFramer.deframe_packet(frame),
)
def test_non_bytes_errors(self):
with pytest.raises(TypeError):
PyHDLCFramer.frame_packet("not bytes")
with pytest.raises(TypeError):
CppHDLCFramer.frame_packet("not bytes")
with pytest.raises(TypeError):
PyHDLCFramer.deframe_packet("not bytes")
with pytest.raises(TypeError):
CppHDLCFramer.deframe_packet("not bytes")