oxide_sql_core/migrations/dialect/
postgres.rs

1//! PostgreSQL dialect for migrations.
2
3use super::MigrationDialect;
4use crate::ast::DataType;
5use crate::migrations::column_builder::{ColumnDefinition, DefaultValue};
6use crate::migrations::operation::{
7    AlterColumnChange, AlterColumnOp, DropIndexOp, RenameColumnOp, RenameTableOp,
8};
9use crate::schema::RustTypeMapping;
10
11/// PostgreSQL dialect for migration SQL generation.
12#[derive(Debug, Clone, Copy, Default)]
13pub struct PostgresDialect;
14
15impl PostgresDialect {
16    /// Creates a new PostgreSQL dialect.
17    #[must_use]
18    pub const fn new() -> Self {
19        Self
20    }
21}
22
23impl MigrationDialect for PostgresDialect {
24    fn name(&self) -> &'static str {
25        "postgresql"
26    }
27
28    fn map_data_type(&self, dt: &DataType) -> String {
29        match dt {
30            DataType::Smallint => "SMALLINT".to_string(),
31            DataType::Integer => "INTEGER".to_string(),
32            DataType::Bigint => "BIGINT".to_string(),
33            DataType::Real => "REAL".to_string(),
34            DataType::Double => "DOUBLE PRECISION".to_string(),
35            DataType::Decimal { precision, scale } => match (precision, scale) {
36                (Some(p), Some(s)) => format!("DECIMAL({p}, {s})"),
37                (Some(p), None) => format!("DECIMAL({p})"),
38                _ => "DECIMAL".to_string(),
39            },
40            DataType::Numeric { precision, scale } => match (precision, scale) {
41                (Some(p), Some(s)) => format!("NUMERIC({p}, {s})"),
42                (Some(p), None) => format!("NUMERIC({p})"),
43                _ => "NUMERIC".to_string(),
44            },
45            DataType::Char(len) => match len {
46                Some(n) => format!("CHAR({n})"),
47                None => "CHAR".to_string(),
48            },
49            DataType::Varchar(len) => match len {
50                Some(n) => format!("VARCHAR({n})"),
51                None => "VARCHAR".to_string(),
52            },
53            DataType::Text => "TEXT".to_string(),
54            DataType::Blob => "BYTEA".to_string(), // PostgreSQL uses BYTEA
55            DataType::Binary(len) => match len {
56                Some(n) => format!("BIT({n})"),
57                None => "BYTEA".to_string(),
58            },
59            DataType::Varbinary(len) => match len {
60                Some(n) => format!("VARBIT({n})"),
61                None => "BYTEA".to_string(),
62            },
63            DataType::Date => "DATE".to_string(),
64            DataType::Time => "TIME".to_string(),
65            DataType::Timestamp => "TIMESTAMP".to_string(),
66            DataType::Datetime => "TIMESTAMP".to_string(), // PostgreSQL uses TIMESTAMP
67            DataType::Boolean => "BOOLEAN".to_string(),
68            DataType::Custom(name) => name.clone(),
69        }
70    }
71
72    fn autoincrement_keyword(&self) -> String {
73        // PostgreSQL uses SERIAL types instead of AUTOINCREMENT keyword
74        // However, when PRIMARY KEY is specified with BIGINT, we don't change the type
75        // The application should use SERIAL/BIGSERIAL types directly
76        String::new()
77    }
78
79    fn column_definition(&self, col: &ColumnDefinition) -> String {
80        // PostgreSQL uses SERIAL/BIGSERIAL for auto-increment
81        let data_type = if col.autoincrement && col.primary_key {
82            match col.data_type {
83                DataType::Integer | DataType::Smallint => "SERIAL".to_string(),
84                DataType::Bigint => "BIGSERIAL".to_string(),
85                _ => self.map_data_type(&col.data_type),
86            }
87        } else {
88            self.map_data_type(&col.data_type)
89        };
90
91        let mut sql = format!("{} {}", self.quote_identifier(&col.name), data_type);
92
93        if col.primary_key {
94            sql.push_str(" PRIMARY KEY");
95        } else {
96            if !col.nullable {
97                sql.push_str(" NOT NULL");
98            }
99            if col.unique {
100                sql.push_str(" UNIQUE");
101            }
102        }
103
104        if let Some(ref default) = col.default {
105            sql.push_str(" DEFAULT ");
106            sql.push_str(&self.render_default(default));
107        }
108
109        if let Some(ref fk) = col.references {
110            sql.push_str(" REFERENCES ");
111            sql.push_str(&self.quote_identifier(&fk.table));
112            sql.push_str(" (");
113            sql.push_str(&self.quote_identifier(&fk.column));
114            sql.push(')');
115            if let Some(action) = fk.on_delete {
116                sql.push_str(" ON DELETE ");
117                sql.push_str(action.as_sql());
118            }
119            if let Some(action) = fk.on_update {
120                sql.push_str(" ON UPDATE ");
121                sql.push_str(action.as_sql());
122            }
123        }
124
125        if let Some(ref check) = col.check {
126            sql.push_str(&format!(" CHECK ({})", check));
127        }
128
129        if let Some(ref collation) = col.collation {
130            sql.push_str(&format!(" COLLATE \"{}\"", collation));
131        }
132
133        sql
134    }
135
136    fn render_default(&self, default: &DefaultValue) -> String {
137        match default {
138            DefaultValue::Boolean(b) => {
139                if *b {
140                    "TRUE".to_string()
141                } else {
142                    "FALSE".to_string()
143                }
144            }
145            _ => default.to_sql(),
146        }
147    }
148
149    fn rename_table(&self, op: &RenameTableOp) -> String {
150        format!(
151            "ALTER TABLE {} RENAME TO {}",
152            self.quote_identifier(&op.old_name),
153            self.quote_identifier(&op.new_name)
154        )
155    }
156
157    fn rename_column(&self, op: &RenameColumnOp) -> String {
158        format!(
159            "ALTER TABLE {} RENAME COLUMN {} TO {}",
160            self.quote_identifier(&op.table),
161            self.quote_identifier(&op.old_name),
162            self.quote_identifier(&op.new_name)
163        )
164    }
165
166    fn alter_column(&self, op: &AlterColumnOp) -> String {
167        let table = self.quote_identifier(&op.table);
168        let column = self.quote_identifier(&op.column);
169
170        match &op.change {
171            AlterColumnChange::SetDataType(dt) => {
172                format!(
173                    "ALTER TABLE {} ALTER COLUMN {} TYPE {}",
174                    table,
175                    column,
176                    self.map_data_type(dt)
177                )
178            }
179            AlterColumnChange::SetNullable(nullable) => {
180                if *nullable {
181                    format!(
182                        "ALTER TABLE {} ALTER COLUMN {} DROP NOT NULL",
183                        table, column
184                    )
185                } else {
186                    format!("ALTER TABLE {} ALTER COLUMN {} SET NOT NULL", table, column)
187                }
188            }
189            AlterColumnChange::SetDefault(default) => {
190                format!(
191                    "ALTER TABLE {} ALTER COLUMN {} SET DEFAULT {}",
192                    table,
193                    column,
194                    self.render_default(default)
195                )
196            }
197            AlterColumnChange::DropDefault => {
198                format!("ALTER TABLE {} ALTER COLUMN {} DROP DEFAULT", table, column)
199            }
200        }
201    }
202
203    fn drop_index(&self, op: &DropIndexOp) -> String {
204        let mut sql = String::from("DROP INDEX ");
205        if op.if_exists {
206            sql.push_str("IF EXISTS ");
207        }
208        sql.push_str(&self.quote_identifier(&op.name));
209        sql
210    }
211
212    fn drop_foreign_key(&self, op: &super::super::operation::DropForeignKeyOp) -> String {
213        format!(
214            "ALTER TABLE {} DROP CONSTRAINT {}",
215            self.quote_identifier(&op.table),
216            self.quote_identifier(&op.name)
217        )
218    }
219}
220
221impl RustTypeMapping for PostgresDialect {
222    fn map_type(&self, rust_type: &str) -> DataType {
223        match rust_type {
224            "bool" => DataType::Boolean,
225            "i8" | "i16" | "u8" | "u16" => DataType::Smallint,
226            "i32" | "u32" => DataType::Integer,
227            "i64" | "u64" | "i128" | "u128" | "isize" | "usize" => DataType::Bigint,
228            "f32" => DataType::Real,
229            "f64" => DataType::Double,
230            "String" => DataType::Varchar(Some(255)),
231            "Vec<u8>" => DataType::Blob,
232            s if s.contains("DateTime") => DataType::Timestamp,
233            s if s.contains("NaiveDate") => DataType::Date,
234            _ => DataType::Text,
235        }
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242    use crate::migrations::column_builder::{bigint, varchar};
243    use crate::migrations::table_builder::CreateTableBuilder;
244
245    #[test]
246    fn test_postgres_data_types() {
247        let dialect = PostgresDialect::new();
248        assert_eq!(dialect.map_data_type(&DataType::Integer), "INTEGER");
249        assert_eq!(dialect.map_data_type(&DataType::Bigint), "BIGINT");
250        assert_eq!(dialect.map_data_type(&DataType::Text), "TEXT");
251        assert_eq!(
252            dialect.map_data_type(&DataType::Varchar(Some(255))),
253            "VARCHAR(255)"
254        );
255        assert_eq!(dialect.map_data_type(&DataType::Blob), "BYTEA");
256        assert_eq!(dialect.map_data_type(&DataType::Boolean), "BOOLEAN");
257        assert_eq!(dialect.map_data_type(&DataType::Timestamp), "TIMESTAMP");
258        assert_eq!(
259            dialect.map_data_type(&DataType::Decimal {
260                precision: Some(10),
261                scale: Some(2)
262            }),
263            "DECIMAL(10, 2)"
264        );
265    }
266
267    #[test]
268    fn test_create_table_with_serial() {
269        let dialect = PostgresDialect::new();
270        let op = CreateTableBuilder::new()
271            .name("users")
272            .column(bigint("id").primary_key().autoincrement().build())
273            .column(varchar("username", 255).not_null().unique().build())
274            .build();
275
276        let sql = dialect.create_table(&op);
277        assert!(sql.contains("CREATE TABLE \"users\""));
278        assert!(sql.contains("\"id\" BIGSERIAL PRIMARY KEY"));
279        assert!(sql.contains("\"username\" VARCHAR(255) NOT NULL UNIQUE"));
280    }
281
282    #[test]
283    fn test_alter_column_sql() {
284        let dialect = PostgresDialect::new();
285
286        // Set NOT NULL
287        let op = AlterColumnOp {
288            table: "users".to_string(),
289            column: "email".to_string(),
290            change: AlterColumnChange::SetNullable(false),
291        };
292        assert_eq!(
293            dialect.alter_column(&op),
294            "ALTER TABLE \"users\" ALTER COLUMN \"email\" SET NOT NULL"
295        );
296
297        // Drop NOT NULL
298        let op = AlterColumnOp {
299            table: "users".to_string(),
300            column: "email".to_string(),
301            change: AlterColumnChange::SetNullable(true),
302        };
303        assert_eq!(
304            dialect.alter_column(&op),
305            "ALTER TABLE \"users\" ALTER COLUMN \"email\" DROP NOT NULL"
306        );
307
308        // Change type
309        let op = AlterColumnOp {
310            table: "users".to_string(),
311            column: "age".to_string(),
312            change: AlterColumnChange::SetDataType(DataType::Bigint),
313        };
314        assert_eq!(
315            dialect.alter_column(&op),
316            "ALTER TABLE \"users\" ALTER COLUMN \"age\" TYPE BIGINT"
317        );
318    }
319
320    #[test]
321    fn test_drop_foreign_key() {
322        let dialect = PostgresDialect::new();
323        let op = super::super::super::operation::DropForeignKeyOp {
324            table: "invoices".to_string(),
325            name: "fk_invoices_user".to_string(),
326        };
327        assert_eq!(
328            dialect.drop_foreign_key(&op),
329            "ALTER TABLE \"invoices\" DROP CONSTRAINT \"fk_invoices_user\""
330        );
331    }
332}