Skip to content

Commit

Permalink
merges internal changes to github exteranl repos
Browse files Browse the repository at this point in the history
  • Loading branch information
taku910 committed Dec 23, 2023
1 parent 022f8c3 commit 6b32c01
Show file tree
Hide file tree
Showing 12 changed files with 177 additions and 46 deletions.
6 changes: 3 additions & 3 deletions src/bpe_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,9 @@ std::vector<std::pair<absl::string_view, int>> Model::SampleEncode(

EncodeResult output;
for (int index = 0; index != -1; index = symbols[index].next) {
CHECK_GE(index, 0);
CHECK_LT(index, static_cast<int>(symbols.size()));
resegment(symbols[index].piece, &output);
if (index >= 0 && index < static_cast<int>(symbols.size())) {
resegment(symbols[index].piece, &output);
}
}

return output;
Expand Down
6 changes: 0 additions & 6 deletions src/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,6 @@ typedef uint64_t uint64;

static constexpr uint32 kUnicodeError = 0xFFFD;

#if defined(OS_WIN) && defined(UNICODE) && defined(_UNICODE)
#define WPATH(path) (::sentencepiece::win32::Utf8ToWide(path).c_str())
#else
#define WPATH(path) (path)
#endif

template <typename T, size_t N>
char (&ArraySizeHelper(T (&array)[N]))[N];

Expand Down
13 changes: 8 additions & 5 deletions src/filesystem.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.!

#include "filesystem.h"

#include <fstream>
#include <iostream>
#include <memory>

#include "filesystem.h"
#include "third_party/absl/memory/memory.h"
#include "util.h"

#if defined(OS_WIN) && defined(UNICODE) && defined(_UNICODE)
#define WPATH(path) (::sentencepiece::win32::Utf8ToWide(path).c_str())
#define WPATH(path) (::sentencepiece::util::Utf8ToWide(path).c_str())
#else
#define WPATH(path) (path)
#define WPATH(path) (path.data())
#endif

namespace sentencepiece {
Expand All @@ -32,7 +35,7 @@ class PosixReadableFile : public ReadableFile {
PosixReadableFile(absl::string_view filename, bool is_binary = false)
: is_(filename.empty()
? &std::cin
: new std::ifstream(WPATH(filename.data()),
: new std::ifstream(WPATH(filename),
is_binary ? std::ios::binary | std::ios::in
: std::ios::in)) {
if (!*is_)
Expand Down Expand Up @@ -70,7 +73,7 @@ class PosixWritableFile : public WritableFile {
PosixWritableFile(absl::string_view filename, bool is_binary = false)
: os_(filename.empty()
? &std::cout
: new std::ofstream(WPATH(filename.data()),
: new std::ofstream(WPATH(filename),
is_binary ? std::ios::binary | std::ios::out
: std::ios::out)) {
if (!*os_)
Expand Down
20 changes: 19 additions & 1 deletion src/model_interface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.!

#include "model_interface.h"

#include <algorithm>

#include "model_interface.h"
#include "sentencepiece_model.pb.h"
#include "third_party/absl/memory/memory.h"
#include "third_party/absl/strings/str_format.h"
Expand Down Expand Up @@ -68,6 +69,23 @@ void ModelInterface::InitializePieces() {
std::set<absl::string_view> user_defined_symbols;
std::vector<bool> byte_found(256, false);

int pieces_size = 0;
int reserved_id_map_size = 0;
for (int i = 0; i < model_proto_->pieces_size(); ++i) {
const auto &sp = model_proto_->pieces(i);
const bool is_normal_piece =
(sp.type() == ModelProto::SentencePiece::NORMAL ||
sp.type() == ModelProto::SentencePiece::USER_DEFINED ||
sp.type() == ModelProto::SentencePiece::UNUSED);
if (is_normal_piece) {
++pieces_size;
} else {
++reserved_id_map_size;
}
}
pieces_.reserve(pieces_size);
reserved_id_map_.reserve(reserved_id_map_size);

for (int i = 0; i < model_proto_->pieces_size(); ++i) {
const auto &sp = model_proto_->pieces(i);
if (sp.piece().empty()) {
Expand Down
7 changes: 5 additions & 2 deletions src/normalizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,11 @@ PrefixMatcher::PrefixMatcher(const std::set<absl::string_view> &dic) {
key.reserve(dic.size());
for (const auto &it : dic) key.push_back(it.data());
trie_ = absl::make_unique<Darts::DoubleArray>();
CHECK_EQ(0, trie_->build(key.size(), const_cast<char **>(&key[0]), nullptr,
nullptr));
if (trie_->build(key.size(), const_cast<char **>(&key[0]), nullptr,
nullptr) != 0) {
LOG(ERROR) << "Failed to build the TRIE for PrefixMatcher";
trie_.reset();
}
}

int PrefixMatcher::PrefixMatch(absl::string_view w, bool *found) const {
Expand Down
43 changes: 32 additions & 11 deletions src/sentencepiece_processor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,15 @@

#include "sentencepiece_processor.h"

#include <algorithm>
#include <cmath>
#include <cstddef>
#include <iterator>
#include <map>
#include <memory>
#include <set>
#include <utility>
#include <vector>

#include "common.h"
#include "filesystem.h"
Expand Down Expand Up @@ -409,7 +415,7 @@ util::Status SentencePieceProcessor::Decode(

SentencePieceText spt;
RETURN_IF_ERROR(Decode(pieces, &spt));
*detokenized = std::move(spt.text());
*detokenized = std::move(*spt.mutable_text());

return util::OkStatus();
}
Expand All @@ -420,7 +426,7 @@ util::Status SentencePieceProcessor::Decode(const std::vector<int> &ids,

SentencePieceText spt;
RETURN_IF_ERROR(Decode(ids, &spt));
*detokenized = std::move(spt.text());
*detokenized = std::move(*spt.mutable_text());

return util::OkStatus();
}
Expand Down Expand Up @@ -623,10 +629,10 @@ util::Status SentencePieceProcessor::PopulateSentencePieceText(
CHECK_EQ_OR_RETURN(consumed, normalized.size())
<< "all normalized characters are not consumed.";

RETURN_IF_ERROR(ApplyExtraOptions(encode_extra_options_, spt));

spt->set_text(input.data(), input.size());

RETURN_IF_ERROR(ApplyExtraOptions(encode_extra_options_, spt));

return util::OkStatus();
} // namespace sentencepiece

Expand Down Expand Up @@ -695,10 +701,17 @@ util::Status SentencePieceProcessor::SampleEncode(
const auto nbests = model_->NBestEncode(normalized, nbest_size);
CHECK_OR_RETURN(!nbests.empty()) << "NBestEncode returns empty result.";

std::vector<float> probs(nbests.size(), 0.0);
for (size_t i = 0; i < nbests.size(); ++i) {
probs[i] = std::exp(alpha * nbests[i].second);
}
std::vector<double> log_probs;
log_probs.reserve(nbests.size());
std::transform(nbests.begin(), nbests.end(), std::back_inserter(log_probs),
[alpha](const auto &nbest) { return alpha * nbest.second; });

const double Z = log_domain::LogSum(log_probs);
std::vector<double> probs;
probs.reserve(log_probs.size());
std::transform(
log_probs.begin(), log_probs.end(), std::back_inserter(probs),
[Z](const auto &log_prob) { return std::exp(log_prob - Z); });

auto *mt = random::GetRandomGenerator();
std::discrete_distribution<int> dist(probs.begin(), probs.end());
Expand Down Expand Up @@ -998,6 +1011,8 @@ util::Status SentencePieceProcessor::ApplyExtraOptions(
piece->set_id(PieceToId(absl::string_view(model_->eos_piece().data())));
piece->set_piece(model_->eos_piece().data(),
model_->eos_piece().size());
piece->set_begin(spt->text().size());
piece->set_end(spt->text().size());
} break;
case BOS: {
auto *array = spt->mutable_pieces();
Expand All @@ -1009,6 +1024,8 @@ util::Status SentencePieceProcessor::ApplyExtraOptions(
piece->set_id(PieceToId(absl::string_view(model_->bos_piece().data())));
piece->set_piece(model_->bos_piece().data(),
model_->bos_piece().size());
piece->set_begin(0);
piece->set_end(0);
} break;
case UNK_PIECE: {
for (int i = 0; i < spt->pieces_size(); ++i) {
Expand Down Expand Up @@ -1097,9 +1114,13 @@ util::Status LoadModelProto(absl::string_view filename,
auto input = filesystem::NewReadableFile(filename, true);
RETURN_IF_ERROR(input->status());
std::string serialized;
CHECK_OR_RETURN(input->ReadAll(&serialized));
CHECK_OR_RETURN(
model_proto->ParseFromArray(serialized.data(), serialized.size()));
if (!input->ReadAll(&serialized)) {
return util::InternalError(absl::StrCat("could not read ", filename));
}
if (!model_proto->ParseFromArray(serialized.data(), serialized.size())) {
return util::InternalError(
absl::StrCat("could not parse ModelProto from ", filename));
}

return util::OkStatus();
}
Expand Down
37 changes: 36 additions & 1 deletion src/sentencepiece_trainer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.!

#include "sentencepiece_trainer.h"

#include <string>
#include <vector>

Expand All @@ -20,7 +22,6 @@
#include "normalizer.h"
#include "sentencepiece.pb.h"
#include "sentencepiece_model.pb.h"
#include "sentencepiece_trainer.h"
#include "spec_parser.h"
#include "third_party/absl/flags/flag.h"
#include "third_party/absl/strings/numbers.h"
Expand Down Expand Up @@ -197,6 +198,40 @@ util::Status SentencePieceTrainer::Train(
sentence_iterator, serialized_model_proto);
}

namespace {
class VectorSentenceIterator : public SentenceIterator {
public:
explicit VectorSentenceIterator(const std::vector<std::string> &values)
: iter_(values.begin()), end_(values.end()) {}
virtual ~VectorSentenceIterator() {}
virtual bool done() const { return iter_ == end_; }
void Next() override { ++iter_; }
const std::string &value() const override { return *iter_; }
util::Status status() const override { return util::OkStatus(); }

private:
std::vector<std::string>::const_iterator iter_;
std::vector<std::string>::const_iterator end_;
};
} // namespace

// static
util::Status SentencePieceTrainer::Train(
absl::string_view args, const std::vector<std::string> &sentences,
std::string *serialized_model_proto) {
VectorSentenceIterator iter(sentences);
return Train(args, &iter, serialized_model_proto);
}

// static
util::Status SentencePieceTrainer::Train(
const std::unordered_map<std::string, std::string> &kwargs,
const std::vector<std::string> &sentences,
std::string *serialized_model_proto) {
VectorSentenceIterator iter(sentences);
return Train(kwargs, &iter, serialized_model_proto);
}

// static
util::Status SentencePieceTrainer::PopulateNormalizerSpec(
NormalizerSpec *normalizer_spec, bool is_denormalizer) {
Expand Down
11 changes: 11 additions & 0 deletions src/sentencepiece_trainer.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,17 @@ class SentencePieceTrainer {
SentenceIterator *sentence_iterator = nullptr,
std::string *serialized_model_proto = nullptr);

// The same as above, but passes the list of sentences.
static util::Status Train(absl::string_view args,
const std::vector<std::string> &sentences,
std::string *serialized_model_proto = nullptr);

// The same as above, but passes the list of sentences.
static util::Status Train(
const std::unordered_map<std::string, std::string> &kwargs,
const std::vector<std::string> &sentences,
std::string *serialized_model_proto = nullptr);

// Handy function to make a normalizer spec from the pre-compiled
// normalization name. Do not use this method in production as it crashes
// When `name` is invalid. Useful for unittesting.
Expand Down
16 changes: 15 additions & 1 deletion src/sentencepiece_trainer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.!

#include "sentencepiece_trainer.h"

#include "filesystem.h"
#include "sentencepiece_model.pb.h"
#include "sentencepiece_trainer.h"
#include "testharness.h"
#include "third_party/absl/strings/str_cat.h"
#include "util.h"
Expand Down Expand Up @@ -129,6 +130,19 @@ TEST(SentencePieceTrainerTest, TrainFromIterator) {
while (fs->ReadLine(&line)) sentences.emplace_back(line);
}

ASSERT_TRUE(SentencePieceTrainer::Train(
absl::StrCat("--model_prefix=", model, " --vocab_size=1000"),
sentences)
.ok());
CheckVocab(model + ".model", 1000);
CheckNormalizer(model + ".model", true, false);

ASSERT_TRUE(SentencePieceTrainer::Train(
{{"model_prefix", model}, {"vocab_size", "1000"}}, sentences)
.ok());
CheckVocab(model + ".model", 1000);
CheckNormalizer(model + ".model", true, false);

VectorIterator it(std::move(sentences));
ASSERT_TRUE(
SentencePieceTrainer::Train(
Expand Down
1 change: 0 additions & 1 deletion src/trainer_interface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,6 @@ util::Status TrainerInterface::LoadSentences() {
w.first = string_util::UnicodeTextToUTF8(uw2);
}

// +3 for meta pieces.
if (trainer_spec_.model_type() != TrainerSpec::WORD &&
trainer_spec_.model_type() != TrainerSpec::CHAR) {
CHECK_LE_OR_RETURN(
Expand Down
Loading

0 comments on commit 6b32c01

Please sign in to comment.