Skip to content

Commit

Permalink
move SharedBitGen to random namespace
Browse files Browse the repository at this point in the history
  • Loading branch information
taku910 committed Jan 6, 2024
1 parent 49afc4c commit adf9e81
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 20 deletions.
6 changes: 3 additions & 3 deletions src/trainer_interface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,8 @@ bool TrainerInterface::IsValidSentencePiece(
}

template <typename T>
void AddDPNoise(const TrainerSpec &trainer_spec, absl::SharedBitGen &generator,
T *to_update) {
void AddDPNoise(const TrainerSpec &trainer_spec,
random::SharedBitGen &generator, T *to_update) {
if (trainer_spec.differential_privacy_noise_level() > 0) {
float random_num = absl::Gaussian<float>(
generator, 0, trainer_spec.differential_privacy_noise_level());
Expand Down Expand Up @@ -480,7 +480,7 @@ util::Status TrainerInterface::LoadSentences() {
for (int n = 0; n < num_workers; ++n) {
pool->Schedule([&, n]() {
// One per thread generator.
absl::SharedBitGen generator;
random::SharedBitGen generator;
for (size_t i = n; i < sentences_.size(); i += num_workers) {
AddDPNoise<int64>(trainer_spec_, generator,
&(sentences_[i].second));
Expand Down
5 changes: 5 additions & 0 deletions src/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,11 @@ namespace random {

std::mt19937 *GetRandomGenerator();

class SharedBitGen {
public:
std::mt19937 *engine() { return GetRandomGenerator(); }
};

template <typename T>
class ReservoirSampler {
public:
Expand Down
4 changes: 2 additions & 2 deletions third_party/absl/random/distributions.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

namespace absl {

template <typename T>
T Gaussian(SharedBitGen &generator, T mean, T stddev) {
template <typename T, typename G>
T Gaussian(G &generator, T mean, T stddev) {
std::normal_distribution<> dist(mean, stddev);
return dist(*generator.engine());
}
Expand Down
15 changes: 0 additions & 15 deletions third_party/absl/random/random.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,4 @@
#ifndef ABSL_CONTAINER_RANDOM_H_
#define ABSL_CONTAINER_RANDOM_H_

#include <random>

#include "../../../src/util.h"

using sentencepiece::random::GetRandomGenerator;

namespace absl {

class SharedBitGen {
public:
std::mt19937 *engine() { return GetRandomGenerator(); }
};

} // namespace absl

#endif // ABSL_CONTAINER_RANDOM_H_

0 comments on commit adf9e81

Please sign in to comment.