Stamp cost API functions and multi-process stamp generation on Android

This commit is contained in:
Mark Qvist 2024-09-07 11:35:17 +02:00
commit 0d76eee6cd
4 changed files with 189 additions and 49 deletions

View file

@ -10,13 +10,13 @@ import multiprocessing
from .LXMF import APP_NAME
class LXMessage:
DRAFT = 0x00
GENERATING = 0x00
OUTBOUND = 0x01
SENDING = 0x02
SENT = 0x04
DELIVERED = 0x08
FAILED = 0xFF
states = [DRAFT, OUTBOUND, SENDING, SENT, DELIVERED, FAILED]
states = [GENERATING, OUTBOUND, SENDING, SENT, DELIVERED, FAILED]
UNKNOWN = 0x00
PACKET = 0x01
@ -126,7 +126,8 @@ class LXMessage:
self.stamp = None
self.stamp_cost = stamp_cost
self.stamp_valid = False
self.state = LXMessage.DRAFT
self.defer_stamp = False
self.state = LXMessage.GENERATING
self.method = LXMessage.UNKNOWN
self.progress = 0.0
self.rssi = None
@ -277,53 +278,128 @@ class LXMessage:
start_time = time.time()
total_rounds = 0
stop_event = multiprocessing.Event()
result_queue = multiprocessing.Queue(maxsize=1)
rounds_queue = multiprocessing.Queue()
def job(stop_event):
terminated = False
rounds = 0
stamp = os.urandom(256//8)
while not LXMessage.stamp_valid(stamp, self.stamp_cost, workblock):
if stop_event.is_set():
break
if timeout != None and rounds % 10000 == 0:
if time.time() > start_time + timeout:
RNS.log(f"Stamp generation for {self} timed out", RNS.LOG_ERROR)
return None
if not RNS.vendor.platformutils.is_android():
stop_event = multiprocessing.Event()
result_queue = multiprocessing.Queue(maxsize=1)
rounds_queue = multiprocessing.Queue()
def job(stop_event):
terminated = False
rounds = 0
stamp = os.urandom(256//8)
rounds += 1
while not LXMessage.stamp_valid(stamp, self.stamp_cost, workblock):
if stop_event.is_set():
break
rounds_queue.put(rounds)
if not stop_event.is_set():
result_queue.put(stamp)
if timeout != None and rounds % 10000 == 0:
if time.time() > start_time + timeout:
RNS.log(f"Stamp generation for {self} timed out", RNS.LOG_ERROR)
return None
job_procs = []
jobs = multiprocessing.cpu_count()
for _ in range(jobs):
process = multiprocessing.Process(target=job, kwargs={"stop_event": stop_event},)
job_procs.append(process)
process.start()
stamp = os.urandom(256//8)
rounds += 1
stamp = result_queue.get()
stop_event.set()
rounds_queue.put(rounds)
if not stop_event.is_set():
result_queue.put(stamp)
for j in range(jobs):
process = job_procs[j]
process.join()
total_rounds += rounds_queue.get()
job_procs = []
jobs = multiprocessing.cpu_count()
for _ in range(jobs):
process = multiprocessing.Process(target=job, kwargs={"stop_event": stop_event},)
job_procs.append(process)
process.start()
duration = time.time() - start_time
rounds = total_rounds
stamp = result_queue.get()
stop_event.set()
for j in range(jobs):
process = job_procs[j]
process.join()
total_rounds += rounds_queue.get()
duration = time.time() - start_time
rounds = total_rounds
else:
# Semaphore support is flaky to non-existent on
# Android, so we need to manually dispatch and
# manage workloads here, while periodically
# checking in on the progress.
use_nacl = False
try:
import nacl.encoding
import nacl.hash
use_nacl = True
except:
pass
def full_hash(m):
if use_nacl:
return nacl.hash.sha256(m, encoder=nacl.encoding.RawEncoder)
else:
return RNS.Identity.full_hash(m)
def sv(s, c, w):
target = 0b1<<256-c
m = w+s
result = full_hash(m)
if int.from_bytes(result, byteorder="big") > target:
return False
else:
return True
stamp = None
wm = multiprocessing.Manager()
jobs = multiprocessing.cpu_count()
RNS.log(f"Dispatching {jobs} workers for stamp generation...") # TODO: Remove
results_dict = wm.dict()
while stamp == None:
job_procs = []
def job(procnum=None, results_dict=None, wb=None):
RNS.log(f"Worker {procnum} starting...") # TODO: Remove
rounds = 0
stamp = os.urandom(256//8)
while not sv(stamp, self.stamp_cost, wb):
if rounds >= 500:
stamp = None
RNS.log(f"Worker {procnum} found no result in {rounds} rounds") # TODO: Remove
break
stamp = os.urandom(256//8)
rounds += 1
results_dict[procnum] = [stamp, rounds]
for pnum in range(jobs):
process = multiprocessing.Process(target=job, kwargs={"procnum":pnum, "results_dict": results_dict, "wb": workblock},)
job_procs.append(process)
process.start()
for process in job_procs:
process.join()
for j in results_dict:
r = results_dict[j]
RNS.log(f"Result from {r}: {r[1]} rounds, stamp: {r[0]}") # TODO: Remove
total_rounds += r[1]
if r[0] != None:
stamp = r[0]
RNS.log(f"Found stamp: {stamp}") # TODO: Remove
duration = time.time() - start_time
rounds = total_rounds
# TODO: Remove stats output
RNS.log(f"Stamp generated in {RNS.prettytime(duration)} / {rounds} rounds", RNS.LOG_DEBUG)
RNS.log(f"Rounds per second {int(rounds/duration)}", RNS.LOG_DEBUG)
RNS.log(f"Stamp: {RNS.hexrep(stamp)}", RNS.LOG_DEBUG)
RNS.log(f"Resulting hash: {RNS.hexrep(RNS.Identity.full_hash(workblock+stamp))}", RNS.LOG_DEBUG)
# RNS.log(f"Rounds per second {int(rounds/duration)}", RNS.LOG_DEBUG)
# RNS.log(f"Stamp: {RNS.hexrep(stamp)}", RNS.LOG_DEBUG)
# RNS.log(f"Resulting hash: {RNS.hexrep(RNS.Identity.full_hash(workblock+stamp))}", RNS.LOG_DEBUG)
###########################
return stamp
@ -344,9 +420,11 @@ class LXMessage:
hashed_part += msgpack.packb(self.payload)
self.hash = RNS.Identity.full_hash(hashed_part)
self.message_id = self.hash
self.stamp = self.get_stamp()
if self.stamp != None:
self.payload.append(self.stamp)
if not self.defer_stamp:
self.stamp = self.get_stamp()
if self.stamp != None:
self.payload.append(self.stamp)
signed_part = b""
signed_part += hashed_part