Use stamp generation worker context manager on Python 3.14+

This commit is contained in:
Mark Qvist 2026-05-10 17:16:38 +02:00
commit 0cb62ddc36

View file

@ -2,9 +2,11 @@ import RNS
import RNS.vendor.umsgpack as msgpack import RNS.vendor.umsgpack as msgpack
import os import os
import sys
import time import time
import math import math
import itertools import itertools
import contextlib
import multiprocessing import multiprocessing
WORKBLOCK_EXPAND_ROUNDS = 3000 WORKBLOCK_EXPAND_ROUNDS = 3000
@ -12,9 +14,38 @@ WORKBLOCK_EXPAND_ROUNDS_PN = 1000
WORKBLOCK_EXPAND_ROUNDS_PEERING = 25 WORKBLOCK_EXPAND_ROUNDS_PEERING = 25
STAMP_SIZE = RNS.Identity.HASHLENGTH//8 STAMP_SIZE = RNS.Identity.HASHLENGTH//8
PN_VALIDATION_POOL_MIN_SIZE = 256 PN_VALIDATION_POOL_MIN_SIZE = 256
USE_WORKER_MANAGER = False
active_jobs = {} active_jobs = {}
if sys.version_info[0] >= 3 and sys.version_info[1] >= 14:
USE_WORKER_MANAGER = True
@contextlib.contextmanager
def worker_context(ctx, stamp_cost, workblock, message_id):
stop_event = ctx.Event()
result_queue = ctx.Queue(1)
rounds_queue = ctx.Queue()
active_jobs[message_id] = [stop_event, result_queue]
job_procs = []
try: yield stop_event, result_queue, rounds_queue, job_procs
finally:
if message_id in active_jobs: del active_jobs[message_id]
stop_event.set()
for p in job_procs:
if p.is_alive(): p.terminate()
for p in job_procs: p.join(timeout=0.5)
result_queue.close()
result_queue.join_thread()
rounds_queue.close()
rounds_queue.join_thread()
stop_event.clear()
def stamp_workblock(material, expand_rounds=WORKBLOCK_EXPAND_ROUNDS): def stamp_workblock(material, expand_rounds=WORKBLOCK_EXPAND_ROUNDS):
wb_st = time.time() wb_st = time.time()
workblock = b"" workblock = b""
@ -100,7 +131,9 @@ def generate_stamp(message_id, stamp_cost, expand_rounds=WORKBLOCK_EXPAND_ROUNDS
if RNS.vendor.platformutils.is_windows() or RNS.vendor.platformutils.is_darwin(): stamp, rounds = job_simple(stamp_cost, workblock, message_id) if RNS.vendor.platformutils.is_windows() or RNS.vendor.platformutils.is_darwin(): stamp, rounds = job_simple(stamp_cost, workblock, message_id)
elif RNS.vendor.platformutils.is_android(): stamp, rounds = job_android(stamp_cost, workblock, message_id) elif RNS.vendor.platformutils.is_android(): stamp, rounds = job_android(stamp_cost, workblock, message_id)
else: stamp, rounds = job_linux(stamp_cost, workblock, message_id) else:
if USE_WORKER_MANAGER: stamp, rounds = job_linux_managed(stamp_cost, workblock, message_id)
else: stamp, rounds = job_linux(stamp_cost, workblock, message_id)
duration = time.time() - start_time duration = time.time() - start_time
speed = rounds/duration speed = rounds/duration
@ -176,6 +209,90 @@ def job_simple(stamp_cost, workblock, message_id):
return pstamp, rounds return pstamp, rounds
def job_linux_managed(stamp_cost, workblock, message_id):
ctx = multiprocessing.get_context("fork")
cores = multiprocessing.cpu_count()
jobs = cores if cores <= 12 else int(cores/2)
allow_kill = True
stamp = None
total_rounds = 0
with worker_context(ctx, stamp_cost, workblock, message_id) as (stop_event, result_queue, rounds_queue, job_procs):
RNS.log(f"Starting {jobs} stamp generation workers", RNS.LOG_DEBUG)
def job(stop_event, sc, wb, worker_id):
terminated = False
rounds = 0
pstamp = os.urandom(256//8)
def sv(s, c, w):
target = 0b1<<256-c; m = w+s
result = RNS.Identity.full_hash(m)
if int.from_bytes(result, byteorder="big") > target: return False
else: return True
while not stop_event.is_set() and not sv(pstamp, sc, wb):
pstamp = os.urandom(256//8); rounds += 1
if not stop_event.is_set():
stop_event.set()
try: result_queue.put_nowait(pstamp)
except: pass
try: rounds_queue.put_nowait(rounds)
except: pass
for jpn in range(jobs):
p = ctx.Process(target=job, args=(stop_event, stamp_cost, workblock, jpn), daemon=True)
job_procs.append(p)
p.start()
try: stamp = result_queue.get(timeout=None)
except Exception as e:
RNS.log(f"Failed to get result from workers: {e}", RNS.LOG_ERROR)
stamp = None
# Collect any potential spurious
# results from worker queue.
try:
while True: result_queue.get_nowait()
except: pass
for j in range(jobs):
nrounds = 0
try:
nrounds = rounds_queue.get(timeout=2)
except Exception as e:
RNS.log(f"Failed to get round stats part {j}: {e}", RNS.LOG_ERROR)
total_rounds += nrounds
all_exited = False
exit_timeout = time.time() + 5
while time.time() < exit_timeout:
if not any(p.is_alive() for p in job_procs):
all_exited = True
break
time.sleep(0.1)
if not all_exited:
RNS.log("Stamp generation IPC timeout, possible worker deadlock. Terminating remaining processes.", RNS.LOG_ERROR)
if allow_kill:
for j in range(jobs):
process = job_procs[j]
process.kill()
else:
return None
else:
for j in range(jobs):
process = job_procs[j]
process.join()
# RNS.log(f"Joined {j} / {process}", RNS.LOG_DEBUG) # TODO: Remove
return stamp, total_rounds
def job_linux(stamp_cost, workblock, message_id): def job_linux(stamp_cost, workblock, message_id):
allow_kill = True allow_kill = True
stamp = None stamp = None
@ -366,30 +483,41 @@ if __name__ == "__main__":
RNS.log("No cost argument provided", RNS.LOG_ERROR) RNS.log("No cost argument provided", RNS.LOG_ERROR)
exit(1) exit(1)
else: else:
try: try: cost = int(sys.argv[1])
cost = int(sys.argv[1])
except Exception as e: except Exception as e:
RNS.log(f"Invalid cost argument provided: {e}", RNS.LOG_ERROR) RNS.log(f"Invalid cost argument provided: {e}", RNS.LOG_ERROR)
exit(1) exit(1)
rounds = 1
if len(sys.argv) > 2:
try: rounds = int(sys.argv[2])
except Exception as e:
RNS.log(f"Invalid rounds argument provided: {e}", RNS.LOG_ERROR)
exit(1)
RNS.loglevel = RNS.LOG_DEBUG RNS.loglevel = RNS.LOG_DEBUG
RNS.log("Testing LXMF stamp generation", RNS.LOG_DEBUG)
message_id = os.urandom(32) for i in range(rounds):
generate_stamp(message_id, cost) RNS.log("", RNS.LOG_DEBUG)
RNS.log(f"Round {i+1} of {rounds}", RNS.LOG_DEBUG)
RNS.log("", RNS.LOG_DEBUG) RNS.log("Testing LXMF stamp generation", RNS.LOG_DEBUG)
RNS.log("Testing propagation stamp generation", RNS.LOG_DEBUG) message_id = os.urandom(32)
message_id = os.urandom(32) generate_stamp(message_id, cost)
generate_stamp(message_id, cost, expand_rounds=WORKBLOCK_EXPAND_ROUNDS_PN)
RNS.log("", RNS.LOG_DEBUG) RNS.log("", RNS.LOG_DEBUG)
RNS.log("Testing peering key generation", RNS.LOG_DEBUG) RNS.log("Testing propagation stamp generation", RNS.LOG_DEBUG)
message_id = os.urandom(32) message_id = os.urandom(32)
generate_stamp(message_id, cost, expand_rounds=WORKBLOCK_EXPAND_ROUNDS_PEERING) generate_stamp(message_id, cost, expand_rounds=WORKBLOCK_EXPAND_ROUNDS_PN)
transient_list = [] RNS.log("", RNS.LOG_DEBUG)
st = time.time(); count = 10000 RNS.log("Testing peering key generation", RNS.LOG_DEBUG)
for i in range(count): transient_list.append(os.urandom(256)) message_id = os.urandom(32)
validate_pn_stamps(transient_list, 5) generate_stamp(message_id, cost, expand_rounds=WORKBLOCK_EXPAND_ROUNDS_PEERING)
dt = time.time()-st; mps = count/dt
RNS.log(f"Validated {count} PN stamps in {RNS.prettytime(dt)}, {round(mps,1)} m/s", RNS.LOG_DEBUG) # transient_list = []
# st = time.time(); count = 10000
# for i in range(count): transient_list.append(os.urandom(256))
# validate_pn_stamps(transient_list, 5)
# dt = time.time()-st; mps = count/dt
# RNS.log(f"Validated {count} PN stamps in {RNS.prettytime(dt)}, {round(mps,1)} m/s", RNS.LOG_DEBUG)