oxide_sql_core/migrations/
migration.rs

1//! Migration trait and runner.
2//!
3//! Provides the `Migration` trait that all migrations implement, and the
4//! `MigrationRunner` that executes migrations in dependency order.
5
6use std::collections::{HashMap, HashSet, VecDeque};
7
8use super::dialect::MigrationDialect;
9use super::operation::Operation;
10use super::state::MigrationState;
11
12/// A database migration with typed up/down operations.
13///
14/// Implement this trait for each migration in your application.
15///
16/// # Example
17///
18/// ```rust
19/// use oxide_sql_core::migrations::{
20///     Migration, Operation, CreateTableBuilder,
21///     bigint, varchar, timestamp,
22/// };
23///
24/// pub struct Migration0001;
25///
26/// impl Migration for Migration0001 {
27///     const ID: &'static str = "0001_create_users";
28///
29///     fn up() -> Vec<Operation> {
30///         vec![
31///             CreateTableBuilder::new()
32///                 .name("users")
33///                 .column(bigint("id").primary_key().autoincrement().build())
34///                 .column(varchar("username", 255).not_null().unique().build())
35///                 .build()
36///                 .into(),
37///         ]
38///     }
39///
40///     fn down() -> Vec<Operation> {
41///         vec![
42///             Operation::drop_table("users"),
43///         ]
44///     }
45/// }
46/// ```
47pub trait Migration {
48    /// Unique migration identifier (e.g., "0001_initial", "0002_add_email").
49    ///
50    /// This ID is stored in the migrations table to track which migrations
51    /// have been applied.
52    const ID: &'static str;
53
54    /// Dependencies on other migrations (must run first).
55    ///
56    /// Each string should be the `ID` of another migration.
57    const DEPENDENCIES: &'static [&'static str] = &[];
58
59    /// Apply the migration (forward).
60    ///
61    /// Returns a list of operations to execute.
62    fn up() -> Vec<Operation>;
63
64    /// Reverse the migration (backward).
65    ///
66    /// Returns a list of operations to execute to undo the migration.
67    /// Return an empty vec if the migration is not reversible.
68    fn down() -> Vec<Operation>;
69}
70
71/// A registered migration with runtime-accessible metadata.
72pub struct RegisteredMigration {
73    /// Migration ID.
74    pub id: &'static str,
75    /// Dependencies.
76    pub dependencies: &'static [&'static str],
77    /// Function to get up operations.
78    pub up: fn() -> Vec<Operation>,
79    /// Function to get down operations.
80    pub down: fn() -> Vec<Operation>,
81}
82
83impl RegisteredMigration {
84    /// Creates a new registered migration from a `Migration` implementor.
85    #[must_use]
86    pub const fn new<M: Migration>() -> Self {
87        Self {
88            id: M::ID,
89            dependencies: M::DEPENDENCIES,
90            up: M::up,
91            down: M::down,
92        }
93    }
94}
95
96/// Status of a migration.
97#[derive(Debug, Clone, PartialEq, Eq)]
98pub struct MigrationStatus {
99    /// The migration ID.
100    pub id: &'static str,
101    /// Whether the migration has been applied.
102    pub applied: bool,
103    /// When the migration was applied (if known).
104    pub applied_at: Option<String>,
105}
106
107/// Runs migrations in dependency order.
108///
109/// The runner tracks which migrations are registered and uses the provided
110/// `MigrationState` to determine which migrations need to be applied.
111///
112/// # Example
113///
114/// ```rust
115/// use oxide_sql_core::migrations::{
116///     Migration, MigrationRunner, MigrationState, Operation,
117///     CreateTableBuilder, SqliteDialect, bigint,
118/// };
119///
120/// // Define a migration
121/// pub struct Migration0001;
122/// impl Migration for Migration0001 {
123///     const ID: &'static str = "0001_initial";
124///     fn up() -> Vec<Operation> {
125///         vec![CreateTableBuilder::new()
126///             .name("test")
127///             .column(bigint("id").primary_key().build())
128///             .build()
129///             .into()]
130///     }
131///     fn down() -> Vec<Operation> {
132///         vec![Operation::drop_table("test")]
133///     }
134/// }
135///
136/// // Create runner
137/// let mut runner = MigrationRunner::new(SqliteDialect::new());
138/// runner.register::<Migration0001>();
139///
140/// // Check status
141/// let state = MigrationState::new();
142/// let pending = runner.pending_migrations(&state);
143/// assert_eq!(pending.len(), 1);
144/// ```
145pub struct MigrationRunner<D: MigrationDialect> {
146    migrations: Vec<RegisteredMigration>,
147    dialect: D,
148}
149
150impl<D: MigrationDialect> MigrationRunner<D> {
151    /// Creates a new migration runner with the given dialect.
152    #[must_use]
153    pub fn new(dialect: D) -> Self {
154        Self {
155            migrations: Vec::new(),
156            dialect,
157        }
158    }
159
160    /// Registers a migration.
161    pub fn register<M: Migration>(&mut self) -> &mut Self {
162        self.migrations.push(RegisteredMigration::new::<M>());
163        self
164    }
165
166    /// Returns all registered migrations.
167    #[must_use]
168    pub fn migrations(&self) -> &[RegisteredMigration] {
169        &self.migrations
170    }
171
172    /// Returns the dialect.
173    #[must_use]
174    pub fn dialect(&self) -> &D {
175        &self.dialect
176    }
177
178    /// Returns migrations that haven't been applied yet.
179    #[must_use]
180    pub fn pending_migrations(&self, state: &MigrationState) -> Vec<&RegisteredMigration> {
181        self.migrations
182            .iter()
183            .filter(|m| !state.is_applied(m.id))
184            .collect()
185    }
186
187    /// Returns the status of all migrations.
188    #[must_use]
189    pub fn status(&self, state: &MigrationState) -> Vec<MigrationStatus> {
190        self.migrations
191            .iter()
192            .map(|m| MigrationStatus {
193                id: m.id,
194                applied: state.is_applied(m.id),
195                applied_at: None, // Would need to query the DB for this
196            })
197            .collect()
198    }
199
200    /// Returns migrations in dependency order (topological sort).
201    ///
202    /// Returns `Err` if there's a circular dependency.
203    pub fn sorted_migrations(&self) -> Result<Vec<&RegisteredMigration>, MigrationError> {
204        // Build dependency graph
205        let mut in_degree: HashMap<&str, usize> = HashMap::new();
206        let mut dependents: HashMap<&str, Vec<&str>> = HashMap::new();
207        let migration_map: HashMap<&str, &RegisteredMigration> =
208            self.migrations.iter().map(|m| (m.id, m)).collect();
209
210        for m in &self.migrations {
211            in_degree.entry(m.id).or_insert(0);
212            for dep in m.dependencies {
213                *in_degree.entry(m.id).or_insert(0) += 1;
214                dependents.entry(*dep).or_default().push(m.id);
215            }
216        }
217
218        // Kahn's algorithm for topological sort
219        let mut queue: VecDeque<&str> = in_degree
220            .iter()
221            .filter(|(_, deg)| **deg == 0)
222            .map(|(id, _)| *id)
223            .collect();
224        let mut result = Vec::new();
225
226        while let Some(id) = queue.pop_front() {
227            if let Some(m) = migration_map.get(id) {
228                result.push(*m);
229            }
230
231            if let Some(deps) = dependents.get(id) {
232                for dep in deps {
233                    if let Some(deg) = in_degree.get_mut(dep) {
234                        *deg -= 1;
235                        if *deg == 0 {
236                            queue.push_back(dep);
237                        }
238                    }
239                }
240            }
241        }
242
243        if result.len() != self.migrations.len() {
244            return Err(MigrationError::CircularDependency);
245        }
246
247        Ok(result)
248    }
249
250    /// Generates SQL for all pending migrations.
251    ///
252    /// Returns a list of (migration_id, sql_statements) pairs.
253    pub fn sql_for_pending(
254        &self,
255        state: &MigrationState,
256    ) -> Result<Vec<(&'static str, Vec<String>)>, MigrationError> {
257        let sorted = self.sorted_migrations()?;
258        let pending: Vec<_> = sorted
259            .into_iter()
260            .filter(|m| !state.is_applied(m.id))
261            .collect();
262
263        let mut result = Vec::new();
264        for migration in pending {
265            let operations = (migration.up)();
266            let sqls: Vec<String> = operations
267                .iter()
268                .map(|op| self.dialect.generate_sql(op))
269                .collect();
270            result.push((migration.id, sqls));
271        }
272
273        Ok(result)
274    }
275
276    /// Generates SQL for rolling back migrations.
277    ///
278    /// Returns a list of (migration_id, sql_statements) pairs in reverse order.
279    pub fn sql_for_rollback(
280        &self,
281        state: &MigrationState,
282        count: usize,
283    ) -> Result<Vec<(&'static str, Vec<String>)>, MigrationError> {
284        let sorted = self.sorted_migrations()?;
285
286        // Get applied migrations in reverse order
287        let applied: Vec<_> = sorted
288            .into_iter()
289            .rev()
290            .filter(|m| state.is_applied(m.id))
291            .take(count)
292            .collect();
293
294        let mut result = Vec::new();
295        for migration in applied {
296            let operations = (migration.down)();
297            if operations.is_empty() {
298                return Err(MigrationError::NotReversible(migration.id.to_string()));
299            }
300            let sqls: Vec<String> = operations
301                .iter()
302                .map(|op| self.dialect.generate_sql(op))
303                .collect();
304            result.push((migration.id, sqls));
305        }
306
307        Ok(result)
308    }
309
310    /// Validates that all dependencies exist and are registered.
311    pub fn validate(&self) -> Result<(), MigrationError> {
312        let ids: HashSet<&str> = self.migrations.iter().map(|m| m.id).collect();
313
314        for m in &self.migrations {
315            for dep in m.dependencies {
316                if !ids.contains(dep) {
317                    return Err(MigrationError::MissingDependency {
318                        migration: m.id.to_string(),
319                        dependency: (*dep).to_string(),
320                    });
321                }
322            }
323        }
324
325        // Check for circular dependencies
326        let _ = self.sorted_migrations()?;
327
328        Ok(())
329    }
330}
331
332/// Errors that can occur during migration.
333#[derive(Debug, Clone, PartialEq, Eq)]
334pub enum MigrationError {
335    /// A migration has a circular dependency.
336    CircularDependency,
337    /// A migration depends on another that doesn't exist.
338    MissingDependency {
339        /// The migration with the missing dependency.
340        migration: String,
341        /// The dependency that's missing.
342        dependency: String,
343    },
344    /// A migration is not reversible.
345    NotReversible(String),
346    /// Database error.
347    DatabaseError(String),
348}
349
350impl std::fmt::Display for MigrationError {
351    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
352        match self {
353            Self::CircularDependency => write!(f, "Circular dependency detected in migrations"),
354            Self::MissingDependency {
355                migration,
356                dependency,
357            } => write!(
358                f,
359                "Migration '{}' depends on '{}' which doesn't exist",
360                migration, dependency
361            ),
362            Self::NotReversible(id) => write!(f, "Migration '{}' is not reversible", id),
363            Self::DatabaseError(msg) => write!(f, "Database error: {}", msg),
364        }
365    }
366}
367
368impl std::error::Error for MigrationError {}
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373    use crate::migrations::column_builder::{bigint, boolean, varchar};
374    use crate::migrations::dialect::SqliteDialect;
375    use crate::migrations::table_builder::CreateTableBuilder;
376
377    // Test migrations
378    struct Migration0001;
379    impl Migration for Migration0001 {
380        const ID: &'static str = "0001_initial";
381        fn up() -> Vec<Operation> {
382            vec![
383                CreateTableBuilder::new()
384                    .name("users")
385                    .column(bigint("id").primary_key().autoincrement().build())
386                    .column(varchar("username", 255).not_null().build())
387                    .build()
388                    .into(),
389            ]
390        }
391        fn down() -> Vec<Operation> {
392            vec![Operation::drop_table("users")]
393        }
394    }
395
396    struct Migration0002;
397    impl Migration for Migration0002 {
398        const ID: &'static str = "0002_add_email";
399        const DEPENDENCIES: &'static [&'static str] = &["0001_initial"];
400        fn up() -> Vec<Operation> {
401            vec![Operation::add_column(
402                "users",
403                varchar("email", 255).build(),
404            )]
405        }
406        fn down() -> Vec<Operation> {
407            vec![Operation::drop_column("users", "email")]
408        }
409    }
410
411    struct Migration0003;
412    impl Migration for Migration0003 {
413        const ID: &'static str = "0003_add_active";
414        const DEPENDENCIES: &'static [&'static str] = &["0002_add_email"];
415        fn up() -> Vec<Operation> {
416            vec![Operation::add_column(
417                "users",
418                boolean("active").not_null().default_bool(true).build(),
419            )]
420        }
421        fn down() -> Vec<Operation> {
422            vec![Operation::drop_column("users", "active")]
423        }
424    }
425
426    #[test]
427    fn test_register_migrations() {
428        let mut runner = MigrationRunner::new(SqliteDialect::new());
429        runner.register::<Migration0001>();
430        runner.register::<Migration0002>();
431
432        assert_eq!(runner.migrations().len(), 2);
433    }
434
435    #[test]
436    fn test_pending_migrations() {
437        let mut runner = MigrationRunner::new(SqliteDialect::new());
438        runner.register::<Migration0001>();
439        runner.register::<Migration0002>();
440
441        let state = MigrationState::new();
442        let pending = runner.pending_migrations(&state);
443        assert_eq!(pending.len(), 2);
444
445        let mut state = MigrationState::new();
446        state.mark_applied("0001_initial");
447        let pending = runner.pending_migrations(&state);
448        assert_eq!(pending.len(), 1);
449        assert_eq!(pending[0].id, "0002_add_email");
450    }
451
452    #[test]
453    fn test_topological_sort() {
454        let mut runner = MigrationRunner::new(SqliteDialect::new());
455        // Register in reverse order
456        runner.register::<Migration0003>();
457        runner.register::<Migration0001>();
458        runner.register::<Migration0002>();
459
460        let sorted = runner.sorted_migrations().unwrap();
461        let ids: Vec<_> = sorted.iter().map(|m| m.id).collect();
462
463        // 0001 must come before 0002, 0002 must come before 0003
464        let pos_0001 = ids.iter().position(|&id| id == "0001_initial").unwrap();
465        let pos_0002 = ids.iter().position(|&id| id == "0002_add_email").unwrap();
466        let pos_0003 = ids.iter().position(|&id| id == "0003_add_active").unwrap();
467
468        assert!(pos_0001 < pos_0002);
469        assert!(pos_0002 < pos_0003);
470    }
471
472    #[test]
473    fn test_sql_generation() {
474        let mut runner = MigrationRunner::new(SqliteDialect::new());
475        runner.register::<Migration0001>();
476
477        let state = MigrationState::new();
478        let sql = runner.sql_for_pending(&state).unwrap();
479
480        assert_eq!(sql.len(), 1);
481        assert_eq!(sql[0].0, "0001_initial");
482        assert!(!sql[0].1.is_empty());
483        assert!(sql[0].1[0].contains("CREATE TABLE"));
484    }
485
486    #[test]
487    fn test_rollback_sql() {
488        let mut runner = MigrationRunner::new(SqliteDialect::new());
489        runner.register::<Migration0001>();
490        runner.register::<Migration0002>();
491
492        let mut state = MigrationState::new();
493        state.mark_applied("0001_initial");
494        state.mark_applied("0002_add_email");
495
496        let sql = runner.sql_for_rollback(&state, 1).unwrap();
497        assert_eq!(sql.len(), 1);
498        assert_eq!(sql[0].0, "0002_add_email");
499        assert!(sql[0].1[0].contains("DROP COLUMN"));
500    }
501
502    #[test]
503    fn test_missing_dependency() {
504        struct BadMigration;
505        impl Migration for BadMigration {
506            const ID: &'static str = "bad_migration";
507            const DEPENDENCIES: &'static [&'static str] = &["nonexistent"];
508            fn up() -> Vec<Operation> {
509                vec![]
510            }
511            fn down() -> Vec<Operation> {
512                vec![]
513            }
514        }
515
516        let mut runner = MigrationRunner::new(SqliteDialect::new());
517        runner.register::<BadMigration>();
518
519        let result = runner.validate();
520        assert!(matches!(
521            result,
522            Err(MigrationError::MissingDependency { .. })
523        ));
524    }
525
526    #[test]
527    fn test_status() {
528        let mut runner = MigrationRunner::new(SqliteDialect::new());
529        runner.register::<Migration0001>();
530        runner.register::<Migration0002>();
531
532        let mut state = MigrationState::new();
533        state.mark_applied("0001_initial");
534
535        let status = runner.status(&state);
536        assert_eq!(status.len(), 2);
537
538        let s1 = status.iter().find(|s| s.id == "0001_initial").unwrap();
539        assert!(s1.applied);
540
541        let s2 = status.iter().find(|s| s.id == "0002_add_email").unwrap();
542        assert!(!s2.applied);
543    }
544}