Merge branch 'main' of github-public:torlando-tech/ble-reticulum
This commit is contained in:
commit
bbeb6c43e8
10 changed files with 971 additions and 27 deletions
|
|
@ -445,7 +445,8 @@ if [[ "$ARCH" == "armhf" ]] || [[ "$(uname -m)" =~ ^(armv6l|armv7l)$ ]]; then
|
|||
print_info "This saves ~20 minutes of compilation time on Pi Zero W"
|
||||
|
||||
WHEEL_URL="https://github.com/torlando-tech/ble-reticulum/releases/download/armv6l-wheels-v1/dbus_fast-2.44.5-cp313-cp313-linux_armv6l.whl"
|
||||
WHEEL_FILE="/tmp/dbus_fast-armv6l-$$.whl"
|
||||
# Use proper wheel filename - pip extracts metadata from the filename
|
||||
WHEEL_FILE="/tmp/dbus_fast-2.44.5-cp313-cp313-linux_armv6l.whl"
|
||||
|
||||
if curl -sL "$WHEEL_URL" -o "$WHEEL_FILE" 2>/dev/null; then
|
||||
if [ -f "$WHEEL_FILE" ] && [ -s "$WHEEL_FILE" ]; then
|
||||
|
|
|
|||
|
|
@ -1036,9 +1036,11 @@ class BLEInterface(Interface):
|
|||
old_interface.detach()
|
||||
RNS.log(f"{self} detached stale interface for {identity_hash[:8]}", RNS.LOG_DEBUG)
|
||||
|
||||
# Clean up address mappings
|
||||
# Clean up address mappings (both directions)
|
||||
if identity_hash in self.identity_to_address:
|
||||
del self.identity_to_address[identity_hash]
|
||||
if old_address in self.address_to_identity:
|
||||
del self.address_to_identity[old_address]
|
||||
|
||||
# Clean up fragmenter/reassembler for old address
|
||||
if peer_identity:
|
||||
|
|
@ -1448,15 +1450,17 @@ class BLEInterface(Interface):
|
|||
|
||||
def _compute_identity_hash(self, peer_identity):
|
||||
"""
|
||||
Compute 16-character hex identity hash for interface tracking.
|
||||
Convert 16-byte identity to 16-character hex string for interface tracking.
|
||||
|
||||
Args:
|
||||
peer_identity: 16-byte peer identity
|
||||
peer_identity: 16-byte peer identity (already a hash from BLE handshake)
|
||||
|
||||
Returns:
|
||||
str: Identity hash (16 hex chars)
|
||||
str: First 16 hex chars of identity (8 bytes)
|
||||
"""
|
||||
return RNS.Identity.full_hash(peer_identity)[:16].hex()[:16]
|
||||
# peer_identity is already the identity hash from BLE handshake
|
||||
# Just convert to hex, don't re-hash (that would corrupt the identity!)
|
||||
return peer_identity.hex()[:16]
|
||||
|
||||
def _spawn_peer_interface(self, address, name, peer_identity, client=None, mtu=None, connection_type="central"):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1612,6 +1612,10 @@ class BluezeroGATTServer:
|
|||
self.stop_event = threading.Event()
|
||||
self.started_event = threading.Event()
|
||||
|
||||
# Event-driven shutdown for D-Bus monitor thread
|
||||
self._monitor_loop = None # Reference to asyncio loop in monitor thread
|
||||
self._async_stop_event = None # asyncio.Event for clean shutdown
|
||||
|
||||
# Connected centrals (address -> info dict)
|
||||
self.connected_centrals: Dict[str, dict] = {}
|
||||
self.centrals_lock = threading.RLock()
|
||||
|
|
@ -1748,6 +1752,12 @@ class BluezeroGATTServer:
|
|||
device_proxies = {} # Track proxy objects for each device
|
||||
|
||||
try:
|
||||
# Set up event-driven shutdown mechanism
|
||||
loop = asyncio.get_running_loop()
|
||||
async_stop = asyncio.Event()
|
||||
self._monitor_loop = loop
|
||||
self._async_stop_event = async_stop
|
||||
|
||||
# Connect to system bus
|
||||
if RNS:
|
||||
RNS.log(f"{self.log_prefix} [GATT-MONITOR] Connecting to D-Bus...", RNS.LOG_EXTREME)
|
||||
|
|
@ -1891,13 +1901,14 @@ class BluezeroGATTServer:
|
|||
if RNS:
|
||||
RNS.log(f"{self.log_prefix} [GATT-MONITOR] Entering wait loop...", RNS.LOG_EXTREME)
|
||||
|
||||
# Poll stop_event and yield to event loop to process D-Bus messages
|
||||
while not self.stop_event.is_set():
|
||||
await asyncio.sleep(0.5)
|
||||
# Wait for stop signal (event-driven, no polling)
|
||||
# D-Bus signals are processed automatically by the event loop
|
||||
# The async_stop event is set via call_soon_threadsafe from stop()
|
||||
await async_stop.wait()
|
||||
|
||||
if RNS:
|
||||
RNS.log(f"{self.log_prefix} [GATT-MONITOR] Stop event set, exiting loop", RNS.LOG_EXTREME)
|
||||
self._log("D-Bus monitoring loop exiting", "DEBUG")
|
||||
RNS.log(f"{self.log_prefix} [GATT-MONITOR] Stop event received, exiting loop", RNS.LOG_EXTREME)
|
||||
self._log("D-Bus monitoring loop exiting (stop signal received)", "DEBUG")
|
||||
|
||||
except Exception as e:
|
||||
if RNS:
|
||||
|
|
@ -1964,14 +1975,11 @@ class BluezeroGATTServer:
|
|||
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
# Wait for 30 seconds (check stop_event frequently)
|
||||
for _ in range(60): # 60 * 0.5s = 30s
|
||||
if self.stop_event.is_set():
|
||||
break
|
||||
time.sleep(0.5)
|
||||
|
||||
if self.stop_event.is_set():
|
||||
break
|
||||
# Wait for 300 seconds (5 min), wake immediately on stop signal
|
||||
# This is a fallback safety net for missed D-Bus signals
|
||||
# Using threading.Event.wait() instead of busy-loop for clean shutdown
|
||||
if self.stop_event.wait(timeout=300.0):
|
||||
break # Stop was signaled
|
||||
|
||||
# Check all connected centrals
|
||||
with self.centrals_lock:
|
||||
|
|
@ -2141,6 +2149,13 @@ class BluezeroGATTServer:
|
|||
self.stop_event.set()
|
||||
self.running = False
|
||||
|
||||
# Wake the async D-Bus monitor loop immediately (event-driven shutdown)
|
||||
if self._monitor_loop and self._async_stop_event:
|
||||
try:
|
||||
self._monitor_loop.call_soon_threadsafe(self._async_stop_event.set)
|
||||
except RuntimeError:
|
||||
pass # Loop already stopped
|
||||
|
||||
# Wait for server thread to exit
|
||||
if self.server_thread and self.server_thread.is_alive():
|
||||
self.server_thread.join(timeout=5.0)
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ if src_dir not in sys.path:
|
|||
|
||||
import pytest
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import Mock, AsyncMock, MagicMock, patch
|
||||
from types import ModuleType
|
||||
|
|
@ -309,8 +310,8 @@ def create_mock_ble_interface(owner=None, config=None):
|
|||
interface.connection_blacklist = {}
|
||||
interface.fragmenters = {}
|
||||
interface.reassemblers = {}
|
||||
interface.peer_lock = asyncio.Lock()
|
||||
interface.frag_lock = asyncio.Lock()
|
||||
interface.peer_lock = threading.RLock() # Use threading lock for mock
|
||||
interface.frag_lock = threading.RLock() # Use threading lock for mock
|
||||
interface.loop = asyncio.get_event_loop()
|
||||
interface.max_peers = config.get('max_connections', 7) if config else 7
|
||||
interface.min_rssi = config.get('min_rssi', -80) if config else -80
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ including data flow, fragmentation, and both central/peripheral modes.
|
|||
|
||||
import pytest
|
||||
import asyncio
|
||||
import threading
|
||||
from unittest.mock import Mock, AsyncMock, patch, MagicMock
|
||||
|
||||
# Import fragmentation for testing
|
||||
|
|
@ -36,8 +37,8 @@ def create_mock_peer_interface(peer_address="AA:BB:CC:DD:EE:FF", peer_name="Test
|
|||
parent.peers = {peer_address: (Mock(is_connected=True), 0, 185)}
|
||||
parent.fragmenters = {peer_address: BLEFragmenter(mtu=185) if BLEFragmenter else Mock()}
|
||||
parent.reassemblers = {peer_address: BLEReassembler() if BLEReassembler else Mock()}
|
||||
parent.frag_lock = asyncio.Lock()
|
||||
parent.peer_lock = asyncio.Lock()
|
||||
parent.frag_lock = threading.RLock() # Use threading lock for mock
|
||||
parent.peer_lock = threading.RLock() # Use threading lock for mock
|
||||
parent.loop = asyncio.get_event_loop()
|
||||
parent.gatt_server = Mock()
|
||||
parent.gatt_server.send_notification = AsyncMock(return_value=True)
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ import pytest
|
|||
import sys
|
||||
import os
|
||||
import asyncio
|
||||
import threading
|
||||
from unittest.mock import Mock, MagicMock, AsyncMock, patch, call
|
||||
|
||||
# Add src to path
|
||||
|
|
@ -54,7 +55,7 @@ class TestBlueZStateCleanup:
|
|||
driver = Mock()
|
||||
driver.loop = asyncio.new_event_loop()
|
||||
driver._connecting_peers = set()
|
||||
driver._connecting_lock = asyncio.Lock()
|
||||
driver._connecting_lock = threading.RLock() # Use threading lock for mock
|
||||
driver._remove_bluez_device = AsyncMock(return_value=True)
|
||||
driver._log = Mock()
|
||||
return driver
|
||||
|
|
|
|||
791
tests/test_hci_error_fixes.py
Normal file
791
tests/test_hci_error_fixes.py
Normal file
|
|
@ -0,0 +1,791 @@
|
|||
"""
|
||||
Tests for HCI Error Fixes (Event-Driven D-Bus Monitoring)
|
||||
|
||||
Tests the fixes for HCI errors on BCM43xx single-radio Bluetooth chips.
|
||||
The root cause was D-Bus monitoring threads polling every 0.5s, causing
|
||||
radio contention with advertising/scanning operations.
|
||||
|
||||
Fixes tested:
|
||||
1. Event-driven D-Bus monitor: Uses asyncio.Event instead of polling
|
||||
2. Stale poll improvements: Uses threading.Event.wait() instead of busy-wait
|
||||
3. Stop() shutdown behavior: Uses call_soon_threadsafe for immediate stop
|
||||
|
||||
Reference: /tmp/hci_error_analysis.md
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import os
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import Mock, MagicMock, AsyncMock, patch, PropertyMock
|
||||
|
||||
# Add src to path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../src'))
|
||||
|
||||
# Mock RNS module before importing
|
||||
import RNS
|
||||
if not hasattr(RNS, 'LOG_INFO'):
|
||||
RNS.LOG_CRITICAL = 0
|
||||
RNS.LOG_ERROR = 1
|
||||
RNS.LOG_WARNING = 2
|
||||
RNS.LOG_NOTICE = 3
|
||||
RNS.LOG_INFO = 4
|
||||
RNS.LOG_VERBOSE = 5
|
||||
RNS.LOG_DEBUG = 6
|
||||
RNS.LOG_EXTREME = 7
|
||||
|
||||
RNS.log = Mock()
|
||||
|
||||
|
||||
class TestEventDrivenDBusMonitor:
|
||||
"""Test event-driven D-Bus monitoring (replaces polling)."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_driver(self):
|
||||
"""Create mock driver with required attributes."""
|
||||
driver = Mock()
|
||||
driver._peers = {}
|
||||
driver._peers_lock = threading.RLock()
|
||||
driver._log = Mock()
|
||||
driver._handle_peripheral_disconnected = Mock()
|
||||
return driver
|
||||
|
||||
@pytest.fixture
|
||||
def mock_gatt_server(self, mock_driver):
|
||||
"""Create mock GATT server with event-driven monitoring setup."""
|
||||
from RNS.Interfaces.linux_bluetooth_driver import BluezeroGATTServer
|
||||
|
||||
server = Mock(spec=BluezeroGATTServer)
|
||||
server.driver = mock_driver
|
||||
server.stop_event = threading.Event()
|
||||
server.connected_centrals = {}
|
||||
server.centrals_lock = threading.RLock()
|
||||
server._log = Mock()
|
||||
server._handle_central_disconnected = Mock()
|
||||
server.log_prefix = "[TEST]"
|
||||
|
||||
# Event-driven shutdown attributes
|
||||
server._monitor_loop = None
|
||||
server._async_stop_event = None
|
||||
|
||||
return server
|
||||
|
||||
def test_async_stop_event_initialized(self, mock_gatt_server):
|
||||
"""Test that _async_stop_event is initialized to None before monitoring starts."""
|
||||
assert mock_gatt_server._async_stop_event is None
|
||||
assert mock_gatt_server._monitor_loop is None
|
||||
|
||||
def test_async_event_setup_in_monitor_loop(self):
|
||||
"""Test that asyncio.Event and loop reference are set up correctly in monitor_loop."""
|
||||
async def async_test():
|
||||
loop = asyncio.get_running_loop()
|
||||
async_stop = asyncio.Event()
|
||||
|
||||
# These would be set on self in the real implementation
|
||||
_monitor_loop = loop
|
||||
_async_stop_event = async_stop
|
||||
|
||||
# Verify they're set correctly
|
||||
assert _monitor_loop is loop
|
||||
assert _async_stop_event is async_stop
|
||||
assert isinstance(_async_stop_event, asyncio.Event)
|
||||
assert not _async_stop_event.is_set()
|
||||
|
||||
asyncio.run(async_test())
|
||||
|
||||
def test_async_event_wait_blocks_until_set(self):
|
||||
"""Test that await async_stop.wait() blocks until event is set."""
|
||||
async def async_test():
|
||||
async_stop = asyncio.Event()
|
||||
wait_completed = False
|
||||
|
||||
async def wait_for_stop():
|
||||
nonlocal wait_completed
|
||||
await async_stop.wait()
|
||||
wait_completed = True
|
||||
|
||||
# Start waiting
|
||||
wait_task = asyncio.create_task(wait_for_stop())
|
||||
|
||||
# Should still be waiting
|
||||
await asyncio.sleep(0.1)
|
||||
assert not wait_completed
|
||||
|
||||
# Set the event
|
||||
async_stop.set()
|
||||
|
||||
# Should complete quickly
|
||||
await asyncio.wait_for(wait_task, timeout=1.0)
|
||||
assert wait_completed
|
||||
|
||||
asyncio.run(async_test())
|
||||
|
||||
def test_async_event_wakes_immediately_when_set(self):
|
||||
"""Test that async event wakes immediately (no 5s delay like polling)."""
|
||||
async def async_test():
|
||||
async_stop = asyncio.Event()
|
||||
wake_time = None
|
||||
|
||||
async def wait_for_stop():
|
||||
nonlocal wake_time
|
||||
await async_stop.wait()
|
||||
wake_time = time.time()
|
||||
|
||||
# Start waiting
|
||||
wait_task = asyncio.create_task(wait_for_stop())
|
||||
await asyncio.sleep(0.01) # Let task start
|
||||
|
||||
# Set event and measure response time
|
||||
set_time = time.time()
|
||||
async_stop.set()
|
||||
|
||||
await asyncio.wait_for(wait_task, timeout=1.0)
|
||||
|
||||
# Response should be immediate (< 100ms vs 5000ms polling)
|
||||
response_time = wake_time - set_time
|
||||
assert response_time < 0.1, f"Response time {response_time}s should be < 0.1s"
|
||||
|
||||
asyncio.run(async_test())
|
||||
|
||||
def test_call_soon_threadsafe_sets_event(self):
|
||||
"""Test that call_soon_threadsafe can set async event from another thread."""
|
||||
async def async_test():
|
||||
async_stop = asyncio.Event()
|
||||
loop = asyncio.get_running_loop()
|
||||
event_was_set = False
|
||||
|
||||
async def wait_for_stop():
|
||||
nonlocal event_was_set
|
||||
await async_stop.wait()
|
||||
event_was_set = True
|
||||
|
||||
# Start waiting in async context
|
||||
wait_task = asyncio.create_task(wait_for_stop())
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Set event from sync code (simulating stop() call)
|
||||
# In real implementation this would be from another thread
|
||||
loop.call_soon_threadsafe(async_stop.set)
|
||||
|
||||
await asyncio.wait_for(wait_task, timeout=1.0)
|
||||
assert event_was_set
|
||||
|
||||
asyncio.run(async_test())
|
||||
|
||||
def test_call_soon_threadsafe_from_thread(self):
|
||||
"""Test that call_soon_threadsafe works from a separate thread."""
|
||||
event_set_in_loop = threading.Event()
|
||||
loop_started = threading.Event()
|
||||
|
||||
stored_loop = None
|
||||
stored_event = None
|
||||
|
||||
async def async_main():
|
||||
nonlocal stored_loop, stored_event
|
||||
async_stop = asyncio.Event()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# Store for cross-thread access
|
||||
stored_loop = loop
|
||||
stored_event = async_stop
|
||||
|
||||
loop_started.set()
|
||||
|
||||
# Wait for signal
|
||||
await async_stop.wait()
|
||||
event_set_in_loop.set()
|
||||
|
||||
# Run async code in thread
|
||||
async_thread = threading.Thread(
|
||||
target=lambda: asyncio.run(async_main()),
|
||||
daemon=True
|
||||
)
|
||||
async_thread.start()
|
||||
|
||||
# Wait for loop to start
|
||||
loop_started.wait(timeout=2.0)
|
||||
assert stored_loop is not None
|
||||
assert stored_event is not None
|
||||
|
||||
# Signal from main thread
|
||||
stored_loop.call_soon_threadsafe(stored_event.set)
|
||||
|
||||
# Verify event was set
|
||||
event_set_in_loop.wait(timeout=2.0)
|
||||
assert event_set_in_loop.is_set()
|
||||
|
||||
async_thread.join(timeout=1.0)
|
||||
|
||||
def test_no_polling_loop_in_monitor(self):
|
||||
"""Test that there's no periodic polling - just event wait."""
|
||||
async def async_test():
|
||||
# The implementation should NOT have this pattern:
|
||||
# while not self.stop_event.is_set():
|
||||
# await asyncio.sleep(0.5) # BAD - polling
|
||||
|
||||
# Instead it should use:
|
||||
# await async_stop.wait() # GOOD - event-driven
|
||||
|
||||
async_stop = asyncio.Event()
|
||||
iterations = 0
|
||||
|
||||
async def event_driven_wait():
|
||||
nonlocal iterations
|
||||
# This is the correct pattern - single wait, no loop iterations
|
||||
await async_stop.wait()
|
||||
iterations = 1 # Only one "iteration" - the wait itself
|
||||
|
||||
# Test event-driven pattern
|
||||
async_stop.set() # Set immediately for test
|
||||
await event_driven_wait()
|
||||
|
||||
# Event-driven should only have 1 "iteration"
|
||||
assert iterations == 1
|
||||
|
||||
asyncio.run(async_test())
|
||||
|
||||
|
||||
class TestStalePollImprovements:
|
||||
"""Test stale connection polling improvements."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_gatt_server(self):
|
||||
"""Create mock GATT server with polling setup."""
|
||||
server = Mock()
|
||||
server.stop_event = threading.Event()
|
||||
server.connected_centrals = {}
|
||||
server.centrals_lock = threading.RLock()
|
||||
server._log = Mock()
|
||||
server._handle_central_disconnected = Mock()
|
||||
server.log_prefix = "[TEST]"
|
||||
return server
|
||||
|
||||
def test_event_wait_used_instead_of_busy_loop(self):
|
||||
"""Test that threading.Event.wait(timeout) is used instead of busy-wait."""
|
||||
stop_event = threading.Event()
|
||||
wait_called = False
|
||||
|
||||
# The implementation should use:
|
||||
# if self.stop_event.wait(timeout=300.0):
|
||||
# break
|
||||
|
||||
# Not the old pattern:
|
||||
# for _ in range(240):
|
||||
# if self.stop_event.is_set():
|
||||
# break
|
||||
# time.sleep(0.5)
|
||||
|
||||
def proper_wait_pattern():
|
||||
nonlocal wait_called
|
||||
if stop_event.wait(timeout=0.1): # Short timeout for test
|
||||
wait_called = True
|
||||
return True
|
||||
return False
|
||||
|
||||
# Should return False when not set
|
||||
result = proper_wait_pattern()
|
||||
assert not result
|
||||
assert not wait_called
|
||||
|
||||
# Should return True when set
|
||||
stop_event.set()
|
||||
result = proper_wait_pattern()
|
||||
assert result
|
||||
assert wait_called
|
||||
|
||||
def test_immediate_stop_response(self):
|
||||
"""Test that stop signal is responded to immediately (not after 0.5s polls)."""
|
||||
stop_event = threading.Event()
|
||||
response_time = None
|
||||
thread_exited = threading.Event()
|
||||
|
||||
def wait_loop():
|
||||
nonlocal response_time
|
||||
start = time.time()
|
||||
# Using Event.wait() pattern
|
||||
stop_event.wait(timeout=300.0)
|
||||
response_time = time.time() - start
|
||||
thread_exited.set()
|
||||
|
||||
thread = threading.Thread(target=wait_loop, daemon=True)
|
||||
thread.start()
|
||||
|
||||
# Let thread start waiting
|
||||
time.sleep(0.05)
|
||||
|
||||
# Signal stop
|
||||
stop_event.set()
|
||||
|
||||
# Wait for thread to respond
|
||||
thread_exited.wait(timeout=1.0)
|
||||
|
||||
# Response should be immediate (< 100ms vs old 500ms polling interval)
|
||||
assert response_time is not None
|
||||
assert response_time < 0.15, f"Response time {response_time}s should be < 0.15s"
|
||||
|
||||
def test_poll_interval_is_300_seconds(self):
|
||||
"""Test that stale poll interval is 300 seconds (5 minutes)."""
|
||||
# The implementation uses:
|
||||
# if self.stop_event.wait(timeout=300.0):
|
||||
|
||||
# Verify the constant value (we can't easily test 5 min wait in unit test)
|
||||
EXPECTED_INTERVAL = 300.0
|
||||
|
||||
# This would be tested by reading the actual code value
|
||||
# For now, we simulate what the implementation should do
|
||||
stop_event = threading.Event()
|
||||
poll_count = 0
|
||||
|
||||
def poll_loop():
|
||||
nonlocal poll_count
|
||||
while not stop_event.is_set():
|
||||
# In real code, this is 300.0
|
||||
if stop_event.wait(timeout=0.01): # Short for test
|
||||
break
|
||||
poll_count += 1
|
||||
if poll_count >= 5: # Limit iterations for test
|
||||
break
|
||||
|
||||
thread = threading.Thread(target=poll_loop, daemon=True)
|
||||
thread.start()
|
||||
|
||||
time.sleep(0.1)
|
||||
stop_event.set()
|
||||
thread.join(timeout=1.0)
|
||||
|
||||
# Thread should have exited cleanly
|
||||
assert not thread.is_alive()
|
||||
|
||||
def test_single_wait_call_per_interval(self):
|
||||
"""Test that each interval uses single wait() call, not 600 iterations."""
|
||||
stop_event = threading.Event()
|
||||
wait_call_count = 0
|
||||
original_wait = threading.Event.wait
|
||||
|
||||
def counting_wait(self, timeout=None):
|
||||
nonlocal wait_call_count
|
||||
wait_call_count += 1
|
||||
return original_wait(self, timeout=0.01 if timeout else timeout)
|
||||
|
||||
with patch.object(threading.Event, 'wait', counting_wait):
|
||||
# Simulate one "poll cycle"
|
||||
stop_event.wait(timeout=300.0) # This is the new pattern
|
||||
|
||||
# Should be just 1 wait call, not 600+ like the old busy-loop
|
||||
assert wait_call_count == 1
|
||||
|
||||
def test_no_busy_loop_iterations(self):
|
||||
"""Test that the old busy-loop pattern is not used."""
|
||||
stop_event = threading.Event()
|
||||
sleep_count = 0
|
||||
|
||||
# Old pattern would have 600 sleep calls per 5-minute interval:
|
||||
# for _ in range(600):
|
||||
# if self.stop_event.is_set():
|
||||
# break
|
||||
# time.sleep(0.5)
|
||||
|
||||
# New pattern has zero sleep calls:
|
||||
# if self.stop_event.wait(timeout=300.0):
|
||||
# break
|
||||
|
||||
def new_poll_pattern():
|
||||
nonlocal sleep_count
|
||||
# New pattern - no sleep calls
|
||||
if stop_event.wait(timeout=0.01):
|
||||
return True
|
||||
# No time.sleep() here!
|
||||
return False
|
||||
|
||||
new_poll_pattern()
|
||||
|
||||
# No explicit sleep calls in new pattern
|
||||
assert sleep_count == 0
|
||||
|
||||
|
||||
class TestStopShutdownBehavior:
|
||||
"""Test stop() method shutdown behavior."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_gatt_server(self):
|
||||
"""Create mock GATT server for stop() testing."""
|
||||
server = Mock()
|
||||
server.stop_event = threading.Event()
|
||||
server.running = True
|
||||
server._log = Mock()
|
||||
server._monitor_loop = None
|
||||
server._async_stop_event = None
|
||||
server.server_thread = None
|
||||
server.disconnect_monitor_thread = None
|
||||
server.stale_poll_thread = None
|
||||
server.ble_agent = None
|
||||
server.connected_centrals = {}
|
||||
server.centrals_lock = threading.RLock()
|
||||
return server
|
||||
|
||||
def test_stop_sets_stop_event(self, mock_gatt_server):
|
||||
"""Test that stop() sets the stop_event."""
|
||||
assert not mock_gatt_server.stop_event.is_set()
|
||||
|
||||
# Simulate stop() behavior
|
||||
mock_gatt_server.stop_event.set()
|
||||
|
||||
assert mock_gatt_server.stop_event.is_set()
|
||||
|
||||
def test_stop_signals_async_event(self):
|
||||
"""Test that stop() signals async event via call_soon_threadsafe."""
|
||||
async def async_test():
|
||||
async_stop = asyncio.Event()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# Simulate stop() calling call_soon_threadsafe
|
||||
loop.call_soon_threadsafe(async_stop.set)
|
||||
|
||||
# Give event loop a chance to process
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
assert async_stop.is_set()
|
||||
|
||||
asyncio.run(async_test())
|
||||
|
||||
def test_stop_handles_runtime_error_gracefully(self):
|
||||
"""Test that stop() handles RuntimeError when loop is already stopped."""
|
||||
# Create a mock loop that raises RuntimeError
|
||||
mock_loop = Mock()
|
||||
mock_loop.call_soon_threadsafe = Mock(side_effect=RuntimeError("Loop is closed"))
|
||||
|
||||
mock_async_stop = Mock()
|
||||
|
||||
# Simulate the stop() error handling code
|
||||
_monitor_loop = mock_loop
|
||||
_async_stop_event = mock_async_stop
|
||||
|
||||
try:
|
||||
_monitor_loop.call_soon_threadsafe(_async_stop_event.set)
|
||||
except RuntimeError:
|
||||
pass # Should be caught and ignored
|
||||
|
||||
# Test passes if no exception is raised
|
||||
|
||||
def test_stop_checks_for_none_references(self):
|
||||
"""Test that stop() checks for None before calling call_soon_threadsafe."""
|
||||
_monitor_loop = None
|
||||
_async_stop_event = None
|
||||
|
||||
# Simulate the stop() check
|
||||
if _monitor_loop and _async_stop_event:
|
||||
# This should NOT be reached
|
||||
pytest.fail("Should not call when references are None")
|
||||
|
||||
# Test passes - no error when refs are None
|
||||
|
||||
def test_shutdown_latency_improvement(self):
|
||||
"""Test that shutdown responds immediately (not up to 5s delay)."""
|
||||
stop_event = threading.Event()
|
||||
async_stop_triggered = threading.Event()
|
||||
thread_exited = threading.Event()
|
||||
|
||||
stored_loop = None
|
||||
stored_async_stop = None
|
||||
|
||||
async def mock_monitor_loop():
|
||||
nonlocal stored_loop, stored_async_stop
|
||||
async_stop = asyncio.Event()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# Store refs for "stop()" to access
|
||||
stored_loop = loop
|
||||
stored_async_stop = async_stop
|
||||
|
||||
async_stop_triggered.set() # Signal refs are ready
|
||||
|
||||
# Wait for stop signal
|
||||
await async_stop.wait()
|
||||
|
||||
def run_async():
|
||||
try:
|
||||
asyncio.run(mock_monitor_loop())
|
||||
except Exception:
|
||||
pass
|
||||
thread_exited.set()
|
||||
|
||||
# Start monitor thread
|
||||
thread = threading.Thread(target=run_async, daemon=True)
|
||||
thread.start()
|
||||
|
||||
# Wait for async loop to be ready
|
||||
async_stop_triggered.wait(timeout=2.0)
|
||||
assert stored_loop is not None
|
||||
|
||||
# Measure shutdown time
|
||||
start = time.time()
|
||||
|
||||
# Simulate stop() calling call_soon_threadsafe
|
||||
stop_event.set()
|
||||
stored_loop.call_soon_threadsafe(stored_async_stop.set)
|
||||
|
||||
# Wait for thread to exit
|
||||
thread_exited.wait(timeout=2.0)
|
||||
shutdown_time = time.time() - start
|
||||
|
||||
# Shutdown should be fast (< 500ms vs old 5000ms max)
|
||||
assert shutdown_time < 0.5, f"Shutdown took {shutdown_time}s, should be < 0.5s"
|
||||
|
||||
def test_stop_waits_for_threads_with_timeout(self, mock_gatt_server):
|
||||
"""Test that stop() waits for threads with reasonable timeouts."""
|
||||
# Create mock threads
|
||||
mock_server_thread = Mock()
|
||||
mock_server_thread.is_alive = Mock(return_value=True)
|
||||
mock_server_thread.join = Mock()
|
||||
|
||||
mock_monitor_thread = Mock()
|
||||
mock_monitor_thread.is_alive = Mock(return_value=True)
|
||||
mock_monitor_thread.join = Mock()
|
||||
|
||||
mock_poll_thread = Mock()
|
||||
mock_poll_thread.is_alive = Mock(return_value=True)
|
||||
mock_poll_thread.join = Mock()
|
||||
|
||||
# Simulate stop() thread joins
|
||||
mock_gatt_server.server_thread = mock_server_thread
|
||||
mock_gatt_server.disconnect_monitor_thread = mock_monitor_thread
|
||||
mock_gatt_server.stale_poll_thread = mock_poll_thread
|
||||
|
||||
# Verify join is called with appropriate timeouts
|
||||
if mock_gatt_server.server_thread and mock_gatt_server.server_thread.is_alive():
|
||||
mock_gatt_server.server_thread.join(timeout=5.0)
|
||||
|
||||
if mock_gatt_server.disconnect_monitor_thread and mock_gatt_server.disconnect_monitor_thread.is_alive():
|
||||
mock_gatt_server.disconnect_monitor_thread.join(timeout=2.0)
|
||||
|
||||
if mock_gatt_server.stale_poll_thread and mock_gatt_server.stale_poll_thread.is_alive():
|
||||
mock_gatt_server.stale_poll_thread.join(timeout=2.0)
|
||||
|
||||
# Verify joins were called with timeouts
|
||||
mock_server_thread.join.assert_called_once_with(timeout=5.0)
|
||||
mock_monitor_thread.join.assert_called_once_with(timeout=2.0)
|
||||
mock_poll_thread.join.assert_called_once_with(timeout=2.0)
|
||||
|
||||
|
||||
class TestIntegrationScenarios:
|
||||
"""Integration tests for HCI error fix scenarios."""
|
||||
|
||||
def test_full_lifecycle_start_to_stop(self):
|
||||
"""Test complete lifecycle with event-driven monitoring."""
|
||||
stop_event = threading.Event()
|
||||
monitor_started = threading.Event()
|
||||
monitor_stopped = threading.Event()
|
||||
|
||||
stored_loop = None
|
||||
stored_async_stop = None
|
||||
|
||||
async def mock_monitor():
|
||||
nonlocal stored_loop, stored_async_stop
|
||||
async_stop = asyncio.Event()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
stored_loop = loop
|
||||
stored_async_stop = async_stop
|
||||
|
||||
monitor_started.set()
|
||||
|
||||
# Event-driven wait (not polling)
|
||||
await async_stop.wait()
|
||||
|
||||
def run_monitor():
|
||||
try:
|
||||
asyncio.run(mock_monitor())
|
||||
except Exception:
|
||||
pass
|
||||
monitor_stopped.set()
|
||||
|
||||
# Start monitoring
|
||||
thread = threading.Thread(target=run_monitor, daemon=True)
|
||||
thread.start()
|
||||
|
||||
# Wait for start
|
||||
monitor_started.wait(timeout=2.0)
|
||||
assert stored_loop is not None
|
||||
assert stored_async_stop is not None
|
||||
|
||||
# Simulate stop()
|
||||
stop_event.set()
|
||||
stored_loop.call_soon_threadsafe(stored_async_stop.set)
|
||||
|
||||
# Wait for clean shutdown
|
||||
monitor_stopped.wait(timeout=2.0)
|
||||
thread.join(timeout=1.0)
|
||||
|
||||
assert not thread.is_alive()
|
||||
assert monitor_stopped.is_set()
|
||||
|
||||
def test_multiple_stop_calls_safe(self):
|
||||
"""Test that multiple stop() calls don't cause issues."""
|
||||
stop_event = threading.Event()
|
||||
|
||||
# First stop
|
||||
stop_event.set()
|
||||
assert stop_event.is_set()
|
||||
|
||||
# Second stop (should be safe)
|
||||
stop_event.set()
|
||||
assert stop_event.is_set()
|
||||
|
||||
# Clear and set again (simulating restart + stop)
|
||||
stop_event.clear()
|
||||
assert not stop_event.is_set()
|
||||
stop_event.set()
|
||||
assert stop_event.is_set()
|
||||
|
||||
def test_dbus_signals_still_processed_during_wait(self):
|
||||
"""Test that D-Bus signals are processed while waiting for stop."""
|
||||
async def async_test():
|
||||
async_stop = asyncio.Event()
|
||||
signal_received = False
|
||||
|
||||
def handle_signal():
|
||||
nonlocal signal_received
|
||||
signal_received = True
|
||||
|
||||
# Start wait task
|
||||
wait_task = asyncio.create_task(async_stop.wait())
|
||||
|
||||
# Simulate D-Bus signal (would be scheduled via event loop)
|
||||
await asyncio.sleep(0.01)
|
||||
handle_signal() # Signal handler runs
|
||||
|
||||
# Verify signal was processed
|
||||
assert signal_received
|
||||
|
||||
# Stop wait
|
||||
async_stop.set()
|
||||
await wait_task
|
||||
|
||||
asyncio.run(async_test())
|
||||
|
||||
|
||||
class TestNoPollingVerification:
|
||||
"""Verify that no polling patterns exist in the fixes."""
|
||||
|
||||
def test_no_05_second_sleep_in_monitor(self):
|
||||
"""Verify the old 0.5s sleep pattern is not used in D-Bus monitor."""
|
||||
# The old pattern was:
|
||||
# await asyncio.sleep(0.5) # BAD
|
||||
|
||||
# New pattern:
|
||||
# await async_stop.wait() # GOOD
|
||||
|
||||
# This test verifies the concept
|
||||
async def no_polling_pattern():
|
||||
async_stop = asyncio.Event()
|
||||
# No periodic sleep!
|
||||
async_stop.set() # Immediately set for test
|
||||
await async_stop.wait()
|
||||
|
||||
# Should complete without any sleep delays
|
||||
start = time.time()
|
||||
asyncio.run(no_polling_pattern())
|
||||
elapsed = time.time() - start
|
||||
|
||||
# Should complete in < 100ms (no 500ms sleeps)
|
||||
assert elapsed < 0.1
|
||||
|
||||
def test_no_busy_loop_in_stale_poll(self):
|
||||
"""Verify the old busy-loop pattern is not used in stale poll."""
|
||||
stop_event = threading.Event()
|
||||
iterations = 0
|
||||
|
||||
# Old pattern (BAD):
|
||||
# for _ in range(240): # 240 * 0.5s = 120s
|
||||
# if self.stop_event.is_set():
|
||||
# break
|
||||
# time.sleep(0.5)
|
||||
|
||||
# New pattern (GOOD):
|
||||
def new_pattern():
|
||||
nonlocal iterations
|
||||
if stop_event.wait(timeout=0.01): # Short timeout for test
|
||||
return True
|
||||
iterations += 1
|
||||
return False
|
||||
|
||||
# Run the new pattern
|
||||
new_pattern()
|
||||
|
||||
# Should only have 1 iteration (the wait call itself)
|
||||
assert iterations == 1
|
||||
|
||||
|
||||
class TestCodeVerification:
|
||||
"""Tests that verify the actual implementation has correct patterns."""
|
||||
|
||||
def test_verify_poll_interval_in_code(self):
|
||||
"""Verify that the actual code uses 300s poll interval."""
|
||||
import re
|
||||
|
||||
# Read the actual source file
|
||||
source_path = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
'../src/RNS/Interfaces/linux_bluetooth_driver.py'
|
||||
)
|
||||
|
||||
with open(source_path, 'r') as f:
|
||||
source = f.read()
|
||||
|
||||
# Find the poll_stale_connections method and verify it uses 300.0 timeout
|
||||
# Pattern: stop_event.wait(timeout=300.0)
|
||||
poll_pattern = r'self\.stop_event\.wait\(timeout=(\d+\.?\d*)\)'
|
||||
matches = re.findall(poll_pattern, source)
|
||||
|
||||
assert '300.0' in matches, f"Expected 300.0 second timeout, found: {matches}"
|
||||
|
||||
def test_verify_event_driven_wait_in_code(self):
|
||||
"""Verify that the actual code uses async_stop.wait() not polling."""
|
||||
import re
|
||||
|
||||
source_path = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
'../src/RNS/Interfaces/linux_bluetooth_driver.py'
|
||||
)
|
||||
|
||||
with open(source_path, 'r') as f:
|
||||
source = f.read()
|
||||
|
||||
# The code should have: await async_stop.wait()
|
||||
assert 'await async_stop.wait()' in source, "Should use event-driven wait"
|
||||
|
||||
# The code should NOT have the old polling pattern in the monitor loop
|
||||
# Look for the monitor_loop function and check it doesn't have asyncio.sleep polling
|
||||
monitor_section_match = re.search(
|
||||
r'async def monitor_loop\(\):(.*?)(?=\n def |\n async def |\Z)',
|
||||
source,
|
||||
re.DOTALL
|
||||
)
|
||||
|
||||
if monitor_section_match:
|
||||
monitor_section = monitor_section_match.group(1)
|
||||
# Should not have: while not stop_event... asyncio.sleep pattern
|
||||
polling_pattern = r'while.*stop_event.*\n.*asyncio\.sleep\(0\.5\)'
|
||||
assert not re.search(polling_pattern, monitor_section), \
|
||||
"Monitor loop should not use 0.5s polling pattern"
|
||||
|
||||
def test_verify_call_soon_threadsafe_in_stop(self):
|
||||
"""Verify that stop() uses call_soon_threadsafe."""
|
||||
source_path = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
'../src/RNS/Interfaces/linux_bluetooth_driver.py'
|
||||
)
|
||||
|
||||
with open(source_path, 'r') as f:
|
||||
source = f.read()
|
||||
|
||||
# The stop() method should use call_soon_threadsafe
|
||||
assert 'call_soon_threadsafe' in source, "stop() should use call_soon_threadsafe"
|
||||
assert '_async_stop_event.set' in source, "Should signal async stop event"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
129
tests/test_identity_hash.py
Normal file
129
tests/test_identity_hash.py
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Unit tests for _compute_identity_hash() function.
|
||||
|
||||
This tests the fix for double-hashing bug where peer_identity (already a hash
|
||||
from BLE handshake) was incorrectly passed through RNS.Identity.full_hash(),
|
||||
producing a different value and causing "no reassembler for X" errors.
|
||||
|
||||
The tests verify the expected behavior without importing BLEInterface directly
|
||||
(which has heavy RNS dependencies), instead testing the core logic.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../src'))
|
||||
|
||||
|
||||
def compute_identity_hash_fixed(peer_identity):
|
||||
"""
|
||||
The FIXED implementation of _compute_identity_hash().
|
||||
|
||||
This is what the code should do: just convert bytes to hex.
|
||||
"""
|
||||
# peer_identity is already the identity hash from BLE handshake
|
||||
# Just convert to hex, don't re-hash (that would corrupt the identity!)
|
||||
return peer_identity.hex()[:16]
|
||||
|
||||
|
||||
class TestComputeIdentityHash:
|
||||
"""Test _compute_identity_hash() returns correct hex representation."""
|
||||
|
||||
def test_identity_hash_returns_hex_of_input(self):
|
||||
"""
|
||||
_compute_identity_hash should return first 16 hex chars of input bytes.
|
||||
|
||||
The peer_identity parameter is already the identity hash from BLE handshake.
|
||||
We should NOT hash it again - just convert to hex.
|
||||
"""
|
||||
# Test identity bytes (16 bytes = 32 hex chars, we want first 16)
|
||||
test_identity = bytes.fromhex("232f48ba94a3142937c9a64714112ff3")
|
||||
|
||||
result = compute_identity_hash_fixed(test_identity)
|
||||
|
||||
# Should be first 16 hex chars of the input (8 bytes = 16 hex chars)
|
||||
assert result == "232f48ba94a31429"
|
||||
assert len(result) == 16
|
||||
|
||||
def test_identity_hash_does_not_double_hash(self):
|
||||
"""
|
||||
Verify the fix: _compute_identity_hash must NOT apply RNS.Identity.full_hash().
|
||||
|
||||
The old buggy code did:
|
||||
return RNS.Identity.full_hash(peer_identity)[:16].hex()[:16]
|
||||
|
||||
This would produce a completely different value, causing sender/receiver
|
||||
identity mismatch and "no reassembler" errors.
|
||||
"""
|
||||
import hashlib
|
||||
|
||||
# Real identity bytes from a test session
|
||||
test_identity = bytes.fromhex("232f48ba94a3142937c9a64714112ff3")
|
||||
|
||||
# Get the correct result (hex of input)
|
||||
correct_result = compute_identity_hash_fixed(test_identity)
|
||||
|
||||
# Simulate what RNS.Identity.full_hash does (SHA-256)
|
||||
# This is what the buggy code would have produced
|
||||
buggy_hash = hashlib.sha256(test_identity).digest()
|
||||
buggy_result = buggy_hash[:16].hex()[:16]
|
||||
|
||||
# The correct result should be hex of input, NOT a hash of the input
|
||||
assert correct_result == test_identity.hex()[:16]
|
||||
|
||||
# The buggy result would be different (a hash of the already-hashed identity)
|
||||
assert correct_result != buggy_result, \
|
||||
"If these are equal, the test identity accidentally produces same hash"
|
||||
|
||||
def test_identity_hash_with_various_inputs(self):
|
||||
"""Test with various identity byte patterns."""
|
||||
test_cases = [
|
||||
bytes.fromhex("00000000000000000000000000000000"),
|
||||
bytes.fromhex("ffffffffffffffffffffffffffffffff"),
|
||||
bytes.fromhex("0123456789abcdef0123456789abcdef"),
|
||||
bytes.fromhex("deadbeefcafebabe1234567890abcdef"),
|
||||
]
|
||||
|
||||
for identity in test_cases:
|
||||
result = compute_identity_hash_fixed(identity)
|
||||
|
||||
# Result should always be first 16 hex chars of input
|
||||
assert result == identity.hex()[:16]
|
||||
assert len(result) == 16
|
||||
# Should be valid hex
|
||||
int(result, 16)
|
||||
|
||||
def test_actual_bleinterface_implementation(self):
|
||||
"""
|
||||
Verify BLEInterface._compute_identity_hash matches expected behavior.
|
||||
|
||||
This test reads the actual source code and verifies it contains the fix.
|
||||
"""
|
||||
import re
|
||||
|
||||
# Read the actual BLEInterface.py source
|
||||
ble_interface_path = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
'../src/RNS/Interfaces/BLEInterface.py'
|
||||
)
|
||||
|
||||
with open(ble_interface_path, 'r') as f:
|
||||
source = f.read()
|
||||
|
||||
# Find the _compute_identity_hash method
|
||||
# Look for the fixed implementation pattern
|
||||
fixed_pattern = r'def _compute_identity_hash.*?return peer_identity\.hex\(\)\[:16\]'
|
||||
|
||||
# Look for the buggy implementation pattern
|
||||
buggy_pattern = r'RNS\.Identity\.full_hash\(peer_identity\)'
|
||||
|
||||
# The fixed code should have peer_identity.hex()[:16]
|
||||
assert re.search(fixed_pattern, source, re.DOTALL), \
|
||||
"_compute_identity_hash should use peer_identity.hex()[:16]"
|
||||
|
||||
# The fixed code should NOT have the double-hash
|
||||
assert not re.search(buggy_pattern, source), \
|
||||
"_compute_identity_hash should NOT use RNS.Identity.full_hash(peer_identity)"
|
||||
|
|
@ -52,7 +52,7 @@ def mock_driver():
|
|||
driver = Mock()
|
||||
driver.loop = asyncio.new_event_loop()
|
||||
driver._peers = {} # address -> peer_conn
|
||||
driver._peers_lock = asyncio.Lock()
|
||||
driver._peers_lock = threading.RLock() # Use threading lock for mock (asyncio.Lock requires event loop in Py3.8/3.9)
|
||||
driver._log = Mock()
|
||||
driver.on_device_disconnected = Mock()
|
||||
|
||||
|
|
@ -71,7 +71,7 @@ class TestPeripheralDisconnectCleanup:
|
|||
gatt_server = Mock()
|
||||
gatt_server.driver = mock_driver
|
||||
gatt_server.connected_centrals = {}
|
||||
gatt_server.centrals_lock = asyncio.Lock()
|
||||
gatt_server.centrals_lock = threading.RLock() # Use threading lock for mock
|
||||
gatt_server.running = True
|
||||
gatt_server._log = Mock()
|
||||
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ import pytest
|
|||
import sys
|
||||
import os
|
||||
import asyncio
|
||||
import threading
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
|
||||
# Add src to path
|
||||
|
|
@ -54,7 +55,7 @@ class TestScannerConnectionCoordination:
|
|||
driver = Mock()
|
||||
driver.loop = asyncio.new_event_loop()
|
||||
driver._connecting_peers = set()
|
||||
driver._connecting_lock = asyncio.Lock()
|
||||
driver._connecting_lock = threading.RLock() # Use threading lock for mock (asyncio.Lock requires event loop in Py3.8/3.9)
|
||||
driver._log = Mock()
|
||||
return driver
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue