Skip to content

Commit

Permalink
Rename some types and methods
Browse files Browse the repository at this point in the history
  • Loading branch information
danieldk committed Oct 22, 2023
1 parent eaf5cc4 commit 28e1783
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 30 deletions.
10 changes: 5 additions & 5 deletions syntaxdot-cli/src/subcommands/distill.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ use itertools::Itertools;
use ordered_float::NotNan;
use syntaxdot::config::{Config, PretrainConfig};
use syntaxdot::dataset::{
BatchedTensors, ConlluDataSet, DataSet, PairedBatchedTensors, PairedDataSet, PlainTextDataSet,
SentenceIterTools, SequenceLength,
BatchedTensors, ConlluDataSet, DataSet, MaxSentenceLen, PairedBatchedTensors, PairedDataSet,
PlainTextDataSet, SentenceIterTools,
};
use syntaxdot::encoders::Encoders;
use syntaxdot::error::SyntaxDotError;
Expand Down Expand Up @@ -108,7 +108,7 @@ pub struct DistillApp {
eval_steps: usize,
hidden_loss: Option<Vec<(usize, usize)>>,
keep_best_steps: Option<usize>,
max_len: SequenceLength,
max_len: MaxSentenceLen,
mixed_precision: bool,
lr_schedules: RefCell<LearningRateSchedules>,
student_config: String,
Expand Down Expand Up @@ -1178,8 +1178,8 @@ impl SyntaxDotApp for DistillApp {
.get_one::<String>(MAX_LEN)
.map(|v| v.parse().context("Cannot parse maximum sentence length"))
.transpose()?
.map(SequenceLength::Pieces)
.unwrap_or(SequenceLength::Unbounded);
.map(MaxSentenceLen::Pieces)
.unwrap_or(MaxSentenceLen::Unbounded);
let mixed_precision = matches.get_flag(MIXED_PRECISION);
let warmup_steps = matches
.get_one::<String>(WARMUP)
Expand Down
8 changes: 4 additions & 4 deletions syntaxdot-cli/src/subcommands/finetune.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use clap::{Arg, ArgAction, ArgMatches, Command};
use indicatif::ProgressStyle;
use ordered_float::NotNan;
use syntaxdot::dataset::{
BatchedTensors, ConlluDataSet, DataSet, SentenceIterTools, SequenceLength,
BatchedTensors, ConlluDataSet, DataSet, MaxSentenceLen, SentenceIterTools,
};
use syntaxdot::encoders::Encoders;
use syntaxdot::lr::{ExponentialDecay, LearningRateSchedule, PlateauLearningRate};
Expand Down Expand Up @@ -62,7 +62,7 @@ pub struct FinetuneApp {
continue_finetune: bool,
device: Device,
finetune_embeds: bool,
max_len: SequenceLength,
max_len: MaxSentenceLen,
label_smoothing: Option<f64>,
mixed_precision: bool,
summary_writer: Box<dyn ScalarWriter>,
Expand Down Expand Up @@ -606,8 +606,8 @@ impl SyntaxDotApp for FinetuneApp {
.context(format!("Cannot parse maximum sentence length: {}", v))
})
.transpose()?
.map(SequenceLength::Pieces)
.unwrap_or(SequenceLength::Unbounded);
.map(MaxSentenceLen::Pieces)
.unwrap_or(MaxSentenceLen::Unbounded);

let keep_best_epochs = matches
.get_one::<String>(KEEP_BEST_EPOCHS)
Expand Down
2 changes: 1 addition & 1 deletion syntaxdot/src/dataset/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub(crate) mod tensor_iter;
pub use tensor_iter::{BatchedTensors, PairedBatchedTensors};

mod sentence_itertools;
pub use sentence_itertools::{SentenceIterTools, SequenceLength};
pub use sentence_itertools::{MaxSentenceLen, SentenceIterTools};

/// A data set consisting of annotated or unannotated sentences.
///
Expand Down
44 changes: 24 additions & 20 deletions syntaxdot/src/dataset/sentence_itertools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ use syntaxdot_tokenizers::SentenceWithPieces;
use crate::error::SyntaxDotError;
use crate::util::RandomRemoveVec;

/// The length of a sequence.
/// The maximum length of a sentence.
///
/// This enum can be used to express the (maximum) length of a
/// sentence in tokens or in pieces.
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub enum SequenceLength {
pub enum MaxSentenceLen {
Tokens(usize),
Pieces(usize),
Unbounded,
Expand All @@ -25,7 +25,7 @@ where
///
/// If `max_len` is `None`, then the sentences will not be
/// filtered by length.
fn filter_by_len(self, max_len: SequenceLength) -> LengthFilter<Self>;
fn filter_by_len(self, max_len: MaxSentenceLen) -> LengthFilter<Self>;

/// Shuffle sentences.
///
Expand All @@ -39,7 +39,7 @@ impl<'a, I> SentenceIterTools<'a> for I
where
I: 'a + Iterator,
{
fn filter_by_len(self, max_len: SequenceLength) -> LengthFilter<Self> {
fn filter_by_len(self, max_len: MaxSentenceLen) -> LengthFilter<Self> {
LengthFilter {
inner: self,
max_len,
Expand All @@ -55,41 +55,45 @@ where
}
}

trait SentenceLength {
fn pieces_length(&self) -> usize;
fn tokens_length(&self) -> usize;
/// Get the length of a tokenized sentence.
trait SentenceLen {
/// Get the length of the sentence in pieces.
fn pieces_len(&self) -> usize;

/// Get the length of the sentence in tokens.
fn tokens_len(&self) -> usize;
}

impl SentenceLength for SentenceWithPieces {
fn pieces_length(&self) -> usize {
impl SentenceLen for SentenceWithPieces {
fn pieces_len(&self) -> usize {
self.pieces.len()
}

fn tokens_length(&self) -> usize {
fn tokens_len(&self) -> usize {
self.token_offsets.len()
}
}

impl SentenceLength for (SentenceWithPieces, SentenceWithPieces) {
fn pieces_length(&self) -> usize {
impl SentenceLen for (SentenceWithPieces, SentenceWithPieces) {
fn pieces_len(&self) -> usize {
self.0.pieces.len().max(self.1.pieces.len())
}

fn tokens_length(&self) -> usize {
fn tokens_len(&self) -> usize {
self.0.token_offsets.len().max(self.1.token_offsets.len())
}
}

/// An Iterator adapter filtering sentences by maximum length.
pub struct LengthFilter<I> {
inner: I,
max_len: SequenceLength,
max_len: MaxSentenceLen,
}

impl<I, S> Iterator for LengthFilter<I>
where
I: Iterator<Item = Result<S, SyntaxDotError>>,
S: SentenceLength,
S: SentenceLen,
{
type Item = Result<S, SyntaxDotError>;

Expand All @@ -98,13 +102,13 @@ where
// Treat Err as length 0 to keep our type as Result<Sentence, Error>. The iterator
// will properly return the Error at a later point.
let too_long = match self.max_len {
SequenceLength::Pieces(max_len) => {
sent.as_ref().map(|s| s.pieces_length()).unwrap_or(0) > max_len
MaxSentenceLen::Pieces(max_len) => {
sent.as_ref().map(|s| s.pieces_len()).unwrap_or(0) > max_len
}
SequenceLength::Tokens(max_len) => {
sent.as_ref().map(|s| s.tokens_length()).unwrap_or(0) > max_len
MaxSentenceLen::Tokens(max_len) => {
sent.as_ref().map(|s| s.tokens_len()).unwrap_or(0) > max_len
}
SequenceLength::Unbounded => false,
MaxSentenceLen::Unbounded => false,
};

if too_long {
Expand Down
2 changes: 2 additions & 0 deletions syntaxdot/src/dataset/tensor_iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ where
}
}

/// An iterator returning input and (optionally) output tensors for
/// pairs of tokenized sentences.
pub trait PairedBatchedTensors<'a> {
/// Get an iterator over batch tensors.
///
Expand Down

0 comments on commit 28e1783

Please sign in to comment.