oxide_sql_core/migrations/
migration.rs1use std::collections::{HashMap, HashSet, VecDeque};
7
8use super::dialect::MigrationDialect;
9use super::operation::Operation;
10use super::state::MigrationState;
11
12pub trait Migration {
48 const ID: &'static str;
53
54 const DEPENDENCIES: &'static [&'static str] = &[];
58
59 fn up() -> Vec<Operation>;
63
64 fn down() -> Vec<Operation>;
69}
70
71pub struct RegisteredMigration {
73 pub id: &'static str,
75 pub dependencies: &'static [&'static str],
77 pub up: fn() -> Vec<Operation>,
79 pub down: fn() -> Vec<Operation>,
81}
82
83impl RegisteredMigration {
84 #[must_use]
86 pub const fn new<M: Migration>() -> Self {
87 Self {
88 id: M::ID,
89 dependencies: M::DEPENDENCIES,
90 up: M::up,
91 down: M::down,
92 }
93 }
94}
95
96#[derive(Debug, Clone, PartialEq, Eq)]
98pub struct MigrationStatus {
99 pub id: &'static str,
101 pub applied: bool,
103 pub applied_at: Option<String>,
105}
106
107pub struct MigrationRunner<D: MigrationDialect> {
146 migrations: Vec<RegisteredMigration>,
147 dialect: D,
148}
149
150impl<D: MigrationDialect> MigrationRunner<D> {
151 #[must_use]
153 pub fn new(dialect: D) -> Self {
154 Self {
155 migrations: Vec::new(),
156 dialect,
157 }
158 }
159
160 pub fn register<M: Migration>(&mut self) -> &mut Self {
162 self.migrations.push(RegisteredMigration::new::<M>());
163 self
164 }
165
166 #[must_use]
168 pub fn migrations(&self) -> &[RegisteredMigration] {
169 &self.migrations
170 }
171
172 #[must_use]
174 pub fn dialect(&self) -> &D {
175 &self.dialect
176 }
177
178 #[must_use]
180 pub fn pending_migrations(&self, state: &MigrationState) -> Vec<&RegisteredMigration> {
181 self.migrations
182 .iter()
183 .filter(|m| !state.is_applied(m.id))
184 .collect()
185 }
186
187 #[must_use]
189 pub fn status(&self, state: &MigrationState) -> Vec<MigrationStatus> {
190 self.migrations
191 .iter()
192 .map(|m| MigrationStatus {
193 id: m.id,
194 applied: state.is_applied(m.id),
195 applied_at: None, })
197 .collect()
198 }
199
200 pub fn sorted_migrations(&self) -> Result<Vec<&RegisteredMigration>, MigrationError> {
204 let mut in_degree: HashMap<&str, usize> = HashMap::new();
206 let mut dependents: HashMap<&str, Vec<&str>> = HashMap::new();
207 let migration_map: HashMap<&str, &RegisteredMigration> =
208 self.migrations.iter().map(|m| (m.id, m)).collect();
209
210 for m in &self.migrations {
211 in_degree.entry(m.id).or_insert(0);
212 for dep in m.dependencies {
213 *in_degree.entry(m.id).or_insert(0) += 1;
214 dependents.entry(*dep).or_default().push(m.id);
215 }
216 }
217
218 let mut queue: VecDeque<&str> = in_degree
220 .iter()
221 .filter(|(_, deg)| **deg == 0)
222 .map(|(id, _)| *id)
223 .collect();
224 let mut result = Vec::new();
225
226 while let Some(id) = queue.pop_front() {
227 if let Some(m) = migration_map.get(id) {
228 result.push(*m);
229 }
230
231 if let Some(deps) = dependents.get(id) {
232 for dep in deps {
233 if let Some(deg) = in_degree.get_mut(dep) {
234 *deg -= 1;
235 if *deg == 0 {
236 queue.push_back(dep);
237 }
238 }
239 }
240 }
241 }
242
243 if result.len() != self.migrations.len() {
244 return Err(MigrationError::CircularDependency);
245 }
246
247 Ok(result)
248 }
249
250 pub fn sql_for_pending(
254 &self,
255 state: &MigrationState,
256 ) -> Result<Vec<(&'static str, Vec<String>)>, MigrationError> {
257 let sorted = self.sorted_migrations()?;
258 let pending: Vec<_> = sorted
259 .into_iter()
260 .filter(|m| !state.is_applied(m.id))
261 .collect();
262
263 let mut result = Vec::new();
264 for migration in pending {
265 let operations = (migration.up)();
266 let sqls: Vec<String> = operations
267 .iter()
268 .map(|op| self.dialect.generate_sql(op))
269 .collect();
270 result.push((migration.id, sqls));
271 }
272
273 Ok(result)
274 }
275
276 pub fn sql_for_rollback(
280 &self,
281 state: &MigrationState,
282 count: usize,
283 ) -> Result<Vec<(&'static str, Vec<String>)>, MigrationError> {
284 let sorted = self.sorted_migrations()?;
285
286 let applied: Vec<_> = sorted
288 .into_iter()
289 .rev()
290 .filter(|m| state.is_applied(m.id))
291 .take(count)
292 .collect();
293
294 let mut result = Vec::new();
295 for migration in applied {
296 let operations = (migration.down)();
297 if operations.is_empty() {
298 return Err(MigrationError::NotReversible(migration.id.to_string()));
299 }
300 let sqls: Vec<String> = operations
301 .iter()
302 .map(|op| self.dialect.generate_sql(op))
303 .collect();
304 result.push((migration.id, sqls));
305 }
306
307 Ok(result)
308 }
309
310 pub fn validate(&self) -> Result<(), MigrationError> {
312 let ids: HashSet<&str> = self.migrations.iter().map(|m| m.id).collect();
313
314 for m in &self.migrations {
315 for dep in m.dependencies {
316 if !ids.contains(dep) {
317 return Err(MigrationError::MissingDependency {
318 migration: m.id.to_string(),
319 dependency: (*dep).to_string(),
320 });
321 }
322 }
323 }
324
325 let _ = self.sorted_migrations()?;
327
328 Ok(())
329 }
330}
331
332#[derive(Debug, Clone, PartialEq, Eq)]
334pub enum MigrationError {
335 CircularDependency,
337 MissingDependency {
339 migration: String,
341 dependency: String,
343 },
344 NotReversible(String),
346 DatabaseError(String),
348}
349
350impl std::fmt::Display for MigrationError {
351 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
352 match self {
353 Self::CircularDependency => write!(f, "Circular dependency detected in migrations"),
354 Self::MissingDependency {
355 migration,
356 dependency,
357 } => write!(
358 f,
359 "Migration '{}' depends on '{}' which doesn't exist",
360 migration, dependency
361 ),
362 Self::NotReversible(id) => write!(f, "Migration '{}' is not reversible", id),
363 Self::DatabaseError(msg) => write!(f, "Database error: {}", msg),
364 }
365 }
366}
367
368impl std::error::Error for MigrationError {}
369
370#[cfg(test)]
371mod tests {
372 use super::*;
373 use crate::migrations::column_builder::{bigint, boolean, varchar};
374 use crate::migrations::dialect::SqliteDialect;
375 use crate::migrations::table_builder::CreateTableBuilder;
376
377 struct Migration0001;
379 impl Migration for Migration0001 {
380 const ID: &'static str = "0001_initial";
381 fn up() -> Vec<Operation> {
382 vec![
383 CreateTableBuilder::new()
384 .name("users")
385 .column(bigint("id").primary_key().autoincrement().build())
386 .column(varchar("username", 255).not_null().build())
387 .build()
388 .into(),
389 ]
390 }
391 fn down() -> Vec<Operation> {
392 vec![Operation::drop_table("users")]
393 }
394 }
395
396 struct Migration0002;
397 impl Migration for Migration0002 {
398 const ID: &'static str = "0002_add_email";
399 const DEPENDENCIES: &'static [&'static str] = &["0001_initial"];
400 fn up() -> Vec<Operation> {
401 vec![Operation::add_column(
402 "users",
403 varchar("email", 255).build(),
404 )]
405 }
406 fn down() -> Vec<Operation> {
407 vec![Operation::drop_column("users", "email")]
408 }
409 }
410
411 struct Migration0003;
412 impl Migration for Migration0003 {
413 const ID: &'static str = "0003_add_active";
414 const DEPENDENCIES: &'static [&'static str] = &["0002_add_email"];
415 fn up() -> Vec<Operation> {
416 vec![Operation::add_column(
417 "users",
418 boolean("active").not_null().default_bool(true).build(),
419 )]
420 }
421 fn down() -> Vec<Operation> {
422 vec![Operation::drop_column("users", "active")]
423 }
424 }
425
426 #[test]
427 fn test_register_migrations() {
428 let mut runner = MigrationRunner::new(SqliteDialect::new());
429 runner.register::<Migration0001>();
430 runner.register::<Migration0002>();
431
432 assert_eq!(runner.migrations().len(), 2);
433 }
434
435 #[test]
436 fn test_pending_migrations() {
437 let mut runner = MigrationRunner::new(SqliteDialect::new());
438 runner.register::<Migration0001>();
439 runner.register::<Migration0002>();
440
441 let state = MigrationState::new();
442 let pending = runner.pending_migrations(&state);
443 assert_eq!(pending.len(), 2);
444
445 let mut state = MigrationState::new();
446 state.mark_applied("0001_initial");
447 let pending = runner.pending_migrations(&state);
448 assert_eq!(pending.len(), 1);
449 assert_eq!(pending[0].id, "0002_add_email");
450 }
451
452 #[test]
453 fn test_topological_sort() {
454 let mut runner = MigrationRunner::new(SqliteDialect::new());
455 runner.register::<Migration0003>();
457 runner.register::<Migration0001>();
458 runner.register::<Migration0002>();
459
460 let sorted = runner.sorted_migrations().unwrap();
461 let ids: Vec<_> = sorted.iter().map(|m| m.id).collect();
462
463 let pos_0001 = ids.iter().position(|&id| id == "0001_initial").unwrap();
465 let pos_0002 = ids.iter().position(|&id| id == "0002_add_email").unwrap();
466 let pos_0003 = ids.iter().position(|&id| id == "0003_add_active").unwrap();
467
468 assert!(pos_0001 < pos_0002);
469 assert!(pos_0002 < pos_0003);
470 }
471
472 #[test]
473 fn test_sql_generation() {
474 let mut runner = MigrationRunner::new(SqliteDialect::new());
475 runner.register::<Migration0001>();
476
477 let state = MigrationState::new();
478 let sql = runner.sql_for_pending(&state).unwrap();
479
480 assert_eq!(sql.len(), 1);
481 assert_eq!(sql[0].0, "0001_initial");
482 assert!(!sql[0].1.is_empty());
483 assert!(sql[0].1[0].contains("CREATE TABLE"));
484 }
485
486 #[test]
487 fn test_rollback_sql() {
488 let mut runner = MigrationRunner::new(SqliteDialect::new());
489 runner.register::<Migration0001>();
490 runner.register::<Migration0002>();
491
492 let mut state = MigrationState::new();
493 state.mark_applied("0001_initial");
494 state.mark_applied("0002_add_email");
495
496 let sql = runner.sql_for_rollback(&state, 1).unwrap();
497 assert_eq!(sql.len(), 1);
498 assert_eq!(sql[0].0, "0002_add_email");
499 assert!(sql[0].1[0].contains("DROP COLUMN"));
500 }
501
502 #[test]
503 fn test_missing_dependency() {
504 struct BadMigration;
505 impl Migration for BadMigration {
506 const ID: &'static str = "bad_migration";
507 const DEPENDENCIES: &'static [&'static str] = &["nonexistent"];
508 fn up() -> Vec<Operation> {
509 vec![]
510 }
511 fn down() -> Vec<Operation> {
512 vec![]
513 }
514 }
515
516 let mut runner = MigrationRunner::new(SqliteDialect::new());
517 runner.register::<BadMigration>();
518
519 let result = runner.validate();
520 assert!(matches!(
521 result,
522 Err(MigrationError::MissingDependency { .. })
523 ));
524 }
525
526 #[test]
527 fn test_status() {
528 let mut runner = MigrationRunner::new(SqliteDialect::new());
529 runner.register::<Migration0001>();
530 runner.register::<Migration0002>();
531
532 let mut state = MigrationState::new();
533 state.mark_applied("0001_initial");
534
535 let status = runner.status(&state);
536 assert_eq!(status.len(), 2);
537
538 let s1 = status.iter().find(|s| s.id == "0001_initial").unwrap();
539 assert!(s1.applied);
540
541 let s2 = status.iter().find(|s| s.id == "0002_add_email").unwrap();
542 assert!(!s2.applied);
543 }
544}