Skip to content

Commit

Permalink
add more advanced SentencePieceNormalizer class
Browse files Browse the repository at this point in the history
  • Loading branch information
taku910 committed Jan 13, 2024
1 parent f5c7363 commit ed76ecc
Show file tree
Hide file tree
Showing 7 changed files with 906 additions and 10 deletions.
95 changes: 95 additions & 0 deletions python/src/sentencepiece/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,98 @@ def Train(arg=None, logstream=None, **kwargs):

# Register SentencePieceTrainer in _sentencepiece:
_sentencepiece.SentencePieceTrainer_swigregister(SentencePieceTrainer)
class SentencePieceNormalizer(object):
thisown = property(lambda x: x.this.own(), lambda x, v: x.this.own(v), doc="The membership flag")
__repr__ = _swig_repr

def __init__(self):
_sentencepiece.SentencePieceNormalizer_swiginit(self, _sentencepiece.new_SentencePieceNormalizer())
__swig_destroy__ = _sentencepiece.delete_SentencePieceNormalizer

def LoadFromSerializedProto(self, serialized):
return _sentencepiece.SentencePieceNormalizer_LoadFromSerializedProto(self, serialized)

def LoadFromRuleTSV(self, filename):
return _sentencepiece.SentencePieceNormalizer_LoadFromRuleTSV(self, filename)

def LoadFromRuleName(self, name):
return _sentencepiece.SentencePieceNormalizer_LoadFromRuleName(self, name)

def serialized_model_proto(self):
return _sentencepiece.SentencePieceNormalizer_serialized_model_proto(self)

def LoadFromFile(self, arg):
return _sentencepiece.SentencePieceNormalizer_LoadFromFile(self, arg)

def _Normalize(self, text):
return _sentencepiece.SentencePieceNormalizer__Normalize(self, text)

def _NormalizeWithOffsets(self, text):
return _sentencepiece.SentencePieceNormalizer__NormalizeWithOffsets(self, text)

def _SetProtoField(self, name, value):
return _sentencepiece.SentencePieceNormalizer__SetProtoField(self, name, value)

def Init(self,
model_file=None,
model_proto=None,
rule_tsv=None,
rule_name=None,
add_dummy_prefix=False,
escape_whitespaces=False,
remove_extra_whitespaces=False):
"""Initialzie sentencePieceNormalizer.
Args:
model_file: The sentencepiece model file path.
model_proto: The sentencepiece model serialized proto.
rule_tsv: The normalization rule file in TSV format.
rule_name: Pre-defined normalization name.
add_dummy_prefix: add dummy prefix.
escape_whitespaces: escape whitespaces.
remove_extra_whitespaces: remove extra whitespaces.
"""

_sentencepiece_normalizer_init_native(self)

if model_file:
status = self.LoadFromFile(model_file)
elif model_proto:
status = self.LoadFromSerializedProto(model_proto)
elif rule_tsv:
status = self.LoadFromRuleTSV(rule_tsv)
elif rule_name:
status = self.LoadFromRuleName(rule_name)
else:
raise RuntimeError('no model is specified')

if status:
self._SetProtoField('add_dummy_prefix', add_dummy_prefix)
self._SetProtoField('escape_whitespaces', escape_whitespaces)
self._SetProtoField('remove_extra_whitespaces', remove_extra_whitespaces)

def Normalize(self, input, with_offsets=None):
def _normalize(text):
if with_offsets:
return self._NormalizeWithOffsets(text)
return self._Normalize(text)

if type(input) is list:
return [_normalize(x) for x in input]
return _normalize(input)


def __getstate__(self):
return self.serialized_model_proto()


def __setstate__(self, serialized_model_proto):
self.__init__()
self.LoadFromSerializedProto(serialized_model_proto)


# Register SentencePieceNormalizer in _sentencepiece:
_sentencepiece.SentencePieceNormalizer_swigregister(SentencePieceNormalizer)


import re
Expand Down Expand Up @@ -1045,7 +1137,9 @@ def _batched_func(self, arg):


_sentencepiece_processor_init_native = SentencePieceProcessor.__init__
_sentencepiece_normalizer_init_native = SentencePieceNormalizer.__init__
setattr(SentencePieceProcessor, '__init__', SentencePieceProcessor.Init)
setattr(SentencePieceNormalizer, '__init__', SentencePieceNormalizer.Init)

SentencePieceProcessor.Tokenize = SentencePieceProcessor.Encode
SentencePieceProcessor.Detokenize = SentencePieceProcessor.Decode
Expand All @@ -1058,6 +1152,7 @@ def _batched_func(self, arg):

_add_snake_case(SentencePieceProcessor)
_add_snake_case(SentencePieceTrainer)
_add_snake_case(SentencePieceNormalizer)
set_random_generator_seed = SetRandomGeneratorSeed
set_min_log_level = SetMinLogLevel

Expand Down
93 changes: 93 additions & 0 deletions python/src/sentencepiece/sentencepiece.i
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,10 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
%ignore sentencepiece::SentencePieceTrainer::SetPretokenizerForTraining;
%ignore sentencepiece::SentencePieceTrainer::GetPretokenizerForTraining;

%ignore sentencepiece::SentencePieceNormalizer::Load;
%ignore sentencepiece::SentencePieceNormalizer::Normalize;
%ignore sentencepiece::SentencePieceNormalizer::mutable_normalizer_spec;

%ignore sentencepiece::io::LoadModelProto;
%ignore sentencepiece::io::SaveModelProto;

Expand Down Expand Up @@ -1293,6 +1297,92 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
}
}

%extend sentencepiece::SentencePieceNormalizer {
sentencepiece::util::Status LoadFromFile(absl::string_view arg) {
return $self->Load(arg);
}

std::string _Normalize(absl::string_view text) {
std::string result;
const auto _status = $self->Normalize(text, &result);
if (!_status.ok()) throw _status;
return result;
}

std::pair<std::string, std::vector<size_t>> _NormalizeWithOffsets(absl::string_view text) {
std::pair<std::string, std::vector<size_t>> result;
const auto _status = $self->Normalize(text, &result.first, &result.second);
if (!_status.ok()) throw _status;
return result;
}

void _SetProtoField(absl::string_view name, bool value) {
sentencepiece::SentencePieceTrainer::SetProtoField(
name,
value ? "1" : "0",
$self->mutable_normalizer_spec()).IgnoreError();
}

%pythoncode %{
def Init(self,
model_file=None,
model_proto=None,
rule_tsv=None,
rule_name=None,
add_dummy_prefix=False,
escape_whitespaces=False,
remove_extra_whitespaces=False):
"""Initialzie sentencePieceNormalizer.
Args:
model_file: The sentencepiece model file path.
model_proto: The sentencepiece model serialized proto.
rule_tsv: The normalization rule file in TSV format.
rule_name: Pre-defined normalization name.
add_dummy_prefix: add dummy prefix.
escape_whitespaces: escape whitespaces.
remove_extra_whitespaces: remove extra whitespaces.
"""

_sentencepiece_normalizer_init_native(self)

if model_file:
status = self.LoadFromFile(model_file)
elif model_proto:
status = self.LoadFromSerializedProto(model_proto)
elif rule_tsv:
status = self.LoadFromRuleTSV(rule_tsv)
elif rule_name:
status = self.LoadFromRuleName(rule_name)
else:
raise RuntimeError('no model is specified')

if status:
self._SetProtoField('add_dummy_prefix', add_dummy_prefix)
self._SetProtoField('escape_whitespaces', escape_whitespaces)
self._SetProtoField('remove_extra_whitespaces', remove_extra_whitespaces)

def Normalize(self, input, with_offsets=None):
def _normalize(text):
if with_offsets:
return self._NormalizeWithOffsets(text)
return self._Normalize(text)

if type(input) is list:
return [_normalize(x) for x in input]
return _normalize(input)


def __getstate__(self):
return self.serialized_model_proto()


def __setstate__(self, serialized_model_proto):
self.__init__()
self.LoadFromSerializedProto(serialized_model_proto)
%}
}

%extend sentencepiece::ImmutableSentencePieceText_ImmutableSentencePiece {
%rename(_piece) piece;
%rename(_id) id;
Expand Down Expand Up @@ -1790,7 +1880,9 @@ def _batchnize(classname, name):


_sentencepiece_processor_init_native = SentencePieceProcessor.__init__
_sentencepiece_normalizer_init_native = SentencePieceNormalizer.__init__
setattr(SentencePieceProcessor, '__init__', SentencePieceProcessor.Init)
setattr(SentencePieceNormalizer, '__init__', SentencePieceNormalizer.Init)

SentencePieceProcessor.Tokenize = SentencePieceProcessor.Encode
SentencePieceProcessor.Detokenize = SentencePieceProcessor.Decode
Expand All @@ -1803,6 +1895,7 @@ for m in [

_add_snake_case(SentencePieceProcessor)
_add_snake_case(SentencePieceTrainer)
_add_snake_case(SentencePieceNormalizer)
set_random_generator_seed = SetRandomGeneratorSeed
set_min_log_level = SetMinLogLevel

Expand Down
Loading

0 comments on commit ed76ecc

Please sign in to comment.