Skip to content
This repository has been archived by the owner on Mar 19, 2024. It is now read-only.

Commit

Permalink
integrate fasttextpy
Browse files Browse the repository at this point in the history
Summary: Changes to support Python bindings.

Reviewed By: EdouardGrave

Differential Revision: D6177403

fbshipit-source-id: 51b271089496f72137ac7ef934f787a8c03e4e00
  • Loading branch information
cpuhrsch authored and facebook-github-bot committed Nov 1, 2017
1 parent 431c9e2 commit f10ec1f
Show file tree
Hide file tree
Showing 18 changed files with 1,951 additions and 174 deletions.
140 changes: 140 additions & 0 deletions python/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# fastText

[fastText](https://fasttext.cc/) is a library for efficient learning of word representations and sentence classification.

## Requirements

**fastText** builds on modern Mac OS and Linux distributions.
Since it uses C\++11 features, it requires a compiler with good C++11 support.
These include :

* (gcc-4.8 or newer) or (clang-3.3 or newer)

You will need

* python 2.7 or newer
* numpy & scipy
* [pybind11](https://github.com/pybind/pybind11)

## Building fastTextpy

In order to build `fastTextpy`, do the following:

```
$ python setup.py install
```

This will add the module fastTextpy to your python interpreter.
Depending on your system you might need to use 'sudo', for example

```
$ sudo python setup.py install
```

Now you can import this library with

```
import fastText
```


## Examples

If you're already largely familiar with fastText you could skip this section
and take a look at the examples within the doc folder.

## Using models

First, you'll need to train a model with fastText. For example

```
./fasttext skipgram -input data/fil9 -output result/fil9
```

You can see more examples within the scripts in the [fastText repository](https://github.com/facebookresearch/fastText).

Next, you can load this model from Python and query it.

```
from fastText import load_model
f = load_model('result/model.bin')
words, frequency = f.get_words()
subwords = f.get_subwords("Paris")
```

If you trained an unsupervised model, you can get word vectors with

```
vector = f.get_word_vector("London")
```

If you trained a supervised model, you can get the top k labels and get their probabilities with

```
k = 5
labels, probabilities = f.predict("I like this Product", k)
```

A more advanced application might look like this:

Getting the word vectors of all words:

```
words, frequency = f.get_words()
for w in words:
print((w, f.get_word_vector(w))
```

## Training models

Training a model is easy. For example

```
from fastTextpy import train_supervised
from fastTextpy import train_unsupervised
model_unsup = train_unsupervised(
input=<data>,
epoch=1,
model="cbow",
thread=10
)
model_unsup.save_model(<path>)
model_sup = train_supervised(
input=<labeled_data>
epoch=1,
thread=10
)
```

You can then use the model objects just as exemplified above.

To get extended help on these functions use the python help functions.

For example

```
Help on function train_unsupervised in module fastTextpy.FastText:
train_unsupervised(input, output=u'model', model=model_name.skipgram, lr=0.05, dim=100, ws=5, epoch=5, minCount=5, minCountLabel=0, minn=3, maxn=6, neg=5, wordNgrams=1, loss=loss_name.ns, bucket=2000000, thread=12, lrUpdateRate=100, t=0.0001, label=u'__label__', verbose=2, pretrainedVectors=u'', saveOutput=0)
```

## Processing data

You can tokenize using the fastText Dictionary method readWord.

This will give you a list of tokens split on the same whitespace characters that fastText splits on.

It will also add the EOS character as necessary, which is exposed via fastTextpy.EOS

Then resulting text is then stored entirely in memory.

For example:

```
from fastTextpy import tokenize
with open(<PATH>, 'r') as f:
tokens = tokenize(f.read())
```
3 changes: 3 additions & 0 deletions python/benchmarks/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
These programs allow us to compare the performance of a few key operations when consindering changes.

It is important to run these to make sure a change doesn't introduce a regression.
50 changes: 50 additions & 0 deletions python/benchmarks/get_word_vector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

from fastText import load_model
from fastText import tokenize
import sys
import time
import tempfile
import argparse


def get_word_vector(data, model):
t1 = time.time()
print("Reading")
with open(data, 'r') as f:
tokens = tokenize(f.read())
t2 = time.time()
print("Read TIME: " + str(t2 - t1))
print("Read NUM : " + str(len(tokens)))
f = load_model(model)
# This is not equivalent to piping the data into
# print-word-vector, because the data is tokenized
# first.
t3 = time.time()
i = 0
for t in tokens:
vec = f.get_word_vector(t)
i += 1
if i % 10000 == 0:
sys.stderr.write("\ri: " + str(float(i / len(tokens))))
sys.stderr.flush()
t4 = time.time()
print("\nVectoring: " + str(t4 - t3))


if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Simple benchmark for get_word_vector.')
parser.add_argument('model', help='A model file to use for benchmarking.')
parser.add_argument('data', help='A data file to use for benchmarking.')
args = parser.parse_args()
get_word_vector(args.data, args.model)
43 changes: 43 additions & 0 deletions python/doc/examples/bin_to_vec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#!/usr/bin/env python

# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import division, absolute_import, print_function

from fastText import load_model
import argparse
import errno

if __name__ == "__main__":
parser = argparse.ArgumentParser(
description=(
"Print fasttext .vec file to stdout from .bin file"
)
)
parser.add_argument(
"model", help="Model to use",
)
args = parser.parse_args()

f = load_model(args.model)
words = f.get_words()
print(str(len(words)) + " " + str(f.get_dimension()))
for w in words:
v = f.get_word_vector(w)
vstr = ""
for vi in v:
vstr += " " + str(vi)
try:
print(w + vstr)
except IOError as e:
if e.errno == errno.EPIPE:
pass
153 changes: 153 additions & 0 deletions python/doc/examples/compute_accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
#!/usr/bin/env python

# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import division, absolute_import, print_function

from fastText import load_model
import argparse
import numpy as np


def closest_ind(query, vectors, cossims):
np.matmul(vectors, query, out=cossims)


def print_score(
question, correct, num_qs, total_accuracy, semantic_accuracy, syntactic_accuracy
):
print(
(
"{0:>30}: ACCURACY TOP1: {3:.2f} % ({1} / {2})\t Total accuracy: {4:.2f} % Semantic accuracy: {5:.2f} % Syntactic accuracy: {6:.2f} %"
).format(
question, correct, num_qs, correct / float(num_qs) * 100, total_accuracy * 100,
semantic_accuracy * 100, syntactic_accuracy * 100,
)
)


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description=(
"compute_accuracy equivalent in Python. "
"See https://github.com/tmikolov/word2vec/blob/master/demo-word-accuracy.sh"
)
)
parser.add_argument(
"model",
help="Model to use",
)
parser.add_argument(
"question_words",
help="word questions similar to tmikolov's file (see help for link)",
)
parser.add_argument(
"threshold",
help="threshold used to limit number of words used",
)
args = parser.parse_args()
args.threshold = int(args.threshold)

f = load_model(args.model)
words, freq = f.get_words(include_freq=True)
words = words[:args.threshold]
vectors = np.zeros((len(words), f.get_dimension()), dtype=float)
for i in range(len(words)):
wv = f.get_word_vector(words[i])
wv = wv / np.linalg.norm(wv)
vectors[i] = wv

total_correct = 0
total_qs = 0
num_lines = 0

total_se_correct = 0
total_se_qs = 0

total_sy_correct = 0
total_sy_qs = 0

qid = 0
with open(args.question_words, 'r') as fqw:
correct = 0
num_qs = 0
question = ""
# For efficiency
cossims = np.zeros(len(words), dtype=float)
for line in fqw:
if line[0] == ":":
if question != "":
total_qs += num_qs
total_correct += correct
score = correct / num_qs
if (qid <= 5):
total_se_correct += correct
total_se_qs += num_qs
else:
total_sy_correct += correct
total_sy_qs += num_qs
print_score(
question,
correct,
num_qs,
total_correct / float(total_qs),
total_se_correct / float(total_se_qs) if total_se_qs > 0 else 0,
total_sy_correct / float(total_sy_qs) if total_sy_qs > 0 else 0,
)
correct = 0
num_qs = 0
question = line.strip().replace(":", "")
qid += 1
else:
num_lines += 1
qwords = line.split()
qwords = [x.lower().strip() for x in qwords]
found = True
for w in qwords:
if w not in words:
found = False
break
if not found:
continue
query = qwords[:3]
query = [f.get_word_vector(x) for x in query]
query = [x / np.linalg.norm(x) for x in query]
query = query[1] - query[0] + query[2]
ban_set = qwords[:3]
closest_ind(query, vectors, cossims)
rank = len(cossims) - 1
result_i = np.argpartition(cossims, rank)[rank]
result = words[result_i]
while result in ban_set:
rank -= 1
result_i = np.argpartition(cossims, rank)[rank]
result = words[result_i]
if result == qwords[3]:
correct += 1
num_qs += 1

total_qs += num_qs
total_correct += correct
total_sy_correct += correct
total_sy_qs += num_qs
print_score(
question,
correct,
num_qs,
total_correct / float(total_qs),
total_se_correct / float(total_se_qs) if total_se_qs > 0 else 0,
total_sy_correct / float(total_sy_qs) if total_sy_qs > 0 else 0,
)
print(
"Questions seen / total: {0} {1} {2:.2f} %".format(total_qs, num_lines,
total_qs / num_lines * 100)
)
Loading

0 comments on commit f10ec1f

Please sign in to comment.