diff --git a/LXMF/LXStamper.py b/LXMF/LXStamper.py index 3d7a1c2..ea25f08 100644 --- a/LXMF/LXStamper.py +++ b/LXMF/LXStamper.py @@ -2,9 +2,11 @@ import RNS import RNS.vendor.umsgpack as msgpack import os +import sys import time import math import itertools +import contextlib import multiprocessing WORKBLOCK_EXPAND_ROUNDS = 3000 @@ -12,9 +14,38 @@ WORKBLOCK_EXPAND_ROUNDS_PN = 1000 WORKBLOCK_EXPAND_ROUNDS_PEERING = 25 STAMP_SIZE = RNS.Identity.HASHLENGTH//8 PN_VALIDATION_POOL_MIN_SIZE = 256 +USE_WORKER_MANAGER = False active_jobs = {} +if sys.version_info[0] >= 3 and sys.version_info[1] >= 14: + USE_WORKER_MANAGER = True + +@contextlib.contextmanager +def worker_context(ctx, stamp_cost, workblock, message_id): + stop_event = ctx.Event() + result_queue = ctx.Queue(1) + rounds_queue = ctx.Queue() + + active_jobs[message_id] = [stop_event, result_queue] + job_procs = [] + + try: yield stop_event, result_queue, rounds_queue, job_procs + finally: + if message_id in active_jobs: del active_jobs[message_id] + stop_event.set() + + for p in job_procs: + if p.is_alive(): p.terminate() + + for p in job_procs: p.join(timeout=0.5) + + result_queue.close() + result_queue.join_thread() + rounds_queue.close() + rounds_queue.join_thread() + stop_event.clear() + def stamp_workblock(material, expand_rounds=WORKBLOCK_EXPAND_ROUNDS): wb_st = time.time() workblock = b"" @@ -100,7 +131,9 @@ def generate_stamp(message_id, stamp_cost, expand_rounds=WORKBLOCK_EXPAND_ROUNDS if RNS.vendor.platformutils.is_windows() or RNS.vendor.platformutils.is_darwin(): stamp, rounds = job_simple(stamp_cost, workblock, message_id) elif RNS.vendor.platformutils.is_android(): stamp, rounds = job_android(stamp_cost, workblock, message_id) - else: stamp, rounds = job_linux(stamp_cost, workblock, message_id) + else: + if USE_WORKER_MANAGER: stamp, rounds = job_linux_managed(stamp_cost, workblock, message_id) + else: stamp, rounds = job_linux(stamp_cost, workblock, message_id) duration = time.time() - start_time speed = rounds/duration @@ -176,6 +209,90 @@ def job_simple(stamp_cost, workblock, message_id): return pstamp, rounds +def job_linux_managed(stamp_cost, workblock, message_id): + ctx = multiprocessing.get_context("fork") + cores = multiprocessing.cpu_count() + jobs = cores if cores <= 12 else int(cores/2) + + allow_kill = True + stamp = None + total_rounds = 0 + + with worker_context(ctx, stamp_cost, workblock, message_id) as (stop_event, result_queue, rounds_queue, job_procs): + RNS.log(f"Starting {jobs} stamp generation workers", RNS.LOG_DEBUG) + + def job(stop_event, sc, wb, worker_id): + terminated = False + rounds = 0 + pstamp = os.urandom(256//8) + + def sv(s, c, w): + target = 0b1<<256-c; m = w+s + result = RNS.Identity.full_hash(m) + if int.from_bytes(result, byteorder="big") > target: return False + else: return True + + while not stop_event.is_set() and not sv(pstamp, sc, wb): + pstamp = os.urandom(256//8); rounds += 1 + + if not stop_event.is_set(): + stop_event.set() + try: result_queue.put_nowait(pstamp) + except: pass + + try: rounds_queue.put_nowait(rounds) + except: pass + + for jpn in range(jobs): + p = ctx.Process(target=job, args=(stop_event, stamp_cost, workblock, jpn), daemon=True) + job_procs.append(p) + p.start() + + try: stamp = result_queue.get(timeout=None) + except Exception as e: + RNS.log(f"Failed to get result from workers: {e}", RNS.LOG_ERROR) + stamp = None + + # Collect any potential spurious + # results from worker queue. + try: + while True: result_queue.get_nowait() + except: pass + + for j in range(jobs): + nrounds = 0 + try: + nrounds = rounds_queue.get(timeout=2) + except Exception as e: + RNS.log(f"Failed to get round stats part {j}: {e}", RNS.LOG_ERROR) + total_rounds += nrounds + + all_exited = False + exit_timeout = time.time() + 5 + while time.time() < exit_timeout: + if not any(p.is_alive() for p in job_procs): + all_exited = True + break + time.sleep(0.1) + + if not all_exited: + RNS.log("Stamp generation IPC timeout, possible worker deadlock. Terminating remaining processes.", RNS.LOG_ERROR) + if allow_kill: + for j in range(jobs): + process = job_procs[j] + process.kill() + else: + return None + + else: + for j in range(jobs): + process = job_procs[j] + process.join() + # RNS.log(f"Joined {j} / {process}", RNS.LOG_DEBUG) # TODO: Remove + + return stamp, total_rounds + + def job_linux(stamp_cost, workblock, message_id): allow_kill = True stamp = None @@ -366,30 +483,41 @@ if __name__ == "__main__": RNS.log("No cost argument provided", RNS.LOG_ERROR) exit(1) else: - try: - cost = int(sys.argv[1]) + try: cost = int(sys.argv[1]) except Exception as e: RNS.log(f"Invalid cost argument provided: {e}", RNS.LOG_ERROR) exit(1) + rounds = 1 + if len(sys.argv) > 2: + try: rounds = int(sys.argv[2]) + except Exception as e: + RNS.log(f"Invalid rounds argument provided: {e}", RNS.LOG_ERROR) + exit(1) + RNS.loglevel = RNS.LOG_DEBUG - RNS.log("Testing LXMF stamp generation", RNS.LOG_DEBUG) - message_id = os.urandom(32) - generate_stamp(message_id, cost) + + for i in range(rounds): + RNS.log("", RNS.LOG_DEBUG) + RNS.log(f"Round {i+1} of {rounds}", RNS.LOG_DEBUG) - RNS.log("", RNS.LOG_DEBUG) - RNS.log("Testing propagation stamp generation", RNS.LOG_DEBUG) - message_id = os.urandom(32) - generate_stamp(message_id, cost, expand_rounds=WORKBLOCK_EXPAND_ROUNDS_PN) + RNS.log("Testing LXMF stamp generation", RNS.LOG_DEBUG) + message_id = os.urandom(32) + generate_stamp(message_id, cost) - RNS.log("", RNS.LOG_DEBUG) - RNS.log("Testing peering key generation", RNS.LOG_DEBUG) - message_id = os.urandom(32) - generate_stamp(message_id, cost, expand_rounds=WORKBLOCK_EXPAND_ROUNDS_PEERING) + RNS.log("", RNS.LOG_DEBUG) + RNS.log("Testing propagation stamp generation", RNS.LOG_DEBUG) + message_id = os.urandom(32) + generate_stamp(message_id, cost, expand_rounds=WORKBLOCK_EXPAND_ROUNDS_PN) - transient_list = [] - st = time.time(); count = 10000 - for i in range(count): transient_list.append(os.urandom(256)) - validate_pn_stamps(transient_list, 5) - dt = time.time()-st; mps = count/dt - RNS.log(f"Validated {count} PN stamps in {RNS.prettytime(dt)}, {round(mps,1)} m/s", RNS.LOG_DEBUG) + RNS.log("", RNS.LOG_DEBUG) + RNS.log("Testing peering key generation", RNS.LOG_DEBUG) + message_id = os.urandom(32) + generate_stamp(message_id, cost, expand_rounds=WORKBLOCK_EXPAND_ROUNDS_PEERING) + + # transient_list = [] + # st = time.time(); count = 10000 + # for i in range(count): transient_list.append(os.urandom(256)) + # validate_pn_stamps(transient_list, 5) + # dt = time.time()-st; mps = count/dt + # RNS.log(f"Validated {count} PN stamps in {RNS.prettytime(dt)}, {round(mps,1)} m/s", RNS.LOG_DEBUG)