Skip to content

Commit

Permalink
add Iterable run method (#296)
Browse files Browse the repository at this point in the history
  • Loading branch information
superstator authored Jan 3, 2024
1 parent a7c81dc commit 17162c1
Show file tree
Hide file tree
Showing 6 changed files with 332 additions and 109 deletions.
7 changes: 7 additions & 0 deletions examples/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,12 @@ mod embedded {

fn main() {
let mut conn = Connection::open_in_memory().unwrap();

// run all migrations in one go
embedded::migrations::runner().run(&mut conn).unwrap();

// or create an iterator over migrations as they run
for migration in embedded::migrations::runner().run_iter(&mut conn) {
info!("Got a migration: {}", migration.expect("migration failed!"));
}
}
163 changes: 163 additions & 0 deletions refinery/tests/rusqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,34 @@ mod rusqlite {
assert_eq!(migrations[3].checksum(), applied_migrations[3].checksum());
}

#[test]
fn report_contains_applied_migrations_iter() {
let mut conn = Connection::open_in_memory().unwrap();
let applied_migrations = embedded::migrations::runner()
.run_iter(&mut conn)
.collect::<Result<Vec<_>, _>>()
.unwrap();

let migrations = get_migrations();

assert_eq!(4, applied_migrations.len());

assert_eq!(migrations[0].version(), applied_migrations[0].version());
assert_eq!(migrations[1].version(), applied_migrations[1].version());
assert_eq!(migrations[2].version(), applied_migrations[2].version());
assert_eq!(migrations[3].version(), applied_migrations[3].version());

assert_eq!(migrations[0].name(), migrations[0].name());
assert_eq!(migrations[1].name(), applied_migrations[1].name());
assert_eq!(migrations[2].name(), applied_migrations[2].name());
assert_eq!(migrations[3].name(), applied_migrations[3].name());

assert_eq!(migrations[0].checksum(), applied_migrations[0].checksum());
assert_eq!(migrations[1].checksum(), applied_migrations[1].checksum());
assert_eq!(migrations[2].checksum(), applied_migrations[2].checksum());
assert_eq!(migrations[3].checksum(), applied_migrations[3].checksum());
}

#[test]
fn creates_migration_table() {
let mut conn = Connection::open_in_memory().unwrap();
Expand All @@ -123,6 +151,26 @@ mod rusqlite {
assert_eq!(DEFAULT_TABLE_NAME, table_name);
}

#[test]
fn creates_migration_table_iter() {
let mut conn = Connection::open_in_memory().unwrap();
embedded::migrations::runner()
.run_iter(&mut conn)
.collect::<Result<Vec<_>, _>>()
.unwrap();
let table_name: String = conn
.query_row(
&format!(
"SELECT name FROM sqlite_master WHERE type='table' AND name='{}'",
DEFAULT_TABLE_NAME
),
[],
|row| row.get(0),
)
.unwrap();
assert_eq!(DEFAULT_TABLE_NAME, table_name);
}

#[test]
fn creates_migration_table_grouped_transaction() {
let mut conn = Connection::open_in_memory().unwrap();
Expand Down Expand Up @@ -163,6 +211,29 @@ mod rusqlite {
assert_eq!("New York", city);
}

#[test]
fn applies_migration_iter() {
let mut conn = Connection::open_in_memory().unwrap();

embedded::migrations::runner()
.run_iter(&mut conn)
.collect::<Result<Vec<_>, _>>()
.unwrap();

conn.execute(
"INSERT INTO persons (name, city) VALUES (?, ?)",
&["John Legend", "New York"],
)
.unwrap();
let (name, city): (String, String) = conn
.query_row("SELECT name, city FROM persons", [], |row| {
Ok((row.get(0).unwrap(), row.get(1).unwrap()))
})
.unwrap();
assert_eq!("John Legend", name);
assert_eq!("New York", city);
}

#[test]
fn applies_migration_grouped_transaction() {
let mut conn = Connection::open_in_memory().unwrap();
Expand Down Expand Up @@ -205,6 +276,28 @@ mod rusqlite {
);
}

#[test]
fn updates_schema_history_iter() {
let mut conn = Connection::open_in_memory().unwrap();

embedded::migrations::runner()
.run_iter(&mut conn)
.collect::<Result<Vec<_>, _>>()
.unwrap();

let current = conn
.get_last_applied_migration(DEFAULT_TABLE_NAME)
.unwrap()
.unwrap();

assert_eq!(4, current.version());

assert_eq!(
OffsetDateTime::now_utc().date(),
current.applied_on().unwrap().date()
);
}

#[test]
fn updates_schema_history_grouped_transaction() {
let mut conn = Connection::open_in_memory().unwrap();
Expand Down Expand Up @@ -259,6 +352,42 @@ mod rusqlite {
assert_eq!(2959965718684201605, applied_migrations[0].checksum());
assert_eq!(8238603820526370208, applied_migrations[1].checksum());
}
#[test]

fn updates_to_last_working_if_iter() {
let mut conn = Connection::open_in_memory().unwrap();

let result: Result<Vec<_>, _> = broken::migrations::runner().run_iter(&mut conn).collect();

assert!(result.is_err());
let current = conn
.get_last_applied_migration(DEFAULT_TABLE_NAME)
.unwrap()
.unwrap();

let err = result.unwrap_err();
let migrations = get_migrations();
let applied_migrations = broken::migrations::runner()
.get_applied_migrations(&mut conn)
.unwrap();

assert_eq!(
OffsetDateTime::now_utc().date(),
current.applied_on().unwrap().date()
);
assert_eq!(2, current.version());
assert!(err.report().unwrap().applied_migrations().is_empty());
assert_eq!(2, applied_migrations.len());

assert_eq!(1, applied_migrations[0].version());
assert_eq!(2, applied_migrations[1].version());

assert_eq!("initial", migrations[0].name());
assert_eq!("add_cars_table", applied_migrations[1].name());

assert_eq!(2959965718684201605, applied_migrations[0].checksum());
assert_eq!(8238603820526370208, applied_migrations[1].checksum());
}

#[test]
fn doesnt_update_to_last_working_if_grouped() {
Expand Down Expand Up @@ -366,6 +495,40 @@ mod rusqlite {
assert_eq!(migrations[2].checksum(), applied_migrations[2].checksum());
}

#[test]
fn migrates_to_target_migration_iter() {
let mut conn = Connection::open_in_memory().unwrap();

let applied_migrations = embedded::migrations::runner()
.set_target(Target::Version(3))
.run_iter(&mut conn)
.collect::<Result<Vec<_>, _>>()
.unwrap();

let current = conn
.get_last_applied_migration(DEFAULT_TABLE_NAME)
.unwrap()
.unwrap();

let migrations = get_migrations();

assert_eq!(3, current.version());

assert_eq!(3, applied_migrations.len());

assert_eq!(migrations[0].version(), applied_migrations[0].version());
assert_eq!(migrations[1].version(), applied_migrations[1].version());
assert_eq!(migrations[2].version(), applied_migrations[2].version());

assert_eq!(migrations[0].name(), migrations[0].name());
assert_eq!(migrations[1].name(), applied_migrations[1].name());
assert_eq!(migrations[2].name(), applied_migrations[2].name());

assert_eq!(migrations[0].checksum(), applied_migrations[0].checksum());
assert_eq!(migrations[1].checksum(), applied_migrations[1].checksum());
assert_eq!(migrations[2].checksum(), applied_migrations[2].checksum());
}

#[test]
fn migrates_to_target_migration_grouped() {
let mut conn = Connection::open_in_memory().unwrap();
Expand Down
84 changes: 79 additions & 5 deletions refinery_core/src/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ use regex::Regex;
use siphasher::sip::SipHasher13;
use time::OffsetDateTime;

use log::error;
use std::cmp::Ordering;
use std::collections::VecDeque;
use std::fmt;
use std::hash::{Hash, Hasher};
use std::sync::OnceLock;

use crate::error::Kind;
use crate::traits::DEFAULT_MIGRATION_TABLE_NAME;
use crate::traits::{sync::migrate as sync_migrate, DEFAULT_MIGRATION_TABLE_NAME};
use crate::{AsyncMigrate, Error, Migrate};
use std::fmt::Formatter;

Expand Down Expand Up @@ -364,13 +366,26 @@ impl Runner {
self
}

/// Creates an iterator over pending migrations, applying each before returning
/// the result from `next()`. If a migration fails, the iterator will return that
/// result and further calls to `next()` will return `None`.
pub fn run_iter<C>(
self,
connection: &mut C,
) -> impl Iterator<Item = Result<Migration, Error>> + '_
where
C: Migrate,
{
RunIterator::new(self, connection)
}

/// Runs the Migrations in the supplied database connection
pub fn run<C>(&self, conn: &'_ mut C) -> Result<Report, Error>
pub fn run<C>(&self, connection: &mut C) -> Result<Report, Error>
where
C: Migrate,
{
Migrate::migrate(
conn,
connection,
&self.migrations,
self.abort_divergent,
self.abort_missing,
Expand All @@ -381,12 +396,12 @@ impl Runner {
}

/// Runs the Migrations asynchronously in the supplied database connection
pub async fn run_async<C>(&self, conn: &mut C) -> Result<Report, Error>
pub async fn run_async<C>(&self, connection: &mut C) -> Result<Report, Error>
where
C: AsyncMigrate + Send,
{
AsyncMigrate::migrate(
conn,
connection,
&self.migrations,
self.abort_divergent,
self.abort_missing,
Expand All @@ -397,3 +412,62 @@ impl Runner {
.await
}
}

pub struct RunIterator<'a, C> {
connection: &'a mut C,
target: Target,
migration_table_name: String,
items: VecDeque<Migration>,
failed: bool,
}
impl<'a, C> RunIterator<'a, C>
where
C: Migrate,
{
pub(crate) fn new(runner: Runner, connection: &'a mut C) -> RunIterator<'a, C> {
RunIterator {
items: VecDeque::from(
Migrate::get_unapplied_migrations(
connection,
&runner.migrations,
runner.abort_divergent,
runner.abort_missing,
&runner.migration_table_name,
)
.unwrap(),
),
connection,
target: runner.target,
migration_table_name: runner.migration_table_name.clone(),
failed: false,
}
}
}
impl<C> Iterator for RunIterator<'_, C>
where
C: Migrate,
{
type Item = Result<Migration, Error>;

fn next(&mut self) -> Option<Self::Item> {
match self.failed {
true => None,
false => self.items.pop_front().and_then(|migration| {
sync_migrate(
self.connection,
vec![migration],
self.target,
&self.migration_table_name,
false,
)
.map(|r| r.applied_migrations.first().cloned())
.map_err(|e| {
error!("migration failed: {e:?}");
self.failed = true;
e
})
.transpose()
}),
}
}
}
27 changes: 5 additions & 22 deletions refinery_core/src/traits/async.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
use crate::error::WrapMigrationError;
use crate::traits::{
verify_migrations, ASSERT_MIGRATIONS_TABLE_QUERY, GET_APPLIED_MIGRATIONS_QUERY,
GET_LAST_APPLIED_MIGRATION_QUERY,
insert_migration_query, verify_migrations, ASSERT_MIGRATIONS_TABLE_QUERY,
GET_APPLIED_MIGRATIONS_QUERY, GET_LAST_APPLIED_MIGRATION_QUERY,
};
use crate::{Error, Migration, Report, Target};

use async_trait::async_trait;
use std::string::ToString;
use time::format_description::well_known::Rfc3339;

#[async_trait]
pub trait AsyncTransaction {
Expand Down Expand Up @@ -42,19 +41,11 @@ async fn migrate<T: AsyncTransaction>(

log::info!("applying migration: {}", migration);
migration.set_applied();
let update_query = &format!(
"INSERT INTO {} (version, name, applied_on, checksum) VALUES ({}, '{}', '{}', '{}')",
// safe to call unwrap as we just converted it to applied, and we are sure it can be formatted according to RFC 33339
migration_table_name,
migration.version(),
migration.name(),
migration.applied_on().unwrap().format(&Rfc3339).unwrap(),
migration.checksum()
);
let update_query = insert_migration_query(&migration, migration_table_name);
transaction
.execute(&[
migration.sql().as_ref().expect("sql must be Some!"),
update_query,
&update_query,
])
.await
.migration_err(
Expand Down Expand Up @@ -83,15 +74,7 @@ async fn migrate_grouped<T: AsyncTransaction>(
}

migration.set_applied();
let query = format!(
"INSERT INTO {} (version, name, applied_on, checksum) VALUES ({}, '{}', '{}', '{}')",
// safe to call unwrap as we just converted it to applied, and we are sure it can be formatted according to RFC 33339
migration_table_name,
migration.version(),
migration.name(),
migration.applied_on().unwrap().format(&Rfc3339).unwrap(),
migration.checksum()
);
let query = insert_migration_query(&migration, migration_table_name);

let sql = migration.sql().expect("sql must be Some!").to_string();

Expand Down
Loading

0 comments on commit 17162c1

Please sign in to comment.