Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
danieldk committed Oct 22, 2023
1 parent d9a8333 commit eaf5cc4
Show file tree
Hide file tree
Showing 7 changed files with 334 additions and 129 deletions.
51 changes: 17 additions & 34 deletions syntaxdot-cli/src/subcommands/distill.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ use itertools::Itertools;
use ordered_float::NotNan;
use syntaxdot::config::{Config, PretrainConfig};
use syntaxdot::dataset::{
BatchedTensors, ConlluDataSet, DataSet, PlainTextDataSet, SentenceIterTools, SequenceLength,
BatchedTensors, ConlluDataSet, DataSet, PairedBatchedTensors, PairedDataSet, PlainTextDataSet,
SentenceIterTools, SequenceLength,
};
use syntaxdot::encoders::Encoders;
use syntaxdot::error::SyntaxDotError;
Expand Down Expand Up @@ -283,8 +284,7 @@ impl DistillApp {
auxiliary_params: &AuxiliaryParameters,
teacher: &Model,
student: &StudentModel,
teacher_train_file: &File,
student_train_file: &File,
distill_file: &File,
validation_file: &mut File,
) -> Result<()> {
let mut best_step = 0;
Expand All @@ -296,7 +296,7 @@ impl DistillApp {

let n_steps = self
.train_duration
.as_steps(teacher_train_file, self.batch_size)
.as_steps(distill_file, self.batch_size)
.context("Cannot determine number of training steps")?;

let train_progress = ProgressBar::new(n_steps as u64);
Expand All @@ -305,29 +305,17 @@ impl DistillApp {
)?);

while global_step < n_steps - 1 {
let mut teacher_train_dataset = Self::open_dataset(teacher_train_file)?;
let mut student_train_dataset = Self::open_dataset(student_train_file)?;

let teacher_train_batches = teacher_train_dataset
.sentences(&*teacher.tokenizer)?
.filter_by_len(self.max_len)
.batched_tensors(None, None, self.batch_size);

let student_train_batches = student_train_dataset
.sentences(&*student.tokenizer)?
let mut dataset = Self::open_dataset(distill_file)?;
let distill_batches = dataset
.tokenize_pair(&*teacher.tokenizer, &*student.tokenizer)?
.filter_by_len(self.max_len)
.batched_tensors(None, None, self.batch_size);
.paired_batched_tensors(None, None, self.batch_size);

for (teacher_steps, student_steps) in teacher_train_batches
.chunks(self.eval_steps)
.into_iter()
.zip(student_train_batches.chunks(self.eval_steps).into_iter())
{
for steps in distill_batches.chunks(self.eval_steps).into_iter() {
self.train_steps(
auxiliary_params,
&train_progress,
teacher_steps,
student_steps,
steps,
&mut global_step,
grad_scaler,
&teacher.model,
Expand Down Expand Up @@ -612,16 +600,14 @@ impl DistillApp {
&self,
auxiliary_params: &AuxiliaryParameters,
progress: &ProgressBar,
teacher_batches: impl Iterator<Item = Result<Tensors, SyntaxDotError>>,
student_batches: impl Iterator<Item = Result<Tensors, SyntaxDotError>>,
batches: impl Iterator<Item = Result<(Tensors, Tensors), SyntaxDotError>>,
global_step: &mut usize,
grad_scaler: &mut GradScaler<impl Optimizer>,
teacher: &BertModel,
student: &BertModel,
) -> Result<()> {
for (teacher_batch, student_batch) in teacher_batches.zip(student_batches) {
let teacher_batch = teacher_batch.context("Cannot read teacher batch")?;
let student_batch = student_batch.context("Cannot read student batch")?;
for batch in batches {
let (teacher_batch, student_batch) = batch.context("Cannot read batch")?;

let distill_loss = self.student_loss(
auxiliary_params,
Expand Down Expand Up @@ -866,7 +852,7 @@ impl DistillApp {
let mut n_tokens = 0;

for batch in dataset
.sentences(tokenizer)?
.tokenize(tokenizer)?
.filter_by_len(self.max_len)
.batched_tensors(biaffine_encoder, Some(encoders), self.batch_size)
{
Expand Down Expand Up @@ -1192,7 +1178,7 @@ impl SyntaxDotApp for DistillApp {
.get_one::<String>(MAX_LEN)
.map(|v| v.parse().context("Cannot parse maximum sentence length"))
.transpose()?
.map(SequenceLength::Tokens)
.map(SequenceLength::Pieces)
.unwrap_or(SequenceLength::Unbounded);
let mixed_precision = matches.get_flag(MIXED_PRECISION);
let warmup_steps = matches
Expand Down Expand Up @@ -1251,9 +1237,7 @@ impl SyntaxDotApp for DistillApp {
let student_config = load_config(&self.student_config)?;
let teacher = Model::load(&self.teacher_config, self.device, true, false, |_| 0)?;

let teacher_train_file = File::open(&self.train_data)
.context(format!("Cannot open train data file: {}", self.train_data))?;
let student_train_file = File::open(&self.train_data)
let distill_file = File::open(&self.train_data)
.context(format!("Cannot open train data file: {}", self.train_data))?;
let mut validation_file = File::open(&self.validation_data).context(format!(
"Cannot open validation data file: {}",
Expand All @@ -1276,8 +1260,7 @@ impl SyntaxDotApp for DistillApp {
&auxiliary_params,
&teacher,
&student,
&teacher_train_file,
&student_train_file,
&distill_file,
&mut validation_file,
)
.context("Model distillation failed")
Expand Down
2 changes: 1 addition & 1 deletion syntaxdot-cli/src/subcommands/finetune.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ impl FinetuneApp {
let mut encoder_loss = BTreeMap::new();

for batch in dataset
.sentences(tokenizer)?
.tokenize(tokenizer)?
.filter_by_len(self.max_len)
.batched_tensors(biaffine_encoder, Some(encoders), self.batch_size)
{
Expand Down
2 changes: 1 addition & 1 deletion syntaxdot/src/dataset/conll.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ where
{
type Iter = ConllIter<'a, Reader<&'a mut R>>;

fn sentences(self, tokenizer: &'a dyn Tokenize) -> Result<Self::Iter, SyntaxDotError> {
fn tokenize(self, tokenizer: &'a dyn Tokenize) -> Result<Self::Iter, SyntaxDotError> {
// Rewind to the beginning of the dataset (if necessary).
self.0.seek(SeekFrom::Start(0))?;

Expand Down
24 changes: 21 additions & 3 deletions syntaxdot/src/dataset/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ mod plaintext;
pub use plaintext::PlainTextDataSet;

pub(crate) mod tensor_iter;
pub use tensor_iter::BatchedTensors;
pub use tensor_iter::{BatchedTensors, PairedBatchedTensors};

mod sentence_itertools;
pub use sentence_itertools::{SentenceIterTools, SequenceLength};
Expand All @@ -26,7 +26,25 @@ pub trait DataSet<'a> {
/// Get an iterator over the sentences and pieces in a dataset.
///
/// The tokens are split in pieces with the given `tokenizer`.
fn sentences(self, tokenizer: &'a dyn Tokenize) -> Result<Self::Iter, SyntaxDotError>;
fn tokenize(self, tokenizer: &'a dyn Tokenize) -> Result<Self::Iter, SyntaxDotError>;
}

/// A data set consisting of annotated or unannotated sentences.
///
/// A `DataSet` provides an iterator over pairs of tokenized sentences.
/// The pairing is useful in applications where two tokenizer models
/// are used, such as distillation.
pub trait PairedDataSet<'a> {
type Iter: Iterator<Item = Result<(SentenceWithPieces, SentenceWithPieces), SyntaxDotError>>;

/// Get an iterator over the sentences and pieces in a dataset.
///
/// The tokens are split in pieces with the given tokenizers.
fn tokenize_pair(
self,
tokenizer1: &'a dyn Tokenize,
tokenizer2: &'a dyn Tokenize,
) -> Result<Self::Iter, SyntaxDotError>;
}

#[cfg(test)]
Expand Down Expand Up @@ -70,7 +88,7 @@ nu"#;
I: Iterator<Item = Result<SentenceWithPieces, SyntaxDotError>>,
{
dataset
.sentences(tokenizer)?
.tokenize(tokenizer)?
.map(|s| s.map(|s| s.pieces))
.collect::<Result<Vec<_>, _>>()
}
Expand Down
80 changes: 72 additions & 8 deletions syntaxdot/src/dataset/plaintext.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use std::io::{BufRead, Lines, Seek, SeekFrom};
use std::io::{BufRead, Lines, Seek};

use syntaxdot_tokenizers::{SentenceWithPieces, Tokenize};
use udgraph::graph::Sentence;
use udgraph::token::Token;

use crate::dataset::DataSet;
use crate::dataset::{DataSet, PairedDataSet};
use crate::error::SyntaxDotError;

/// A CoNLL-X data set.
Expand All @@ -21,25 +21,46 @@ impl<'a, R> DataSet<'a> for &'a mut PlainTextDataSet<R>
where
R: BufRead + Seek,
{
type Iter = PlainTextIter<'a, &'a mut R>;
type Iter = TokenizeIter<'a, &'a mut R>;

fn sentences(self, tokenizer: &'a dyn Tokenize) -> Result<Self::Iter, SyntaxDotError> {
fn tokenize(self, tokenizer: &'a dyn Tokenize) -> Result<Self::Iter, SyntaxDotError> {
// Rewind to the beginning of the dataset (if necessary).
self.0.seek(SeekFrom::Start(0))?;
self.0.rewind()?;

Ok(PlainTextIter {
Ok(TokenizeIter {
lines: (&mut self.0).lines(),
tokenizer,
})
}
}

pub struct PlainTextIter<'a, R> {
impl<'a, R> PairedDataSet<'a> for &'a mut PlainTextDataSet<R>
where
R: BufRead + Seek,
{
type Iter = TokenizePairIter<'a, &'a mut R>;

fn tokenize_pair(
self,
tokenizer1: &'a dyn Tokenize,
tokenizer2: &'a dyn Tokenize,
) -> Result<Self::Iter, SyntaxDotError> {
self.0.rewind()?;

Ok(TokenizePairIter {
lines: (&mut self.0).lines(),
tokenizer1,
tokenizer2,
})
}
}

pub struct TokenizeIter<'a, R> {
lines: Lines<R>,
tokenizer: &'a dyn Tokenize,
}

impl<'a, R> Iterator for PlainTextIter<'a, R>
impl<'a, R> Iterator for TokenizeIter<'a, R>
where
R: BufRead,
{
Expand Down Expand Up @@ -73,6 +94,49 @@ where
}
}

pub struct TokenizePairIter<'a, R> {
lines: Lines<R>,
tokenizer1: &'a dyn Tokenize,
tokenizer2: &'a dyn Tokenize,
}

impl<'a, R> Iterator for TokenizePairIter<'a, R>
where
R: BufRead,
{
type Item = Result<(SentenceWithPieces, SentenceWithPieces), SyntaxDotError>;

fn next(&mut self) -> Option<Self::Item> {
for line in &mut self.lines {
// Bubble up read errors.
let line = match line {
Ok(line) => line,
Err(err) => return Some(Err(SyntaxDotError::IoError(err))),
};

let line_trimmed = line.trim();

// Skip empty lines
if line_trimmed.is_empty() {
continue;
}

let sentence = line_trimmed
.split_terminator(' ')
.map(ToString::to_string)
.map(Token::new)
.collect::<Sentence>();

return Some(Ok((
self.tokenizer1.tokenize(sentence.clone()),
self.tokenizer2.tokenize(sentence),
)));
}

None
}
}

#[cfg(test)]
mod tests {
use std::io::{BufReader, Cursor};
Expand Down
Loading

0 comments on commit eaf5cc4

Please sign in to comment.