Skip to content

Commit

Permalink
Add Opts::setup and OptsBuilder::setup
Browse files Browse the repository at this point in the history
  • Loading branch information
blackbeam committed Apr 14, 2023
1 parent 7c6572d commit 73dbb96
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 2 deletions.
37 changes: 37 additions & 0 deletions src/conn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,16 @@ impl Conn {
Ok(())
}

async fn run_setup_commands(&mut self) -> Result<()> {
let mut setup = self.inner.opts.setup().to_vec();

while let Some(query) = setup.pop() {
self.query_drop(query).await?;
}

Ok(())
}

/// Returns a future that resolves to [`Conn`].
pub fn new<T: Into<Opts>>(opts: T) -> crate::BoxFuture<'static, Conn> {
let opts = opts.into();
Expand Down Expand Up @@ -913,6 +923,7 @@ impl Conn {
conn.read_max_allowed_packet().await?;
conn.read_wait_timeout().await?;
conn.run_init_commands().await?;
conn.run_setup_commands().await?;

Ok(conn)
}
Expand Down Expand Up @@ -1011,6 +1022,7 @@ impl Conn {
self.routine(routines::ResetRoutine).await?;
self.inner.stmt_cache.clear();
self.inner.infile_handler = None;
self.run_setup_commands().await?;
}

Ok(supports_com_reset_connection)
Expand Down Expand Up @@ -1052,6 +1064,7 @@ impl Conn {
self.routine(routines::ChangeUser).await?;
self.inner.stmt_cache.clear();
self.inner.infile_handler = None;
self.run_setup_commands().await?;
Ok(())
}

Expand Down Expand Up @@ -1548,6 +1561,30 @@ mod test {
Ok(())
}

#[tokio::test]
async fn should_execute_setup_queries_on_reset() -> super::Result<()> {
let opts = OptsBuilder::from_opts(get_opts()).setup(vec!["SET @a = 42", "SET @b = 'foo'"]);
let mut conn = Conn::new(opts).await?;

// initial run
let mut result: Vec<(u8, String)> = conn.query("SELECT @a, @b").await?;
assert_eq!(result, vec![(42, "foo".into())]);

// after reset
if conn.reset().await? {
result = conn.query("SELECT @a, @b").await?;
assert_eq!(result, vec![(42, "foo".into())]);
}

// after change user
conn.change_user(Default::default()).await?;
result = conn.query("SELECT @a, @b").await?;
assert_eq!(result, vec![(42, "foo".into())]);

conn.disconnect().await?;
Ok(())
}

#[tokio::test]
async fn should_reset_the_connection() -> super::Result<()> {
let mut conn = Conn::new(get_opts()).await?;
Expand Down
22 changes: 20 additions & 2 deletions src/opts/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -403,9 +403,13 @@ pub(crate) struct MysqlOpts {
/// (defaults to `wait_timeout`).
conn_ttl: Option<Duration>,

/// Commands to execute on each new database connection.
/// Commands to execute once new connection is established.
init: Vec<String>,

/// Commands to execute on new connection and every time
/// [`Conn::reset`] or [`Conn::change_user`] is invoked.
setup: Vec<String>,

/// Number of prepared statements cached on the client side (per connection). Defaults to `10`.
stmt_cache_size: usize,

Expand Down Expand Up @@ -577,11 +581,17 @@ impl Opts {
self.inner.mysql_opts.db_name.as_ref().map(AsRef::as_ref)
}

/// Commands to execute on each new database connection.
/// Commands to execute once new connection is established.
pub fn init(&self) -> &[String] {
self.inner.mysql_opts.init.as_ref()
}

/// Commands to execute on new connection and every time
/// [`Conn::reset`] or [`Conn::change_user`] is invoked.
pub fn setup(&self) -> &[String] {
self.inner.mysql_opts.setup.as_ref()
}

/// TCP keep alive timeout in milliseconds (defaults to `None`).
///
/// # Connection URL
Expand Down Expand Up @@ -871,6 +881,7 @@ impl Default for MysqlOpts {
pass: None,
db_name: None,
init: vec![],
setup: vec![],
tcp_keepalive: None,
tcp_nodelay: true,
local_infile_handler: None,
Expand Down Expand Up @@ -1037,6 +1048,12 @@ impl OptsBuilder {
self
}

/// Defines setup queries. See [`Opts::setup`].
pub fn setup<T: Into<String>>(mut self, setup: Vec<T>) -> Self {
self.opts.setup = setup.into_iter().map(Into::into).collect();
self
}

/// Defines `tcp_keepalive` option. See [`Opts::tcp_keepalive`].
pub fn tcp_keepalive<T: Into<u32>>(mut self, tcp_keepalive: Option<T>) -> Self {
self.opts.tcp_keepalive = tcp_keepalive.map(Into::into);
Expand Down Expand Up @@ -1654,6 +1671,7 @@ mod test {
assert_eq!(url_opts.pass(), builder_opts.pass());
assert_eq!(url_opts.db_name(), builder_opts.db_name());
assert_eq!(url_opts.init(), builder_opts.init());
assert_eq!(url_opts.setup(), builder_opts.setup());
assert_eq!(url_opts.tcp_keepalive(), builder_opts.tcp_keepalive());
assert_eq!(url_opts.tcp_nodelay(), builder_opts.tcp_nodelay());
assert_eq!(url_opts.pool_opts(), builder_opts.pool_opts());
Expand Down

0 comments on commit 73dbb96

Please sign in to comment.