Skip to content

Commit

Permalink
Fix overflow in concurrent_hash_map (#704)
Browse files Browse the repository at this point in the history
Signed-off-by: Alexei Katranov <[email protected]>
  • Loading branch information
alexey-katranov committed Jan 13, 2022
1 parent 1eaccf7 commit ec39c54
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 55 deletions.
20 changes: 10 additions & 10 deletions include/oneapi/tbb/concurrent_hash_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -455,12 +455,12 @@ class hash_map_iterator {
template <typename Key, typename T, typename HashCompare, typename A
#if __TBB_PREVIEW_CONCURRENT_HASH_MAP_EXTENSIONS
, typename M
>
__TBB_requires(tbb::detail::hash_compare<HashCompare, Key> &&
ch_map_rw_scoped_lockable<M>)
>
__TBB_requires(tbb::detail::hash_compare<HashCompare, Key> &&
ch_map_rw_scoped_lockable<M>)
#else
>
__TBB_requires(tbb::detail::hash_compare<HashCompare, Key>)
>
__TBB_requires(tbb::detail::hash_compare<HashCompare, Key>)
#endif
friend class concurrent_hash_map;

Expand Down Expand Up @@ -726,7 +726,7 @@ class concurrent_hash_map
void rehash_bucket( bucket *b_new, const hashcode_type hash ) {
__TBB_ASSERT( hash > 1, "The lowermost buckets can't be rehashed" );
b_new->node_list.store(reinterpret_cast<node_base*>(empty_rehashed_flag), std::memory_order_release); // mark rehashed
hashcode_type mask = (1u << tbb::detail::log2(hash)) - 1; // get parent mask from the topmost bit
hashcode_type mask = (hashcode_type(1) << tbb::detail::log2(hash)) - 1; // get parent mask from the topmost bit
bucket_accessor b_old( this, hash & mask );

mask = (mask<<1) | 1; // get full mask for new bucket
Expand Down Expand Up @@ -784,7 +784,7 @@ class concurrent_hash_map
void release() {
if( my_node ) {
node::scoped_type::release();
my_node = 0;
my_node = nullptr;
}
}

Expand All @@ -800,7 +800,7 @@ class concurrent_hash_map
}

// Create empty result
const_accessor() : my_node(nullptr) {}
const_accessor() : my_node(nullptr), my_hash() {}

// Destroy result after releasing the underlying reference.
~const_accessor() {
Expand Down Expand Up @@ -971,7 +971,7 @@ class concurrent_hash_map
hashcode_type h = b; bucket *b_old = bp;
do {
__TBB_ASSERT( h > 1, "The lowermost buckets can't be rehashed" );
hashcode_type m = ( 1u<<tbb::detail::log2( h ) ) - 1; // get parent mask from the topmost bit
hashcode_type m = ( hashcode_type(1) << tbb::detail::log2( h ) ) - 1; // get parent mask from the topmost bit
b_old = this->get_bucket( h &= m );
} while( rehash_required(b_old->node_list.load(std::memory_order_relaxed)) );
// now h - is index of the root rehashed bucket b_old
Expand Down Expand Up @@ -1470,7 +1470,7 @@ class concurrent_hash_map
h &= m;
bucket *b = this->get_bucket( h );
while (rehash_required(b->node_list.load(std::memory_order_relaxed))) {
m = ( 1u<<tbb::detail::log2( h ) ) - 1; // get parent mask from the topmost bit
m = ( hashcode_type(1) << tbb::detail::log2( h ) ) - 1; // get parent mask from the topmost bit
b = this->get_bucket( h &= m );
}
node *n = search_bucket( key, b );
Expand Down
90 changes: 45 additions & 45 deletions test/conformance/conformance_concurrent_hash_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class MyData {
}

MyData( const MyData& other ) {
CHECK(other.my_state==LIVE);
CHECK_FAST(other.my_state==LIVE);
my_state = LIVE;
data = other.data;
if(MyDataCountLimit && MyDataCount + 1 >= MyDataCountLimit) {
Expand All @@ -103,18 +103,18 @@ class MyData {
}

int value_of() const {
CHECK(my_state==LIVE);
CHECK_FAST(my_state==LIVE);
return data;
}

void set_value( int i ) {
CHECK(my_state==LIVE);
CHECK_FAST(my_state==LIVE);
data = i;
}

bool operator==( const MyData& other ) const {
CHECK(other.my_state==LIVE);
CHECK(my_state==LIVE);
CHECK_FAST(other.my_state==LIVE);
CHECK_FAST(my_state==LIVE);
return data == other.data;
}
};
Expand All @@ -124,32 +124,32 @@ class MyData2 : public MyData {
MyData2( ) {}

MyData2( const MyData2& other ) : MyData() {
CHECK(other.my_state==LIVE);
CHECK(my_state==LIVE);
CHECK_FAST(other.my_state==LIVE);
CHECK_FAST(my_state==LIVE);
data = other.data;
}

MyData2( const MyData& other ) {
CHECK(other.my_state==LIVE);
CHECK(my_state==LIVE);
CHECK_FAST(other.my_state==LIVE);
CHECK_FAST(my_state==LIVE);
data = other.data;
}

void operator=( const MyData& other ) {
CHECK(other.my_state==LIVE);
CHECK(my_state==LIVE);
CHECK_FAST(other.my_state==LIVE);
CHECK_FAST(my_state==LIVE);
data = other.data;
}

void operator=( const MyData2& other ) {
CHECK(other.my_state==LIVE);
CHECK(my_state==LIVE);
CHECK_FAST(other.my_state==LIVE);
CHECK_FAST(my_state==LIVE);
data = other.data;
}

bool operator==( const MyData2& other ) const {
CHECK(other.my_state==LIVE);
CHECK(my_state==LIVE);
CHECK_FAST(other.my_state==LIVE);
CHECK_FAST(my_state==LIVE);
return data == other.data;
}
};
Expand Down Expand Up @@ -223,7 +223,7 @@ void FillTable( test_table_type& x, int n ) {
MyKey key( MyKey::make(-i) ); // hash values must not be specified in direct order
typename test_table_type::accessor a;
bool b = x.insert(a,key);
CHECK(b);
CHECK_FAST(b);
a->second.set_value( i*i );
}
}
Expand All @@ -237,8 +237,8 @@ static void CheckTable( const test_table_type& x, int n ) {
MyKey key( MyKey::make(-i) );
typename test_table_type::const_accessor a;
bool b = x.find(a,key);
CHECK(b);
CHECK(a->second.value_of()==i*i);
CHECK_FAST(b);
CHECK_FAST(a->second.value_of()==i*i);
}
int count = 0;
int key_sum = 0;
Expand Down Expand Up @@ -609,21 +609,21 @@ struct RvalueInsert {
static void apply( DataStateTrackedTable& table, int i ) {
DataStateTrackedTable::accessor a;
int next = i + 1;
REQUIRE_MESSAGE((table.insert( a, std::make_pair(MyKey::make(i), move_support_tests::Foo(next)))),
CHECK_FAST_MESSAGE((table.insert( a, std::make_pair(MyKey::make(i), move_support_tests::Foo(next)))),
"already present while should not ?" );
CHECK((*a).second == next);
CHECK((*a).second.state == StateTrackableBase::MoveInitialized);
CHECK_FAST((*a).second == next);
CHECK_FAST((*a).second.state == StateTrackableBase::MoveInitialized);
}
};

struct Emplace {
template <typename Accessor>
static void apply_impl( DataStateTrackedTable& table, int i) {
Accessor a;
REQUIRE_MESSAGE((table.emplace( a, MyKey::make(i), (i + 1))),
CHECK_FAST_MESSAGE((table.emplace( a, MyKey::make(i), (i + 1))),
"already present while should not ?" );
CHECK((*a).second == i + 1);
CHECK((*a).second.state == StateTrackableBase::DirectInitialized);
CHECK_FAST((*a).second == i + 1);
CHECK_FAST((*a).second.state == StateTrackableBase::DirectInitialized);
}

static void apply( DataStateTrackedTable& table, int i ) {
Expand Down Expand Up @@ -665,11 +665,11 @@ struct Insert {
if( i&1 ) {
test_table_type::accessor a;
table.insert( a, std::make_pair(MyKey::make(i), MyData(i*i)) );
CHECK((*a).second.value_of()==i*i);
CHECK_FAST((*a).second.value_of()==i*i);
} else {
test_table_type::const_accessor ca;
table.insert( ca, std::make_pair(MyKey::make(i), MyData(i*i)) );
CHECK(ca->second.value_of()==i*i);
CHECK_FAST(ca->second.value_of()==i*i);
}
}
}
Expand All @@ -680,12 +680,12 @@ struct Find {
test_table_type::accessor a;
const test_table_type::accessor& ca = a;
bool b = table.find( a, MyKey::make(i) );
CHECK(b==!a.empty());
CHECK_FAST(b==!a.empty());
if( b ) {
if( !UseKey(i) )
REPORT("Line %d: unexpected key %d present\n",__LINE__,i);
CHECK(ca->second.value_of()==i*i);
CHECK((*ca).second.value_of()==i*i);
CHECK_FAST(ca->second.value_of()==i*i);
CHECK_FAST((*ca).second.value_of()==i*i);
if( i&1 )
ca->second.set_value( ~ca->second.value_of() );
else
Expand All @@ -702,12 +702,12 @@ struct FindConst {
test_table_type::const_accessor a;
const test_table_type::const_accessor& ca = a;
bool b = table.find( a, MyKey::make(i) );
CHECK(b==(table.count(MyKey::make(i))>0));
CHECK(b==!a.empty());
CHECK(b==UseKey(i));
CHECK_FAST(b==(table.count(MyKey::make(i))>0));
CHECK_FAST(b==!a.empty());
CHECK_FAST(b==UseKey(i));
if( b ) {
CHECK(ca->second.value_of()==~(i*i));
CHECK((*ca).second.value_of()==~(i*i));
CHECK_FAST(ca->second.value_of()==~(i*i));
CHECK_FAST((*ca).second.value_of()==~(i*i));
}
}
};
Expand Down Expand Up @@ -746,24 +746,24 @@ void TraverseTable( test_table_type& table, size_t n, size_t expected_size ) {
for( test_table_type::iterator i = table.begin(); i!=table.end(); ++i ) {
// Check iterator
int k = i->first.value_of();
CHECK(UseKey(k));
CHECK((*i).first.value_of()==k);
REQUIRE_MESSAGE((0<=k && size_t(k)<n), "out of bounds key" );
REQUIRE_MESSAGE( !array[k], "duplicate key" );
CHECK_FAST(UseKey(k));
CHECK_FAST((*i).first.value_of()==k);
CHECK_FAST_MESSAGE((0<=k && size_t(k)<n), "out of bounds key" );
CHECK_FAST_MESSAGE( !array[k], "duplicate key" );
array[k] = true;
++count;

// Check lower/upper bounds
std::pair<test_table_type::iterator, test_table_type::iterator> er = table.equal_range(i->first);
std::pair<test_table_type::const_iterator, test_table_type::const_iterator> cer = const_table.equal_range(i->first);
CHECK((cer.first == er.first && cer.second == er.second));
CHECK(cer.first == i);
CHECK(std::distance(cer.first, cer.second) == 1);
CHECK_FAST((cer.first == er.first && cer.second == er.second));
CHECK_FAST(cer.first == i);
CHECK_FAST(std::distance(cer.first, cer.second) == 1);

// Check const_iterator
test_table_type::const_iterator cic = ci++;
CHECK(cic->first.value_of()==k);
CHECK((*cic).first.value_of()==k);
CHECK_FAST(cic->first.value_of()==k);
CHECK_FAST((*cic).first.value_of()==k);
}
CHECK(ci==const_table.end());
delete[] array;
Expand All @@ -788,7 +788,7 @@ struct Erase {
} else
b = table.erase( MyKey::make(i) );
if( b ) ++EraseCount;
CHECK(table.count(MyKey::make(i)) == 0);
CHECK_FAST(table.count(MyKey::make(i)) == 0);
}
};

Expand Down Expand Up @@ -865,7 +865,7 @@ struct ParallelTraverseBody {
void operator()( const RangeType& range ) const {
for( typename RangeType::iterator i = range.begin(); i!=range.end(); ++i ) {
int k = i->first.value_of();
CHECK((0<=k && size_t(k)<n));
CHECK_FAST((0<=k && size_t(k)<n));
++array[k];
}
}
Expand Down

0 comments on commit ec39c54

Please sign in to comment.