Skip to content

Commit

Permalink
wsample: refactor initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
ptaffet-jump committed Sep 26, 2023
1 parent 4d64d88 commit 49bb23f
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 96 deletions.
66 changes: 43 additions & 23 deletions src/ballet/wsample/fd_wsample.c
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,11 @@ seed_recursive( treap_ele_t * pool,


void *
fd_wsample_new( void * shmem,
fd_chacha20rng_t * rng,
ulong * weights,
ulong ele_cnt,
int restore_enabled,
int opt_hint ) {
fd_wsample_new_init( void * shmem,
fd_chacha20rng_t * rng,
ulong ele_cnt,
int restore_enabled,
int opt_hint ) {
if( FD_UNLIKELY( !shmem ) ) {
FD_LOG_WARNING(( "NULL shmem" ));
return NULL;
Expand All @@ -147,7 +146,7 @@ fd_wsample_new( void * shmem,
fd_wsample_t * sampler = (fd_wsample_t *)shmem;

sampler->total_weight = 0UL;
sampler->unremoved_cnt = ele_cnt;
sampler->unremoved_cnt = 0UL;
sampler->unremoved_weight = 0UL;
sampler->restore_enabled = restore_enabled;
sampler->rng = rng;
Expand All @@ -163,25 +162,45 @@ fd_wsample_new( void * shmem,
/* 100 is fine as a starting prio. See note above. */
if( opt_hint==FD_WSAMPLE_HINT_POWERLAW_NOREMOVE ) seed_recursive( pool, 1U, (uint)ele_cnt, 100U );
else treap_seed ( pool, ele_cnt, ele_cnt^0xBADF00DU );
return shmem;
}

ulong weight_sum = 0UL;
for( ulong i=0UL; i<ele_cnt; i++ ) {
ulong w = weights[i];
void *
fd_wsample_new_add( void * shmem,
ulong weight ) {
fd_wsample_t * sampler = (fd_wsample_t *)shmem;
if( FD_UNLIKELY( !sampler ) ) return NULL;

if( FD_UNLIKELY( w==0UL ) ) {
FD_LOG_WARNING(( "zero weight entry found" ));
return NULL;
}
if( FD_UNLIKELY( weight_sum+w<w ) ) {
FD_LOG_WARNING(( "total weight too large" ));
return NULL;
}
if( FD_UNLIKELY( weight==0UL ) ) {
FD_LOG_WARNING(( "zero weight entry found" ));
return NULL;
}
if( FD_UNLIKELY( sampler->total_weight+weight<weight ) ) {
FD_LOG_WARNING(( "total weight too large" ));
return NULL;
}

weight_sum += w;
pool[i].weight = w;
treap_idx_insert( sampler->treap, i, pool );
treap_ele_t * pool = sampler->pool;
ulong i = sampler->unremoved_cnt++;
sampler->total_weight += weight;
pool[i].weight = weight;
treap_idx_insert( sampler->treap, i, pool );

return shmem;
}

void *
fd_wsample_new_fini( void * shmem ) {
fd_wsample_t * sampler = (fd_wsample_t *)shmem;
if( FD_UNLIKELY( !sampler ) ) return NULL;

if( FD_UNLIKELY( sampler->unremoved_cnt != treap_ele_max( sampler->treap ) ) ) {
FD_LOG_WARNING(( "fd_wsample_new_add_weight called %lu times, but expected %lu weights", sampler->unremoved_cnt,
treap_ele_max( sampler->treap ) ));
return NULL;
}

treap_ele_t * pool = sampler->pool;
/* Populate left_sum values */

ulong nodesum = 0UL; /* Tracks sum of current node and all its children */
Expand Down Expand Up @@ -236,11 +255,12 @@ fd_wsample_new( void * shmem,
}
}

sampler->total_weight = nodesum;
FD_TEST( sampler->total_weight == nodesum );
sampler->unremoved_weight = nodesum;

if( restore_enabled ) {
if( sampler->restore_enabled ) {
/* Copy the sampler to make restore fast. */
ulong ele_cnt = treap_ele_max( sampler->treap );
fd_memcpy( pool+ele_cnt, pool, ele_cnt*sizeof(treap_ele_t) );
}

Expand Down
72 changes: 48 additions & 24 deletions src/ballet/wsample/fd_wsample.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,33 +62,57 @@ void * fd_wsample_delete ( void * shmem );
#define FD_WSAMPLE_HINT_POWERLAW_NOREMOVE 2
#define FD_WSAMPLE_HINT_POWERLAW_REMOVE 3

/* fd_wsample_new formats a memory region with the appropriate alignment
and footprint to be usable as a weighted sampler. shmem is a
pointer to the first byte of the memory region to use. rng must be a
local join of a ChaCha20 RNG struct. The weighted sampler will use
/* fd_wsample_new_init, fd_wsample_new_add_weight, and
fd_wsample_new_fini format a memory region with the appropriate
alignment and footprint to be usable as a weighted sampler. This
multi-function initialization process prevents needing to construct a
flat array of weights, which is often inconvenient.
The caller must first call fd_wsample_new_init, then new_add_weight
ele_cnt times, and finally new_fini. Only at that point will the
region of memory be ready to be joined.
fd_wsample_new_init begins the formatting a memory region. shmem is
a pointer to the first byte of the memory region to use. rng must be
a local join of a ChaCha20 RNG struct. The weighted sampler will use
rng to generate random numbers. It may seem more natural for the
weighted sampler to own its own rng, but this is done to facilitate
sharing of rngs between weighted samplers, which is useful for
Turbine. weights points to the first element of an array of length
ele_cnt. If restore_enabled is set to 0, fd_wsample_restore_all will
not work but the required footprint is smaller. All elements in
weights must be strictly positive, and the sum must be less than
ULONG_MAX. ele_cnt must be less than UINT_MAX. opt_hint gives a
hint of the shape of the weights and the style of queries that will
be most common; this hint impacts query performance but not
correctness. opt_hint must be one of FD_WSAMPLE_HINT_*. Retains
read/write interest in rng but not in the memory pointed to by
weights.
On successful return, the weighted sampler contains an element
corresponding to each provided weight. Returns shmem on success and
NULL on failure. Caller is not joined on return. */
void * fd_wsample_new( void * shmem,
fd_chacha20rng_t * rng,
ulong * weights,
ulong ele_cnt,
int restore_enabled,
int opt_hint );
Turbine. ele_cnt specifies the number of elements that can be
sampled from and must be less than UINT_MAX. If restore_enabled is
set to 0, fd_wsample_restore_all will not work but the required
footprint is smaller. opt_hint gives a hint of the shape of the
weights and the style of queries that will be most common; this hint
impacts query performance but not correctness. opt_hint must be one
of FD_WSAMPLE_HINT_*.
fd_wsample_new_add adds a weight to a partially formatted memory
region. shmem must be a partially constructed region of memory, as
returned by fd_wsample_new_init or fd_wsample_new_add_weight weight
must be strictly positive, and the cumulative sum of this weight and
all other weights must be less than ULONG_MAX.
fd_wsample_new_fini finalizes the formatting of a partially formatted
memory region. shmem must be a partially constructed region of
memory, as returned by fd_wsample_new_add_weight (or
fd_wsample_new_init if ele_cnt==0).
Retains read/write interest in rng.
Each function returns shmem on success and NULL on failure. It's
safe to pass NULL as shmem, in which case NULL will be returned, so
you only need to check the final result.
On successful completion of the formatting process, the weighted
sampler will contain an element corresponding to each provided
weight. Caller is not joined on return. */
void * fd_wsample_new_init( void * shmem,
fd_chacha20rng_t * rng,
ulong ele_cnt,
int restore_enabled,
int opt_hint );
void * fd_wsample_new_add ( void * shmem, ulong weight );
void * fd_wsample_new_fini( void * shmem );

/* fd_wsample_get_rng returns the value provided for rng in new. */
fd_chacha20rng_t * fd_wsample_get_rng( fd_wsample_t * sampler );
Expand Down
56 changes: 30 additions & 26 deletions src/ballet/wsample/test_wsample.c
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,12 @@ test_probability_dist_replacement( void ) {
fd_chacha20rng_t * rng = fd_chacha20rng_join( fd_chacha20rng_new( _rng, FD_CHACHA20RNG_MODE_SHIFT ) );
fd_chacha20rng_init( rng, seed );

for( ulong i=0UL; i<1024UL; i++ ) weights[i] = 2000000UL / (i+1UL);

for( ulong sz=1UL; sz<1024UL; sz+=113UL ) {
for( ulong i=0UL; i<sz; i++ ) weights[i] = 2000000UL / (i+1UL);
fd_wsample_t * tree = fd_wsample_join( fd_wsample_new( _shmem, rng, weights, sz, 0, FD_WSAMPLE_HINT_POWERLAW_NOREMOVE ) );
void * partial = fd_wsample_new_init( _shmem, rng, sz, 0, FD_WSAMPLE_HINT_POWERLAW_NOREMOVE );
for( ulong i=0UL; i<sz; i++ ) partial = fd_wsample_new_add( partial, weights[i] );
fd_wsample_t * tree = fd_wsample_join( fd_wsample_new_fini( partial ) );

ulong weight_sum = 0UL;
for( ulong i=0UL; i<sz; i++ ) weight_sum += weights[i];
Expand Down Expand Up @@ -100,8 +103,9 @@ test_probability_dist_noreplacement( void ) {
fd_chacha20rng_init( rng, seed );

for( ulong sz=1UL; sz<1024UL; sz+=113UL ) {
for( ulong i=0UL; i<sz; i++ ) weights[i] = 2000000UL / (i+1UL);
fd_wsample_t * tree = fd_wsample_join( fd_wsample_new( _shmem, rng, weights, sz, 1, FD_WSAMPLE_HINT_POWERLAW_REMOVE ) );
void * partial = fd_wsample_new_init( _shmem, rng, sz, 1, FD_WSAMPLE_HINT_POWERLAW_REMOVE );
for( ulong i=0UL; i<sz; i++ ) partial = fd_wsample_new_add( partial, 2000000UL / (i+1UL) );
fd_wsample_t * tree = fd_wsample_join( fd_wsample_new_fini( partial ) );

memset( counts, 0, MAX*sizeof(ulong) );
for( ulong j=0UL; j<sz; j++ ) {
Expand Down Expand Up @@ -129,8 +133,9 @@ test_probability_dist_noreplacement( void ) {
/* Expected probabilities of sampling without replacement get
complicated. We're going to use a 4-element set, and make sure the
distrubtion of returned 4-tuples matches what we manually compute. */
weights[0] = 40UL; weights[1] = 30UL; weights[2] = 20UL; weights[3] = 10UL;
fd_wsample_t * tree = fd_wsample_join( fd_wsample_new( _shmem, rng, weights, 4UL, 1, FD_WSAMPLE_HINT_FLAT ) );
void * partial = fd_wsample_new_init( _shmem, rng, 4UL, 1, FD_WSAMPLE_HINT_FLAT );
partial = fd_wsample_new_add( fd_wsample_new_add( fd_wsample_new_add( fd_wsample_new_add( partial, 40UL ), 30UL ), 20UL ), 10UL );
fd_wsample_t * tree = fd_wsample_join( fd_wsample_new_fini( partial ) );
memset( counts, 0, MAX*sizeof(ulong) );

for( ulong sample=0UL; sample<302400UL; sample++ ) {
Expand Down Expand Up @@ -183,10 +188,8 @@ test_matches_solana( void ) {
fd_chacha20rng_t * rng = fd_chacha20rng_join( fd_chacha20rng_new( _rng, FD_CHACHA20RNG_MODE_MOD ) );
uchar zero_seed[32] = {0};

weights[0] = 2UL;
weights[1] = 1UL;

fd_wsample_t * tree = fd_wsample_join( fd_wsample_new( _shmem, rng, weights, 2UL, 0, FD_WSAMPLE_HINT_FLAT ) );
void * partial = fd_wsample_new_init( _shmem, rng, 2UL, 0, FD_WSAMPLE_HINT_FLAT );
fd_wsample_t * tree = fd_wsample_join( fd_wsample_new_fini( fd_wsample_new_add( fd_wsample_new_add( partial, 2UL ), 1UL ) ) );
fd_wsample_seed_rng( fd_wsample_get_rng( tree ), zero_seed );

FD_TEST( fd_wsample_sample( tree ) == 0UL );
Expand All @@ -212,7 +215,9 @@ test_matches_solana( void ) {
memset( zero_seed, 48, 32UL );
fd_chacha20rng_init( rng, zero_seed );

tree = fd_wsample_join( fd_wsample_new( _shmem, rng, weights2, 18UL, 0, FD_WSAMPLE_HINT_FLAT ) );
partial = fd_wsample_new_init( _shmem, rng, 18UL, 0, FD_WSAMPLE_HINT_FLAT );
for( ulong i=0UL; i<18UL; i++ ) partial = fd_wsample_new_add( partial, weights2[i] );
tree = fd_wsample_join( fd_wsample_new_fini( partial ) );
fd_wsample_seed_rng( fd_wsample_get_rng( tree ), zero_seed );

FD_TEST( fd_wsample_sample_and_remove( tree ) == 9UL );
Expand Down Expand Up @@ -242,16 +247,15 @@ static void
test_sharing( void ) {
fd_chacha20rng_t _rng[1];
uchar zero_seed[32] = {0};
weights[0] = 2UL;
weights[1] = 1UL;

for( ulong i=0UL; i<0x100UL; i++ ) {
fd_chacha20rng_t * rng = fd_chacha20rng_join( fd_chacha20rng_new( _rng, FD_CHACHA20RNG_MODE_SHIFT ) );
fd_chacha20rng_init( rng, zero_seed );


fd_wsample_t * sample1 = fd_wsample_join( fd_wsample_new( _shmem, rng, weights, 2UL, 0, FD_WSAMPLE_HINT_FLAT ) );
fd_wsample_t * sample2 = fd_wsample_join( fd_wsample_new( _shmem+MAX_FOOTPRINT/2UL, rng, weights, 2UL, 0, FD_WSAMPLE_HINT_FLAT ) );
void * pl1 = fd_wsample_new_init( _shmem, rng, 2UL, 0, FD_WSAMPLE_HINT_FLAT );
void * pl2 = fd_wsample_new_init( _shmem+MAX_FOOTPRINT/2UL, rng, 2UL, 0, FD_WSAMPLE_HINT_FLAT );
fd_wsample_t * sample1 = fd_wsample_join( fd_wsample_new_fini( fd_wsample_new_add( fd_wsample_new_add( pl1, 2UL ), 1UL ) ) );
fd_wsample_t * sample2 = fd_wsample_join( fd_wsample_new_fini( fd_wsample_new_add( fd_wsample_new_add( pl2, 2UL ), 1UL ) ) );

/* Since they're using the same weights, they are interchangeable. */

Expand All @@ -275,14 +279,14 @@ static void
test_restore_disabled( void ) {
fd_chacha20rng_t _rng[1];
uchar zero_seed[32] = {0};
weights[0] = 2UL;
weights[1] = 1UL;

fd_chacha20rng_t * rng = fd_chacha20rng_join( fd_chacha20rng_new( _rng, FD_CHACHA20RNG_MODE_SHIFT ) );
fd_chacha20rng_init( rng, zero_seed );

fd_wsample_t * sample1 = fd_wsample_join( fd_wsample_new( _shmem, rng, weights, 2UL, 0, FD_WSAMPLE_HINT_FLAT ) );
fd_wsample_t * sample2 = fd_wsample_join( fd_wsample_new( _shmem+MAX_FOOTPRINT/2UL, rng, weights, 2UL, 1, FD_WSAMPLE_HINT_FLAT ) );
void * partial1 = fd_wsample_new_init( _shmem, rng, 2UL, 0, FD_WSAMPLE_HINT_FLAT );
void * partial2 = fd_wsample_new_init( _shmem+MAX_FOOTPRINT/2UL, rng, 2UL, 1, FD_WSAMPLE_HINT_FLAT );
fd_wsample_t * sample1 = fd_wsample_join( fd_wsample_new_fini( fd_wsample_new_add( fd_wsample_new_add( partial1, 2UL ), 1UL ) ) );
fd_wsample_t * sample2 = fd_wsample_join( fd_wsample_new_fini( fd_wsample_new_add( fd_wsample_new_add( partial2, 2UL ), 1UL ) ) );

FD_TEST( fd_wsample_sample_and_remove( sample1 ) != FD_WSAMPLE_EMPTY );
FD_TEST( fd_wsample_sample_and_remove( sample1 ) != FD_WSAMPLE_EMPTY );
Expand All @@ -308,13 +312,13 @@ static void
test_remove_idx( void ) {
fd_chacha20rng_t _rng[1];
uchar zero_seed[32] = {0};
weights[0] = 2UL;
weights[1] = 1UL;

fd_chacha20rng_t * rng = fd_chacha20rng_join( fd_chacha20rng_new( _rng, FD_CHACHA20RNG_MODE_SHIFT ) );
fd_chacha20rng_init( rng, zero_seed );

fd_wsample_t * sample = fd_wsample_join( fd_wsample_new( _shmem, rng, weights, 2UL, 1, FD_WSAMPLE_HINT_FLAT ) );
void * partial = fd_wsample_new_init( _shmem, rng, 2UL, 1, FD_WSAMPLE_HINT_FLAT );
fd_wsample_t * sample = fd_wsample_join( fd_wsample_new_fini( fd_wsample_new_add( fd_wsample_new_add( partial, 2UL ), 1UL ) ) );
FD_TEST( sample );

fd_wsample_remove_idx( sample, 1UL );

Expand Down Expand Up @@ -348,9 +352,9 @@ test_map( void ) {
fd_chacha20rng_t * rng = fd_chacha20rng_join( fd_chacha20rng_new( _rng, FD_CHACHA20RNG_MODE_SHIFT ) );

ulong sz=1018UL;
for( ulong i=0UL; i<sz; i++ ) weights[i] = 2000000UL / (i+1UL);
fd_wsample_t * tree = fd_wsample_join( fd_wsample_new( _shmem, rng, weights, sz, 0, FD_WSAMPLE_HINT_POWERLAW_NOREMOVE ) );
fd_wsample_seed_rng( fd_wsample_get_rng( tree ), seed );
void * partial = fd_wsample_new_init( _shmem, rng, sz, 0, FD_WSAMPLE_HINT_POWERLAW_NOREMOVE );
for( ulong i=0UL; i<sz; i++ ) partial = fd_wsample_new_add( partial, 2000000UL / (i+1UL) );
fd_wsample_t * tree = fd_wsample_join( fd_wsample_new_fini( partial ) );

ulong x = 0UL;
for( ulong i=0UL; i<sz; i++ ) for( ulong j=0UL; j<weights[i]; j++ ) FD_TEST( fd_wsample_map_sample( tree, x++ )==i );
Expand Down
26 changes: 8 additions & 18 deletions src/flamenco/leaders/fd_leaders.c
Original file line number Diff line number Diff line change
Expand Up @@ -46,38 +46,30 @@ fd_epoch_leaders_new( void * shmem,
all the memory we need right here in shmem, we just need to be
careful about how we use it.
In order to construct a wsample object, we need a footprint of
64+32*pub_cnt bytes as well as a list of weights, which is
8*pub_cnt bytes. The footprint fits nicely in the space we'll use
for the struct and the list of pubkeys, while the list of weights
probably fits in the space we'll use for indices.
64+32*pub_cnt bytes. The footprint fits nicely in the space we'll use
for the struct and the list of pubkeys.
This works out because we only need the list of weights until we've
finished constructing the wsample object, and we can delay copying
the pubkeys until we're done with the wsample object.
There's a lot of type punning going on here, so watch out. */
This works out because we can delay copying the pubkeys until we're
done with the wsample object. There's a lot of type punning going
on here, so watch out. */

laddr = (ulong)shmem;
laddr = fd_ulong_align_up( laddr, fd_wsample_align() );
void * wsample_mem = (void *)fd_type_pun( (void *)laddr );
laddr += fd_wsample_footprint( pub_cnt, 0 );

laddr = fd_ulong_align_up( laddr, alignof(ulong) );
ulong * weights = (ulong *)fd_type_pun( (void *)laddr );
laddr += pub_cnt*sizeof(ulong);

FD_TEST( laddr-(ulong)shmem <= fd_epoch_leaders_footprint( pub_cnt, slot_cnt ) );

for( ulong i=0UL; i<pub_cnt; i++ ) weights[i] = stakes[i].stake;

/* Create and seed ChaCha20Rng */
fd_chacha20rng_t _rng[1];
fd_chacha20rng_t * rng = fd_chacha20rng_join( fd_chacha20rng_new( _rng, FD_CHACHA20RNG_MODE_MOD ) );
uchar key[ 32 ] = {0};
memcpy( key, &epoch, sizeof(ulong) );
fd_chacha20rng_init( rng, key );

fd_wsample_t * wsample = fd_wsample_join( fd_wsample_new( wsample_mem, rng, weights, pub_cnt, 0,
FD_WSAMPLE_HINT_POWERLAW_NODELETE ) );
void * _wsample = fd_wsample_new_init( wsample_mem, rng, pub_cnt, 0, FD_WSAMPLE_HINT_POWERLAW_NOREMOVE );
for( ulong i=0UL; i<pub_cnt; i++ ) _wsample = fd_wsample_new_add( _wsample, stakes[i].stake );
fd_wsample_t * wsample = fd_wsample_join( fd_wsample_new_fini( _wsample ) );

/* Compute the eventual addresses */
laddr = (ulong)shmem;
Expand All @@ -93,8 +85,6 @@ fd_epoch_leaders_new( void * shmem,
uint * sched = (uint *)fd_type_pun( (void *)laddr );
ulong sched_cnt = (slot_cnt+FD_EPOCH_SLOTS_PER_ROTATION-1UL)/FD_EPOCH_SLOTS_PER_ROTATION;

FD_TEST( (ulong)sched >= (ulong)weights );

/* Generate samples. We need uints, so we can't use sample_many. */
for( ulong i=0UL; i<sched_cnt; i++ ) sched[ i ] = (uint)fd_wsample_sample( wsample );

Expand Down
7 changes: 2 additions & 5 deletions src/flamenco/leaders/fd_leaders.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,8 @@ typedef struct fd_stake_weight fd_stake_weight_t;
alignof(fd_epoch_leaders_t), sizeof(fd_epoch_leaders_t) ), \
32UL, (pub_cnt )*32UL ), \
alignof(ulong), ( \
(slot_cnt+FD_EPOCH_SLOTS_PER_ROTATION-1UL)/FD_EPOCH_SLOTS_PER_ROTATION*sizeof(uint)> \
(pub_cnt)*sizeof(ulong) ? /* Take the MAX of the two */ \
(slot_cnt+FD_EPOCH_SLOTS_PER_ROTATION-1UL)/FD_EPOCH_SLOTS_PER_ROTATION*sizeof(uint) : \
(pub_cnt)*sizeof(ulong) ) \
), \
(slot_cnt+FD_EPOCH_SLOTS_PER_ROTATION-1UL)/FD_EPOCH_SLOTS_PER_ROTATION*sizeof(uint) \
) ), \
FD_EPOCH_LEADERS_ALIGN )

#define FD_EPOCH_SLOTS_PER_ROTATION (4UL)
Expand Down

0 comments on commit 49bb23f

Please sign in to comment.