1use std::collections::BTreeSet;
8
9use crate::schema::{RustTypeMapping, TableSchema};
10
11use super::column_builder::ColumnDefinition;
12use super::operation::{
13 AddColumnOp, AlterColumnChange, AlterColumnOp, CreateTableOp, DropColumnOp, DropTableOp,
14 Operation,
15};
16use super::snapshot::{ColumnSnapshot, SchemaSnapshot, TableSnapshot};
17
18#[derive(Debug, Clone, PartialEq)]
21pub enum AmbiguousChange {
22 PossibleRename {
25 table: String,
27 old_column: String,
29 new_column: String,
31 },
32 PossibleTableRename {
35 old_table: String,
37 new_table: String,
39 },
40}
41
42#[derive(Debug, Clone, PartialEq)]
44pub struct SchemaDiff {
45 pub operations: Vec<Operation>,
47 pub ambiguous: Vec<AmbiguousChange>,
49}
50
51impl SchemaDiff {
52 #[must_use]
54 pub fn is_empty(&self) -> bool {
55 self.operations.is_empty() && self.ambiguous.is_empty()
56 }
57}
58
59fn diff_table(table_name: &str, old: &TableSnapshot, new: &TableSnapshot) -> SchemaDiff {
62 let old_names: BTreeSet<&str> = old.columns.iter().map(|c| c.name.as_str()).collect();
63 let new_names: BTreeSet<&str> = new.columns.iter().map(|c| c.name.as_str()).collect();
64
65 let dropped: Vec<&str> = old_names.difference(&new_names).copied().collect();
66 let added: Vec<&str> = new_names.difference(&old_names).copied().collect();
67 let common: BTreeSet<&str> = old_names.intersection(&new_names).copied().collect();
68
69 let mut operations = Vec::new();
70 let mut ambiguous = Vec::new();
71
72 let mut rename_dropped = BTreeSet::new();
75 let mut rename_added = BTreeSet::new();
76
77 if dropped.len() == 1 && added.len() == 1 {
78 let old_col = old.column(dropped[0]).unwrap();
79 let new_col = new.column(added[0]).unwrap();
80 if old_col.data_type == new_col.data_type {
81 ambiguous.push(AmbiguousChange::PossibleRename {
82 table: table_name.to_string(),
83 old_column: dropped[0].to_string(),
84 new_column: added[0].to_string(),
85 });
86 rename_dropped.insert(dropped[0]);
87 rename_added.insert(added[0]);
88 }
89 }
90
91 for &name in &added {
93 if rename_added.contains(name) {
94 continue;
95 }
96 let col = new.column(name).unwrap();
97 operations.push(Operation::AddColumn(AddColumnOp {
98 table: table_name.to_string(),
99 column: snapshot_to_column_def(col),
100 }));
101 }
102
103 for &name in &common {
105 let old_col = old.column(name).unwrap();
106 let new_col = new.column(name).unwrap();
107
108 if old_col.data_type != new_col.data_type {
109 operations.push(Operation::AlterColumn(AlterColumnOp {
110 table: table_name.to_string(),
111 column: name.to_string(),
112 change: AlterColumnChange::SetDataType(new_col.data_type.clone()),
113 }));
114 }
115
116 if old_col.nullable != new_col.nullable {
117 operations.push(Operation::AlterColumn(AlterColumnOp {
118 table: table_name.to_string(),
119 column: name.to_string(),
120 change: AlterColumnChange::SetNullable(new_col.nullable),
121 }));
122 }
123
124 match (&old_col.default, &new_col.default) {
125 (None, Some(new_default)) => {
126 operations.push(Operation::AlterColumn(AlterColumnOp {
127 table: table_name.to_string(),
128 column: name.to_string(),
129 change: AlterColumnChange::SetDefault(new_default.clone()),
130 }));
131 }
132 (Some(_), None) => {
133 operations.push(Operation::AlterColumn(AlterColumnOp {
134 table: table_name.to_string(),
135 column: name.to_string(),
136 change: AlterColumnChange::DropDefault,
137 }));
138 }
139 (Some(old_def), Some(new_def)) if old_def != new_def => {
140 operations.push(Operation::AlterColumn(AlterColumnOp {
141 table: table_name.to_string(),
142 column: name.to_string(),
143 change: AlterColumnChange::SetDefault(new_def.clone()),
144 }));
145 }
146 _ => {}
147 }
148 }
149
150 for &name in &dropped {
152 if rename_dropped.contains(name) {
153 continue;
154 }
155 operations.push(Operation::DropColumn(DropColumnOp {
156 table: table_name.to_string(),
157 column: name.to_string(),
158 }));
159 }
160
161 SchemaDiff {
162 operations,
163 ambiguous,
164 }
165}
166
167fn snapshot_to_column_def(col: &ColumnSnapshot) -> ColumnDefinition {
170 ColumnDefinition {
171 name: col.name.clone(),
172 data_type: col.data_type.clone(),
173 nullable: col.nullable,
174 default: col.default.clone(),
175 primary_key: col.primary_key,
176 unique: col.unique,
177 autoincrement: col.autoincrement,
178 references: None,
179 check: None,
180 collation: None,
181 }
182}
183
184pub fn auto_diff_schema(current: &SchemaSnapshot, desired: &SchemaSnapshot) -> SchemaDiff {
190 let current_tables: BTreeSet<&str> = current.tables.keys().map(String::as_str).collect();
191 let desired_tables: BTreeSet<&str> = desired.tables.keys().map(String::as_str).collect();
192
193 let dropped_tables: Vec<&str> = current_tables
194 .difference(&desired_tables)
195 .copied()
196 .collect();
197 let added_tables: Vec<&str> = desired_tables
198 .difference(¤t_tables)
199 .copied()
200 .collect();
201 let common_tables: Vec<&str> = current_tables
202 .intersection(&desired_tables)
203 .copied()
204 .collect();
205
206 let mut create_ops = Vec::new();
207 let mut add_ops = Vec::new();
208 let mut alter_ops = Vec::new();
209 let mut drop_col_ops = Vec::new();
210 let mut drop_table_ops = Vec::new();
211 let mut ambiguous = Vec::new();
212
213 let mut rename_dropped = BTreeSet::new();
215 let mut rename_added = BTreeSet::new();
216
217 if dropped_tables.len() == 1 && added_tables.len() == 1 {
218 let old_table = ¤t.tables[dropped_tables[0]];
219 let new_table = &desired.tables[added_tables[0]];
220 if tables_have_same_columns(old_table, new_table) {
221 ambiguous.push(AmbiguousChange::PossibleTableRename {
222 old_table: dropped_tables[0].to_string(),
223 new_table: added_tables[0].to_string(),
224 });
225 rename_dropped.insert(dropped_tables[0]);
226 rename_added.insert(added_tables[0]);
227 }
228 }
229
230 for &name in &added_tables {
232 if rename_added.contains(name) {
233 continue;
234 }
235 let table = &desired.tables[name];
236 let columns = table.columns.iter().map(snapshot_to_column_def).collect();
237 create_ops.push(Operation::CreateTable(CreateTableOp {
238 name: name.to_string(),
239 columns,
240 constraints: vec![],
241 if_not_exists: false,
242 }));
243 }
244
245 for &name in &common_tables {
247 let old_table = ¤t.tables[name];
248 let new_table = &desired.tables[name];
249 let table_diff = diff_table(name, old_table, new_table);
250
251 for op in table_diff.operations {
252 match &op {
253 Operation::AddColumn(_) => add_ops.push(op),
254 Operation::AlterColumn(_) => alter_ops.push(op),
255 Operation::DropColumn(_) => drop_col_ops.push(op),
256 _ => add_ops.push(op),
257 }
258 }
259 ambiguous.extend(table_diff.ambiguous);
260 }
261
262 for &name in &dropped_tables {
264 if rename_dropped.contains(name) {
265 continue;
266 }
267 drop_table_ops.push(Operation::DropTable(DropTableOp {
268 name: name.to_string(),
269 if_exists: false,
270 cascade: false,
271 }));
272 }
273
274 let mut operations = Vec::new();
276 operations.extend(create_ops);
277 operations.extend(add_ops);
278 operations.extend(alter_ops);
279 operations.extend(drop_col_ops);
280 operations.extend(drop_table_ops);
281
282 SchemaDiff {
283 operations,
284 ambiguous,
285 }
286}
287
288pub fn auto_diff_table<T: TableSchema>(
291 current: &TableSnapshot,
292 dialect: &impl RustTypeMapping,
293) -> SchemaDiff {
294 let desired = TableSnapshot::from_table_schema::<T>(dialect);
295 diff_table(&desired.name, current, &desired)
296}
297
298fn tables_have_same_columns(a: &TableSnapshot, b: &TableSnapshot) -> bool {
301 if a.columns.len() != b.columns.len() {
302 return false;
303 }
304 a.columns.iter().zip(b.columns.iter()).all(|(ac, bc)| {
305 ac.name == bc.name
306 && ac.data_type == bc.data_type
307 && ac.nullable == bc.nullable
308 && ac.primary_key == bc.primary_key
309 && ac.unique == bc.unique
310 && ac.autoincrement == bc.autoincrement
311 && ac.default == bc.default
312 })
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318 use crate::ast::DataType;
319 use crate::migrations::column_builder::DefaultValue;
320
321 fn col(name: &str, data_type: DataType, nullable: bool) -> ColumnSnapshot {
326 ColumnSnapshot {
327 name: name.to_string(),
328 data_type,
329 nullable,
330 primary_key: false,
331 unique: false,
332 autoincrement: false,
333 default: None,
334 }
335 }
336
337 fn pk_col(name: &str, data_type: DataType) -> ColumnSnapshot {
338 ColumnSnapshot {
339 name: name.to_string(),
340 data_type,
341 nullable: false,
342 primary_key: true,
343 unique: false,
344 autoincrement: true,
345 default: None,
346 }
347 }
348
349 fn table(name: &str, columns: Vec<ColumnSnapshot>) -> TableSnapshot {
350 TableSnapshot {
351 name: name.to_string(),
352 columns,
353 }
354 }
355
356 fn schema(tables: Vec<TableSnapshot>) -> SchemaSnapshot {
357 let mut s = SchemaSnapshot::new();
358 for t in tables {
359 s.add_table(t);
360 }
361 s
362 }
363
364 #[test]
369 fn no_changes_produces_empty_diff() {
370 let t = table(
371 "users",
372 vec![
373 pk_col("id", DataType::Bigint),
374 col("name", DataType::Text, false),
375 ],
376 );
377 let diff = diff_table("users", &t, &t);
378 assert!(diff.is_empty());
379 }
380
381 #[test]
382 fn new_table_detected() {
383 let current = schema(vec![]);
384 let desired = schema(vec![table("users", vec![pk_col("id", DataType::Bigint)])]);
385 let diff = auto_diff_schema(¤t, &desired);
386
387 assert_eq!(diff.operations.len(), 1);
388 match &diff.operations[0] {
389 Operation::CreateTable(op) => {
390 assert_eq!(op.name, "users");
391 }
392 other => panic!("Expected CreateTable, got {other:?}"),
393 }
394 }
395
396 #[test]
397 fn dropped_table_detected() {
398 let current = schema(vec![table("users", vec![pk_col("id", DataType::Bigint)])]);
399 let desired = schema(vec![]);
400 let diff = auto_diff_schema(¤t, &desired);
401
402 assert_eq!(diff.operations.len(), 1);
403 match &diff.operations[0] {
404 Operation::DropTable(op) => {
405 assert_eq!(op.name, "users");
406 }
407 other => panic!("Expected DropTable, got {other:?}"),
408 }
409 }
410
411 #[test]
412 fn added_column_detected() {
413 let old = table("users", vec![pk_col("id", DataType::Bigint)]);
414 let new = table(
415 "users",
416 vec![
417 pk_col("id", DataType::Bigint),
418 col("email", DataType::Text, true),
419 ],
420 );
421 let diff = diff_table("users", &old, &new);
422
423 assert_eq!(diff.operations.len(), 1);
424 match &diff.operations[0] {
425 Operation::AddColumn(op) => {
426 assert_eq!(op.table, "users");
427 assert_eq!(op.column.name, "email");
428 }
429 other => panic!("Expected AddColumn, got {other:?}"),
430 }
431 }
432
433 #[test]
434 fn dropped_column_detected() {
435 let old = table(
436 "users",
437 vec![
438 pk_col("id", DataType::Bigint),
439 col("email", DataType::Text, true),
440 ],
441 );
442 let new = table("users", vec![pk_col("id", DataType::Bigint)]);
443 let diff = diff_table("users", &old, &new);
444
445 assert_eq!(diff.operations.len(), 1);
446 match &diff.operations[0] {
447 Operation::DropColumn(op) => {
448 assert_eq!(op.table, "users");
449 assert_eq!(op.column, "email");
450 }
451 other => panic!("Expected DropColumn, got {other:?}"),
452 }
453 }
454
455 #[test]
456 fn type_change_detected() {
457 let old = table(
458 "users",
459 vec![
460 pk_col("id", DataType::Bigint),
461 col("score", DataType::Integer, false),
462 ],
463 );
464 let new = table(
465 "users",
466 vec![
467 pk_col("id", DataType::Bigint),
468 col("score", DataType::Bigint, false),
469 ],
470 );
471 let diff = diff_table("users", &old, &new);
472
473 assert_eq!(diff.operations.len(), 1);
474 match &diff.operations[0] {
475 Operation::AlterColumn(op) => {
476 assert_eq!(op.column, "score");
477 assert_eq!(op.change, AlterColumnChange::SetDataType(DataType::Bigint));
478 }
479 other => panic!("Expected AlterColumn, got {other:?}"),
480 }
481 }
482
483 #[test]
484 fn nullable_change_detected() {
485 let old = table(
486 "users",
487 vec![
488 pk_col("id", DataType::Bigint),
489 col("email", DataType::Text, false),
490 ],
491 );
492 let new = table(
493 "users",
494 vec![
495 pk_col("id", DataType::Bigint),
496 col("email", DataType::Text, true),
497 ],
498 );
499 let diff = diff_table("users", &old, &new);
500
501 assert_eq!(diff.operations.len(), 1);
502 match &diff.operations[0] {
503 Operation::AlterColumn(op) => {
504 assert_eq!(op.column, "email");
505 assert_eq!(op.change, AlterColumnChange::SetNullable(true));
506 }
507 other => panic!("Expected AlterColumn, got {other:?}"),
508 }
509 }
510
511 #[test]
512 fn default_added() {
513 let old = table("t", vec![col("active", DataType::Boolean, false)]);
514 let mut new_col = col("active", DataType::Boolean, false);
515 new_col.default = Some(DefaultValue::Expression("TRUE".into()));
516 let new = table("t", vec![new_col]);
517 let diff = diff_table("t", &old, &new);
518
519 assert_eq!(diff.operations.len(), 1);
520 match &diff.operations[0] {
521 Operation::AlterColumn(op) => {
522 assert_eq!(
523 op.change,
524 AlterColumnChange::SetDefault(DefaultValue::Expression("TRUE".into()))
525 );
526 }
527 other => panic!("Expected AlterColumn, got {other:?}"),
528 }
529 }
530
531 #[test]
532 fn default_changed() {
533 let mut old_col = col("count", DataType::Integer, false);
534 old_col.default = Some(DefaultValue::Integer(0));
535 let old = table("t", vec![old_col]);
536
537 let mut new_col = col("count", DataType::Integer, false);
538 new_col.default = Some(DefaultValue::Integer(1));
539 let new = table("t", vec![new_col]);
540
541 let diff = diff_table("t", &old, &new);
542 assert_eq!(diff.operations.len(), 1);
543 match &diff.operations[0] {
544 Operation::AlterColumn(op) => {
545 assert_eq!(
546 op.change,
547 AlterColumnChange::SetDefault(DefaultValue::Integer(1))
548 );
549 }
550 other => panic!("Expected AlterColumn, got {other:?}"),
551 }
552 }
553
554 #[test]
555 fn default_removed() {
556 let mut old_col = col("active", DataType::Boolean, false);
557 old_col.default = Some(DefaultValue::Expression("TRUE".into()));
558 let old = table("t", vec![old_col]);
559 let new = table("t", vec![col("active", DataType::Boolean, false)]);
560 let diff = diff_table("t", &old, &new);
561
562 assert_eq!(diff.operations.len(), 1);
563 match &diff.operations[0] {
564 Operation::AlterColumn(op) => {
565 assert_eq!(op.change, AlterColumnChange::DropDefault);
566 }
567 other => panic!("Expected AlterColumn, got {other:?}"),
568 }
569 }
570
571 #[test]
572 fn ambiguous_rename_detected() {
573 let old = table(
574 "users",
575 vec![
576 pk_col("id", DataType::Bigint),
577 col("name", DataType::Text, false),
578 ],
579 );
580 let new = table(
581 "users",
582 vec![
583 pk_col("id", DataType::Bigint),
584 col("full_name", DataType::Text, false),
585 ],
586 );
587 let diff = diff_table("users", &old, &new);
588
589 assert!(diff.operations.is_empty());
591 assert_eq!(diff.ambiguous.len(), 1);
592 match &diff.ambiguous[0] {
593 AmbiguousChange::PossibleRename {
594 table,
595 old_column,
596 new_column,
597 } => {
598 assert_eq!(table, "users");
599 assert_eq!(old_column, "name");
600 assert_eq!(new_column, "full_name");
601 }
602 other => panic!("Expected PossibleRename, got {other:?}"),
603 }
604 }
605
606 #[test]
607 fn ambiguous_rename_not_triggered_different_types() {
608 let old = table(
609 "users",
610 vec![
611 pk_col("id", DataType::Bigint),
612 col("name", DataType::Text, false),
613 ],
614 );
615 let new = table(
616 "users",
617 vec![
618 pk_col("id", DataType::Bigint),
619 col("full_name", DataType::Integer, false),
620 ],
621 );
622 let diff = diff_table("users", &old, &new);
623
624 assert!(diff.ambiguous.is_empty());
626 assert_eq!(diff.operations.len(), 2);
627 }
628
629 #[test]
630 fn multiple_changes_combined() {
631 let old = table(
632 "users",
633 vec![
634 pk_col("id", DataType::Bigint),
635 col("name", DataType::Text, false),
636 col("old_field", DataType::Integer, false),
637 ],
638 );
639 let new = table(
640 "users",
641 vec![
642 pk_col("id", DataType::Bigint),
643 col("name", DataType::Varchar(Some(255)), true),
644 col("new_field", DataType::Boolean, false),
645 ],
646 );
647 let diff = diff_table("users", &old, &new);
648
649 assert!(diff.ambiguous.is_empty());
653 assert_eq!(diff.operations.len(), 4);
656 }
657
658 #[test]
659 fn operation_ordering_in_schema_diff() {
660 let current = schema(vec![
665 table(
666 "to_drop",
667 vec![
668 pk_col("id", DataType::Bigint),
669 col("legacy", DataType::Text, false),
670 ],
671 ),
672 table(
673 "to_alter",
674 vec![
675 pk_col("id", DataType::Bigint),
676 col("old_a", DataType::Text, false),
677 col("old_b", DataType::Integer, false),
678 ],
679 ),
680 ]);
681 let desired = schema(vec![
682 table("to_create", vec![pk_col("id", DataType::Bigint)]),
683 table(
684 "to_alter",
685 vec![
686 pk_col("id", DataType::Bigint),
687 col("new_a", DataType::Text, false),
688 col("new_b", DataType::Integer, false),
689 ],
690 ),
691 ]);
692 let diff = auto_diff_schema(¤t, &desired);
693
694 let mut saw_create = false;
697 let mut saw_add = false;
698 let mut saw_drop_col = false;
699 let mut saw_drop_table = false;
700
701 for op in &diff.operations {
702 match op {
703 Operation::CreateTable(_) => {
704 assert!(!saw_add && !saw_drop_col && !saw_drop_table);
705 saw_create = true;
706 }
707 Operation::AddColumn(_) => {
708 assert!(
709 !saw_drop_col && !saw_drop_table,
710 "AddColumn must come before DropColumn/DropTable"
711 );
712 saw_add = true;
713 }
714 Operation::DropColumn(_) => {
715 assert!(!saw_drop_table, "DropColumn must come before DropTable");
716 saw_drop_col = true;
717 }
718 Operation::DropTable(_) => {
719 saw_drop_table = true;
720 }
721 _ => {}
722 }
723 }
724
725 assert!(saw_create);
726 assert!(saw_add);
727 assert!(saw_drop_col);
728 assert!(saw_drop_table);
729 }
730
731 #[test]
732 fn possible_table_rename_detected() {
733 let current = schema(vec![table(
734 "users",
735 vec![
736 pk_col("id", DataType::Bigint),
737 col("name", DataType::Text, false),
738 ],
739 )]);
740 let desired = schema(vec![table(
741 "accounts",
742 vec![
743 pk_col("id", DataType::Bigint),
744 col("name", DataType::Text, false),
745 ],
746 )]);
747 let diff = auto_diff_schema(¤t, &desired);
748
749 assert!(diff.operations.is_empty());
750 assert_eq!(diff.ambiguous.len(), 1);
751 match &diff.ambiguous[0] {
752 AmbiguousChange::PossibleTableRename {
753 old_table,
754 new_table,
755 } => {
756 assert_eq!(old_table, "users");
757 assert_eq!(new_table, "accounts");
758 }
759 other => {
760 panic!("Expected PossibleTableRename, got {other:?}")
761 }
762 }
763 }
764
765 #[test]
766 fn auto_diff_table_works() {
767 use crate::migrations::SqliteDialect;
768 use crate::schema::{ColumnSchema, Table};
769
770 struct MyTable;
771 struct MyRow;
772
773 impl Table for MyTable {
774 type Row = MyRow;
775 const NAME: &'static str = "items";
776 const COLUMNS: &'static [&'static str] = &["id", "title"];
777 const PRIMARY_KEY: Option<&'static str> = Some("id");
778 }
779
780 impl TableSchema for MyTable {
781 const SCHEMA: &'static [ColumnSchema] = &[
782 ColumnSchema {
783 name: "id",
784 rust_type: "i64",
785 nullable: false,
786 primary_key: true,
787 unique: false,
788 autoincrement: true,
789 default_expr: None,
790 },
791 ColumnSchema {
792 name: "title",
793 rust_type: "String",
794 nullable: false,
795 primary_key: false,
796 unique: false,
797 autoincrement: false,
798 default_expr: None,
799 },
800 ];
801 }
802
803 let dialect = SqliteDialect::new();
804 let current = table("items", vec![pk_col("id", DataType::Bigint)]);
805 let diff = auto_diff_table::<MyTable>(¤t, &dialect);
806
807 assert_eq!(diff.operations.len(), 1);
809 match &diff.operations[0] {
810 Operation::AddColumn(op) => {
811 assert_eq!(op.column.name, "title");
812 assert_eq!(op.column.data_type, DataType::Text);
813 }
814 other => panic!("Expected AddColumn, got {other:?}"),
815 }
816 }
817}