Skip to content

Commit

Permalink
refactor: opaque Error
Browse files Browse the repository at this point in the history
Shaves off the `thiserror` dependency and should improve compile times slightly.
Unfortunately this does mean we can't match on `Error` anymore, though I'm not sure if that was ever useful to begin with.
  • Loading branch information
decahedron1 committed Aug 31, 2024
1 parent 18feafe commit 2d26f1f
Show file tree
Hide file tree
Showing 44 changed files with 424 additions and 695 deletions.
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ qnn = [ "ort-sys/qnn" ]

[dependencies]
ndarray = { version = "0.16", optional = true }
thiserror = "1.0"
ort-sys = { version = "2.0.0-rc.5", path = "ort-sys" }
libloading = { version = "0.8", optional = true }

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::path::Path;

use ndarray::{s, Array1, Array2, Axis, Ix2};
use ort::{CUDAExecutionProvider, GraphOptimizationLevel, Session};
use ort::{CUDAExecutionProvider, Error, GraphOptimizationLevel, Session};
use tokenizers::Tokenizer;

/// Example usage of a text embedding model like Sentence Transformers' `all-mini-lm-l6` model for semantic textual
Expand Down Expand Up @@ -31,7 +31,7 @@ fn main() -> ort::Result<()> {
let inputs = vec!["The weather outside is lovely.", "It's so sunny outside!", "She drove to the stadium."];

// Encode our input strings. `encode_batch` will pad each input to be the same length.
let encodings = tokenizer.encode_batch(inputs.clone(), false)?;
let encodings = tokenizer.encode_batch(inputs.clone(), false).map_err(|e| Error::new(e.to_string()))?;

// Get the padded length of each encoding.
let padded_token_length = encodings[0].len();
Expand Down
24 changes: 10 additions & 14 deletions src/environment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,7 @@ use tracing::{debug, Level};

#[cfg(feature = "load-dynamic")]
use crate::G_ORT_DYLIB_PATH;
use crate::{
error::{Error, Result},
execution_providers::ExecutionProviderDispatch,
extern_system_fn, ortsys
};
use crate::{error::Result, execution_providers::ExecutionProviderDispatch, extern_system_fn, ortsys};

struct EnvironmentSingleton {
lock: RwLock<Option<Arc<Environment>>>
Expand Down Expand Up @@ -154,19 +150,19 @@ impl EnvironmentBuilder {
let cname = CString::new(self.name.clone()).unwrap_or_else(|_| unreachable!());

let mut thread_options: *mut ort_sys::OrtThreadingOptions = std::ptr::null_mut();
ortsys![unsafe CreateThreadingOptions(&mut thread_options) -> Error::CreateEnvironment; nonNull(thread_options)];
ortsys![unsafe CreateThreadingOptions(&mut thread_options)?; nonNull(thread_options)];
if let Some(inter_op_parallelism) = global_thread_pool.inter_op_parallelism {
ortsys![unsafe SetGlobalInterOpNumThreads(thread_options, inter_op_parallelism) -> Error::CreateEnvironment];
ortsys![unsafe SetGlobalInterOpNumThreads(thread_options, inter_op_parallelism)?];
}
if let Some(intra_op_parallelism) = global_thread_pool.intra_op_parallelism {
ortsys![unsafe SetGlobalIntraOpNumThreads(thread_options, intra_op_parallelism) -> Error::CreateEnvironment];
ortsys![unsafe SetGlobalIntraOpNumThreads(thread_options, intra_op_parallelism)?];
}
if let Some(spin_control) = global_thread_pool.spin_control {
ortsys![unsafe SetGlobalSpinControl(thread_options, i32::from(spin_control)) -> Error::CreateEnvironment];
ortsys![unsafe SetGlobalSpinControl(thread_options, i32::from(spin_control))?];
}
if let Some(intra_op_thread_affinity) = global_thread_pool.intra_op_thread_affinity {
let cstr = CString::new(intra_op_thread_affinity).unwrap_or_else(|_| unreachable!());
ortsys![unsafe SetGlobalIntraOpThreadAffinity(thread_options, cstr.as_ptr()) -> Error::CreateEnvironment];
ortsys![unsafe SetGlobalIntraOpThreadAffinity(thread_options, cstr.as_ptr())?];
}

ortsys![
Expand All @@ -177,7 +173,7 @@ impl EnvironmentBuilder {
cname.as_ptr(),
thread_options,
&mut env_ptr
) -> Error::CreateEnvironment;
)?;
nonNull(env_ptr)
];
ortsys![unsafe ReleaseThreadingOptions(thread_options)];
Expand All @@ -195,17 +191,17 @@ impl EnvironmentBuilder {
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE,
cname.as_ptr(),
&mut env_ptr
) -> Error::CreateEnvironment;
)?;
nonNull(env_ptr)
];
(env_ptr, false)
};
debug!(env_ptr = format!("{env_ptr:?}").as_str(), "Environment created");

if self.telemetry {
ortsys![unsafe EnableTelemetryEvents(env_ptr) -> Error::CreateEnvironment];
ortsys![unsafe EnableTelemetryEvents(env_ptr)?];
} else {
ortsys![unsafe DisableTelemetryEvents(env_ptr) -> Error::CreateEnvironment];
ortsys![unsafe DisableTelemetryEvents(env_ptr)?];
}

let mut env_lock = G_ENV.lock.write().expect("poisoned lock");
Expand Down
Loading

0 comments on commit 2d26f1f

Please sign in to comment.