Use stamp generation worker context manager on Python 3.14+
This commit is contained in:
parent
a8505eade9
commit
0cb62ddc36
1 changed files with 148 additions and 20 deletions
|
|
@ -2,9 +2,11 @@ import RNS
|
||||||
import RNS.vendor.umsgpack as msgpack
|
import RNS.vendor.umsgpack as msgpack
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
import math
|
import math
|
||||||
import itertools
|
import itertools
|
||||||
|
import contextlib
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
|
||||||
WORKBLOCK_EXPAND_ROUNDS = 3000
|
WORKBLOCK_EXPAND_ROUNDS = 3000
|
||||||
|
|
@ -12,9 +14,38 @@ WORKBLOCK_EXPAND_ROUNDS_PN = 1000
|
||||||
WORKBLOCK_EXPAND_ROUNDS_PEERING = 25
|
WORKBLOCK_EXPAND_ROUNDS_PEERING = 25
|
||||||
STAMP_SIZE = RNS.Identity.HASHLENGTH//8
|
STAMP_SIZE = RNS.Identity.HASHLENGTH//8
|
||||||
PN_VALIDATION_POOL_MIN_SIZE = 256
|
PN_VALIDATION_POOL_MIN_SIZE = 256
|
||||||
|
USE_WORKER_MANAGER = False
|
||||||
|
|
||||||
active_jobs = {}
|
active_jobs = {}
|
||||||
|
|
||||||
|
if sys.version_info[0] >= 3 and sys.version_info[1] >= 14:
|
||||||
|
USE_WORKER_MANAGER = True
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def worker_context(ctx, stamp_cost, workblock, message_id):
|
||||||
|
stop_event = ctx.Event()
|
||||||
|
result_queue = ctx.Queue(1)
|
||||||
|
rounds_queue = ctx.Queue()
|
||||||
|
|
||||||
|
active_jobs[message_id] = [stop_event, result_queue]
|
||||||
|
job_procs = []
|
||||||
|
|
||||||
|
try: yield stop_event, result_queue, rounds_queue, job_procs
|
||||||
|
finally:
|
||||||
|
if message_id in active_jobs: del active_jobs[message_id]
|
||||||
|
stop_event.set()
|
||||||
|
|
||||||
|
for p in job_procs:
|
||||||
|
if p.is_alive(): p.terminate()
|
||||||
|
|
||||||
|
for p in job_procs: p.join(timeout=0.5)
|
||||||
|
|
||||||
|
result_queue.close()
|
||||||
|
result_queue.join_thread()
|
||||||
|
rounds_queue.close()
|
||||||
|
rounds_queue.join_thread()
|
||||||
|
stop_event.clear()
|
||||||
|
|
||||||
def stamp_workblock(material, expand_rounds=WORKBLOCK_EXPAND_ROUNDS):
|
def stamp_workblock(material, expand_rounds=WORKBLOCK_EXPAND_ROUNDS):
|
||||||
wb_st = time.time()
|
wb_st = time.time()
|
||||||
workblock = b""
|
workblock = b""
|
||||||
|
|
@ -100,6 +131,8 @@ def generate_stamp(message_id, stamp_cost, expand_rounds=WORKBLOCK_EXPAND_ROUNDS
|
||||||
|
|
||||||
if RNS.vendor.platformutils.is_windows() or RNS.vendor.platformutils.is_darwin(): stamp, rounds = job_simple(stamp_cost, workblock, message_id)
|
if RNS.vendor.platformutils.is_windows() or RNS.vendor.platformutils.is_darwin(): stamp, rounds = job_simple(stamp_cost, workblock, message_id)
|
||||||
elif RNS.vendor.platformutils.is_android(): stamp, rounds = job_android(stamp_cost, workblock, message_id)
|
elif RNS.vendor.platformutils.is_android(): stamp, rounds = job_android(stamp_cost, workblock, message_id)
|
||||||
|
else:
|
||||||
|
if USE_WORKER_MANAGER: stamp, rounds = job_linux_managed(stamp_cost, workblock, message_id)
|
||||||
else: stamp, rounds = job_linux(stamp_cost, workblock, message_id)
|
else: stamp, rounds = job_linux(stamp_cost, workblock, message_id)
|
||||||
|
|
||||||
duration = time.time() - start_time
|
duration = time.time() - start_time
|
||||||
|
|
@ -176,6 +209,90 @@ def job_simple(stamp_cost, workblock, message_id):
|
||||||
|
|
||||||
return pstamp, rounds
|
return pstamp, rounds
|
||||||
|
|
||||||
|
def job_linux_managed(stamp_cost, workblock, message_id):
|
||||||
|
ctx = multiprocessing.get_context("fork")
|
||||||
|
cores = multiprocessing.cpu_count()
|
||||||
|
jobs = cores if cores <= 12 else int(cores/2)
|
||||||
|
|
||||||
|
allow_kill = True
|
||||||
|
stamp = None
|
||||||
|
total_rounds = 0
|
||||||
|
|
||||||
|
with worker_context(ctx, stamp_cost, workblock, message_id) as (stop_event, result_queue, rounds_queue, job_procs):
|
||||||
|
RNS.log(f"Starting {jobs} stamp generation workers", RNS.LOG_DEBUG)
|
||||||
|
|
||||||
|
def job(stop_event, sc, wb, worker_id):
|
||||||
|
terminated = False
|
||||||
|
rounds = 0
|
||||||
|
pstamp = os.urandom(256//8)
|
||||||
|
|
||||||
|
def sv(s, c, w):
|
||||||
|
target = 0b1<<256-c; m = w+s
|
||||||
|
result = RNS.Identity.full_hash(m)
|
||||||
|
if int.from_bytes(result, byteorder="big") > target: return False
|
||||||
|
else: return True
|
||||||
|
|
||||||
|
while not stop_event.is_set() and not sv(pstamp, sc, wb):
|
||||||
|
pstamp = os.urandom(256//8); rounds += 1
|
||||||
|
|
||||||
|
if not stop_event.is_set():
|
||||||
|
stop_event.set()
|
||||||
|
try: result_queue.put_nowait(pstamp)
|
||||||
|
except: pass
|
||||||
|
|
||||||
|
try: rounds_queue.put_nowait(rounds)
|
||||||
|
except: pass
|
||||||
|
|
||||||
|
for jpn in range(jobs):
|
||||||
|
p = ctx.Process(target=job, args=(stop_event, stamp_cost, workblock, jpn), daemon=True)
|
||||||
|
job_procs.append(p)
|
||||||
|
p.start()
|
||||||
|
|
||||||
|
try: stamp = result_queue.get(timeout=None)
|
||||||
|
except Exception as e:
|
||||||
|
RNS.log(f"Failed to get result from workers: {e}", RNS.LOG_ERROR)
|
||||||
|
stamp = None
|
||||||
|
|
||||||
|
# Collect any potential spurious
|
||||||
|
# results from worker queue.
|
||||||
|
try:
|
||||||
|
while True: result_queue.get_nowait()
|
||||||
|
except: pass
|
||||||
|
|
||||||
|
for j in range(jobs):
|
||||||
|
nrounds = 0
|
||||||
|
try:
|
||||||
|
nrounds = rounds_queue.get(timeout=2)
|
||||||
|
except Exception as e:
|
||||||
|
RNS.log(f"Failed to get round stats part {j}: {e}", RNS.LOG_ERROR)
|
||||||
|
total_rounds += nrounds
|
||||||
|
|
||||||
|
all_exited = False
|
||||||
|
exit_timeout = time.time() + 5
|
||||||
|
while time.time() < exit_timeout:
|
||||||
|
if not any(p.is_alive() for p in job_procs):
|
||||||
|
all_exited = True
|
||||||
|
break
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
if not all_exited:
|
||||||
|
RNS.log("Stamp generation IPC timeout, possible worker deadlock. Terminating remaining processes.", RNS.LOG_ERROR)
|
||||||
|
if allow_kill:
|
||||||
|
for j in range(jobs):
|
||||||
|
process = job_procs[j]
|
||||||
|
process.kill()
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
else:
|
||||||
|
for j in range(jobs):
|
||||||
|
process = job_procs[j]
|
||||||
|
process.join()
|
||||||
|
# RNS.log(f"Joined {j} / {process}", RNS.LOG_DEBUG) # TODO: Remove
|
||||||
|
|
||||||
|
return stamp, total_rounds
|
||||||
|
|
||||||
|
|
||||||
def job_linux(stamp_cost, workblock, message_id):
|
def job_linux(stamp_cost, workblock, message_id):
|
||||||
allow_kill = True
|
allow_kill = True
|
||||||
stamp = None
|
stamp = None
|
||||||
|
|
@ -366,13 +483,24 @@ if __name__ == "__main__":
|
||||||
RNS.log("No cost argument provided", RNS.LOG_ERROR)
|
RNS.log("No cost argument provided", RNS.LOG_ERROR)
|
||||||
exit(1)
|
exit(1)
|
||||||
else:
|
else:
|
||||||
try:
|
try: cost = int(sys.argv[1])
|
||||||
cost = int(sys.argv[1])
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
RNS.log(f"Invalid cost argument provided: {e}", RNS.LOG_ERROR)
|
RNS.log(f"Invalid cost argument provided: {e}", RNS.LOG_ERROR)
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
|
rounds = 1
|
||||||
|
if len(sys.argv) > 2:
|
||||||
|
try: rounds = int(sys.argv[2])
|
||||||
|
except Exception as e:
|
||||||
|
RNS.log(f"Invalid rounds argument provided: {e}", RNS.LOG_ERROR)
|
||||||
|
exit(1)
|
||||||
|
|
||||||
RNS.loglevel = RNS.LOG_DEBUG
|
RNS.loglevel = RNS.LOG_DEBUG
|
||||||
|
|
||||||
|
for i in range(rounds):
|
||||||
|
RNS.log("", RNS.LOG_DEBUG)
|
||||||
|
RNS.log(f"Round {i+1} of {rounds}", RNS.LOG_DEBUG)
|
||||||
|
|
||||||
RNS.log("Testing LXMF stamp generation", RNS.LOG_DEBUG)
|
RNS.log("Testing LXMF stamp generation", RNS.LOG_DEBUG)
|
||||||
message_id = os.urandom(32)
|
message_id = os.urandom(32)
|
||||||
generate_stamp(message_id, cost)
|
generate_stamp(message_id, cost)
|
||||||
|
|
@ -387,9 +515,9 @@ if __name__ == "__main__":
|
||||||
message_id = os.urandom(32)
|
message_id = os.urandom(32)
|
||||||
generate_stamp(message_id, cost, expand_rounds=WORKBLOCK_EXPAND_ROUNDS_PEERING)
|
generate_stamp(message_id, cost, expand_rounds=WORKBLOCK_EXPAND_ROUNDS_PEERING)
|
||||||
|
|
||||||
transient_list = []
|
# transient_list = []
|
||||||
st = time.time(); count = 10000
|
# st = time.time(); count = 10000
|
||||||
for i in range(count): transient_list.append(os.urandom(256))
|
# for i in range(count): transient_list.append(os.urandom(256))
|
||||||
validate_pn_stamps(transient_list, 5)
|
# validate_pn_stamps(transient_list, 5)
|
||||||
dt = time.time()-st; mps = count/dt
|
# dt = time.time()-st; mps = count/dt
|
||||||
RNS.log(f"Validated {count} PN stamps in {RNS.prettytime(dt)}, {round(mps,1)} m/s", RNS.LOG_DEBUG)
|
# RNS.log(f"Validated {count} PN stamps in {RNS.prettytime(dt)}, {round(mps,1)} m/s", RNS.LOG_DEBUG)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue