Skip to content

Commit

Permalink
Improves the thread utilization in batch encoding/decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
taku910 committed Aug 5, 2023
1 parent 635fe84 commit 8cbdf13
Show file tree
Hide file tree
Showing 3 changed files with 655 additions and 610 deletions.
36 changes: 6 additions & 30 deletions python/src/sentencepiece/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
# This file was automatically generated by SWIG (http://www.swig.org).
# Version 4.0.2
# This file was automatically generated by SWIG (https://www.swig.org).
# Version 4.1.0
#
# Do not make changes to this file unless you know what you are doing--modify
# Do not make changes to this file unless you know what you are doing - modify
# the SWIG interface file instead.

from sys import version_info as _swig_python_version_info
if _swig_python_version_info < (2, 7, 0):
raise RuntimeError("Python 2.7 or later required")

# Import the low-level C/C++ module
if __package__ or "." in __name__:
from . import _sentencepiece
Expand All @@ -29,10 +26,10 @@ def _swig_repr(self):

def _swig_setattr_nondynamic_instance_variable(set):
def set_instance_attr(self, name, value):
if name == "thisown":
self.this.own(value)
elif name == "this":
if name == "this":
set(self, name, value)
elif name == "thisown":
self.this.own(value)
elif hasattr(self, name) and isinstance(getattr(type(self), name), property):
set(self, name, value)
else:
Expand Down Expand Up @@ -109,7 +106,6 @@ def __hash__(self):

# Register ImmutableSentencePieceText_ImmutableSentencePiece in _sentencepiece:
_sentencepiece.ImmutableSentencePieceText_ImmutableSentencePiece_swigregister(ImmutableSentencePieceText_ImmutableSentencePiece)

class ImmutableSentencePieceText(object):
thisown = property(lambda x: x.this.own(), lambda x, v: x.this.own(v), doc="The membership flag")
__repr__ = _swig_repr
Expand Down Expand Up @@ -179,7 +175,6 @@ def __str__(self):

# Register ImmutableSentencePieceText in _sentencepiece:
_sentencepiece.ImmutableSentencePieceText_swigregister(ImmutableSentencePieceText)

class ImmutableNBestSentencePieceText(object):
thisown = property(lambda x: x.this.own(), lambda x, v: x.this.own(v), doc="The membership flag")
__repr__ = _swig_repr
Expand Down Expand Up @@ -237,7 +232,6 @@ def __str__(self):

# Register ImmutableNBestSentencePieceText in _sentencepiece:
_sentencepiece.ImmutableNBestSentencePieceText_swigregister(ImmutableNBestSentencePieceText)

class SentencePieceProcessor(object):
thisown = property(lambda x: x.this.own(), lambda x, v: x.this.own(v), doc="The membership flag")
__repr__ = _swig_repr
Expand Down Expand Up @@ -908,7 +902,6 @@ def Load(self, model_file=None, model_proto=None):
# Register SentencePieceProcessor in _sentencepiece:
_sentencepiece.SentencePieceProcessor_swigregister(SentencePieceProcessor)


def SetRandomGeneratorSeed(seed):
return _sentencepiece.SetRandomGeneratorSeed(seed)
class SentencePieceTrainer(object):
Expand Down Expand Up @@ -992,22 +985,6 @@ def Train(arg=None, logstream=None, **kwargs):
# Register SentencePieceTrainer in _sentencepiece:
_sentencepiece.SentencePieceTrainer_swigregister(SentencePieceTrainer)

def SentencePieceTrainer__TrainFromString(arg):
return _sentencepiece.SentencePieceTrainer__TrainFromString(arg)

def SentencePieceTrainer__TrainFromMap(args):
return _sentencepiece.SentencePieceTrainer__TrainFromMap(args)

def SentencePieceTrainer__TrainFromMap2(args, iter):
return _sentencepiece.SentencePieceTrainer__TrainFromMap2(args, iter)

def SentencePieceTrainer__TrainFromMap3(args):
return _sentencepiece.SentencePieceTrainer__TrainFromMap3(args)

def SentencePieceTrainer__TrainFromMap4(args, iter):
return _sentencepiece.SentencePieceTrainer__TrainFromMap4(args, iter)



import re
import csv
Expand Down Expand Up @@ -1084,4 +1061,3 @@ def __exit__(self, type, value, traceback):
self.ostream.close()



25 changes: 16 additions & 9 deletions python/src/sentencepiece/sentencepiece.i
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

%{

#include <atomic>
#include <iostream>
#include <algorithm>
#include <functional>
Expand Down Expand Up @@ -246,9 +247,11 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
InitNumThreads(ins, &num_threads); \
{ \
ThreadPool pool(ins.size()); \
std::atomic<size_t> index = 0; \
for (int n = 0; n < num_threads; ++n) { \
pool.Schedule([&, n]() { \
for (size_t i = n; i < ins.size(); i += num_threads) { \
pool.Schedule([&]() { \
size_t i = 0; \
while ((i = std::atomic_fetch_add(&index, 1)) < outs.size()) { \
auto out = enable_sampling ? \
self->Sample##FuncName(ins[i], \
nbest_size, alpha) : \
Expand All @@ -267,10 +270,12 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
std::vector<OutType> outs(ins.size()); \
InitNumThreads(ins, &num_threads); \
{ \
std::atomic<size_t> index = 0; \
ThreadPool pool(ins.size()); \
for (int n = 0; n < num_threads; ++n) { \
pool.Schedule([&, n]() { \
for (size_t i = n; i < ins.size(); i += num_threads) { \
pool.Schedule([&]() { \
size_t i = 0; \
while ((i = std::atomic_fetch_add(&index, 1)) < outs.size()) { \
CheckIds(ins[i], self->GetPieceSize()); \
auto out = self->FuncName(ins[i]); \
ConvertToUnicodeSpans(&out); \
Expand Down Expand Up @@ -655,12 +660,14 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
InitNumThreads(ins, &num_threads);
{
ThreadPool pool(ins.size());
std::atomic<size_t> index = 0;
for (int n = 0; n < num_threads; ++n) {
pool.Schedule([&, n]() {
for (size_t i = n; i < ins.size(); i += num_threads) {
outs[i] = self->CalculateEntropy(ins[i], alpha);
}
});
pool.Schedule([&]() {
size_t i = 0;
while ((i = std::atomic_fetch_add(&index, 1)) < outs.size()) {
outs[i] = self->CalculateEntropy(ins[i], alpha);
}
});
}
}
return outs;
Expand Down
Loading

0 comments on commit 8cbdf13

Please sign in to comment.