oxide_sql_core/ast/
statement.rs

1//! SQL statement AST types.
2
3use core::fmt;
4
5use super::expression::Expr;
6
7/// Order direction for ORDER BY.
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
9pub enum OrderDirection {
10    /// Ascending order (default).
11    #[default]
12    Asc,
13    /// Descending order.
14    Desc,
15}
16
17impl OrderDirection {
18    /// Returns the SQL representation.
19    #[must_use]
20    pub const fn as_str(&self) -> &'static str {
21        match self {
22            Self::Asc => "ASC",
23            Self::Desc => "DESC",
24        }
25    }
26}
27
28impl fmt::Display for OrderDirection {
29    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30        f.write_str(self.as_str())
31    }
32}
33
34/// Null ordering for ORDER BY.
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum NullOrdering {
37    /// NULLs come first.
38    First,
39    /// NULLs come last.
40    Last,
41}
42
43impl NullOrdering {
44    /// Returns the SQL representation.
45    #[must_use]
46    pub const fn as_str(&self) -> &'static str {
47        match self {
48            Self::First => "NULLS FIRST",
49            Self::Last => "NULLS LAST",
50        }
51    }
52}
53
54impl fmt::Display for NullOrdering {
55    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56        f.write_str(self.as_str())
57    }
58}
59
60/// An ORDER BY clause entry.
61#[derive(Debug, Clone, PartialEq)]
62pub struct OrderBy {
63    /// The expression to order by.
64    pub expr: Expr,
65    /// The direction (ASC or DESC).
66    pub direction: OrderDirection,
67    /// Null ordering (optional).
68    pub nulls: Option<NullOrdering>,
69}
70
71/// Join type.
72#[derive(Debug, Clone, Copy, PartialEq, Eq)]
73pub enum JoinType {
74    /// INNER JOIN.
75    Inner,
76    /// LEFT OUTER JOIN.
77    Left,
78    /// RIGHT OUTER JOIN.
79    Right,
80    /// FULL OUTER JOIN.
81    Full,
82    /// CROSS JOIN.
83    Cross,
84}
85
86impl JoinType {
87    /// Returns the SQL representation.
88    #[must_use]
89    pub const fn as_str(&self) -> &'static str {
90        match self {
91            Self::Inner => "INNER JOIN",
92            Self::Left => "LEFT JOIN",
93            Self::Right => "RIGHT JOIN",
94            Self::Full => "FULL JOIN",
95            Self::Cross => "CROSS JOIN",
96        }
97    }
98}
99
100impl fmt::Display for JoinType {
101    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102        f.write_str(self.as_str())
103    }
104}
105
106/// A JOIN clause.
107#[derive(Debug, Clone, PartialEq)]
108pub struct JoinClause {
109    /// The type of join.
110    pub join_type: JoinType,
111    /// The table to join.
112    pub table: TableRef,
113    /// The join condition (for non-CROSS joins).
114    pub on: Option<Expr>,
115    /// USING columns (alternative to ON).
116    pub using: Vec<String>,
117}
118
119/// A table reference in FROM clause.
120#[derive(Debug, Clone, PartialEq)]
121pub enum TableRef {
122    /// A simple table name.
123    Table {
124        /// Schema name (optional).
125        schema: Option<String>,
126        /// Table name.
127        name: String,
128        /// Alias.
129        alias: Option<String>,
130    },
131    /// A subquery.
132    Subquery {
133        /// The subquery.
134        query: Box<SelectStatement>,
135        /// Alias (required for subqueries).
136        alias: String,
137    },
138    /// A joined table.
139    Join {
140        /// Left side of the join.
141        left: Box<TableRef>,
142        /// The join clause.
143        join: Box<JoinClause>,
144    },
145}
146
147impl TableRef {
148    /// Creates a simple table reference.
149    #[must_use]
150    pub fn table(name: impl Into<String>) -> Self {
151        Self::Table {
152            schema: None,
153            name: name.into(),
154            alias: None,
155        }
156    }
157
158    /// Creates a table reference with schema.
159    #[must_use]
160    pub fn with_schema(schema: impl Into<String>, name: impl Into<String>) -> Self {
161        Self::Table {
162            schema: Some(schema.into()),
163            name: name.into(),
164            alias: None,
165        }
166    }
167
168    /// Adds an alias to this table reference.
169    #[must_use]
170    pub fn alias(self, alias: impl Into<String>) -> Self {
171        match self {
172            Self::Table { schema, name, .. } => Self::Table {
173                schema,
174                name,
175                alias: Some(alias.into()),
176            },
177            Self::Subquery { query, .. } => Self::Subquery {
178                query,
179                alias: alias.into(),
180            },
181            Self::Join { left, join } => Self::Join {
182                left: Box::new((*left).alias(alias)),
183                join,
184            },
185        }
186    }
187}
188
189/// A SELECT statement.
190#[derive(Debug, Clone, PartialEq)]
191pub struct SelectStatement {
192    /// Whether to select DISTINCT values.
193    pub distinct: bool,
194    /// The columns to select.
195    pub columns: Vec<SelectColumn>,
196    /// The FROM clause.
197    pub from: Option<TableRef>,
198    /// The WHERE clause.
199    pub where_clause: Option<Expr>,
200    /// GROUP BY expressions.
201    pub group_by: Vec<Expr>,
202    /// HAVING clause.
203    pub having: Option<Expr>,
204    /// ORDER BY clauses.
205    pub order_by: Vec<OrderBy>,
206    /// LIMIT clause.
207    pub limit: Option<Expr>,
208    /// OFFSET clause.
209    pub offset: Option<Expr>,
210}
211
212/// A column in SELECT clause.
213#[derive(Debug, Clone, PartialEq)]
214pub struct SelectColumn {
215    /// The expression.
216    pub expr: Expr,
217    /// Column alias.
218    pub alias: Option<String>,
219}
220
221impl SelectColumn {
222    /// Creates a new select column.
223    #[must_use]
224    pub fn new(expr: Expr) -> Self {
225        Self { expr, alias: None }
226    }
227
228    /// Creates a select column with an alias.
229    #[must_use]
230    pub fn with_alias(expr: Expr, alias: impl Into<String>) -> Self {
231        Self {
232            expr,
233            alias: Some(alias.into()),
234        }
235    }
236}
237
238/// An INSERT statement.
239#[derive(Debug, Clone, PartialEq)]
240pub struct InsertStatement {
241    /// Schema name.
242    pub schema: Option<String>,
243    /// Table name.
244    pub table: String,
245    /// Column names (optional).
246    pub columns: Vec<String>,
247    /// Values to insert.
248    pub values: InsertSource,
249    /// ON CONFLICT clause (for UPSERT).
250    pub on_conflict: Option<OnConflict>,
251}
252
253/// Source of data for INSERT.
254#[derive(Debug, Clone, PartialEq)]
255pub enum InsertSource {
256    /// VALUES (...), (...), ...
257    Values(Vec<Vec<Expr>>),
258    /// SELECT ...
259    Query(Box<SelectStatement>),
260    /// DEFAULT VALUES
261    DefaultValues,
262}
263
264/// ON CONFLICT clause for UPSERT.
265#[derive(Debug, Clone, PartialEq)]
266pub struct OnConflict {
267    /// Conflict target columns.
268    pub columns: Vec<String>,
269    /// Action to take on conflict.
270    pub action: ConflictAction,
271}
272
273/// Action to take on conflict.
274#[derive(Debug, Clone, PartialEq)]
275pub enum ConflictAction {
276    /// DO NOTHING
277    DoNothing,
278    /// DO UPDATE SET ...
279    DoUpdate(Vec<UpdateAssignment>),
280}
281
282/// An UPDATE statement.
283#[derive(Debug, Clone, PartialEq)]
284pub struct UpdateStatement {
285    /// Schema name.
286    pub schema: Option<String>,
287    /// Table name.
288    pub table: String,
289    /// Alias.
290    pub alias: Option<String>,
291    /// SET assignments.
292    pub assignments: Vec<UpdateAssignment>,
293    /// FROM clause (for joins in UPDATE).
294    pub from: Option<TableRef>,
295    /// WHERE clause.
296    pub where_clause: Option<Expr>,
297}
298
299/// An assignment in UPDATE SET.
300#[derive(Debug, Clone, PartialEq)]
301pub struct UpdateAssignment {
302    /// Column name.
303    pub column: String,
304    /// Value expression.
305    pub value: Expr,
306}
307
308/// A DELETE statement.
309#[derive(Debug, Clone, PartialEq)]
310pub struct DeleteStatement {
311    /// Schema name.
312    pub schema: Option<String>,
313    /// Table name.
314    pub table: String,
315    /// Alias.
316    pub alias: Option<String>,
317    /// WHERE clause.
318    pub where_clause: Option<Expr>,
319}
320
321/// A SQL statement.
322#[derive(Debug, Clone, PartialEq)]
323pub enum Statement {
324    /// SELECT statement.
325    Select(SelectStatement),
326    /// INSERT statement.
327    Insert(InsertStatement),
328    /// UPDATE statement.
329    Update(UpdateStatement),
330    /// DELETE statement.
331    Delete(DeleteStatement),
332}
333
334// ===================================================================
335// Display implementations
336// ===================================================================
337
338impl fmt::Display for OrderBy {
339    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
340        write!(f, "{} {}", self.expr, self.direction)?;
341        if let Some(nulls) = &self.nulls {
342            write!(f, " {nulls}")?;
343        }
344        Ok(())
345    }
346}
347
348impl fmt::Display for JoinClause {
349    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
350        write!(f, "{} {}", self.join_type, self.table)?;
351        if let Some(on) = &self.on {
352            write!(f, " ON {on}")?;
353        }
354        if !self.using.is_empty() {
355            write!(f, " USING (")?;
356            for (i, col) in self.using.iter().enumerate() {
357                if i > 0 {
358                    write!(f, ", ")?;
359                }
360                write!(f, "{col}")?;
361            }
362            write!(f, ")")?;
363        }
364        Ok(())
365    }
366}
367
368impl fmt::Display for TableRef {
369    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
370        match self {
371            Self::Table {
372                schema,
373                name,
374                alias,
375            } => {
376                if let Some(s) = schema {
377                    write!(f, "{s}.")?;
378                }
379                write!(f, "{name}")?;
380                if let Some(a) = alias {
381                    write!(f, " AS {a}")?;
382                }
383                Ok(())
384            }
385            Self::Subquery { query, alias } => {
386                write!(f, "({query}) AS {alias}")
387            }
388            Self::Join { left, join } => {
389                write!(f, "{left} {join}")
390            }
391        }
392    }
393}
394
395impl fmt::Display for SelectColumn {
396    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
397        write!(f, "{}", self.expr)?;
398        if let Some(a) = &self.alias {
399            write!(f, " AS {a}")?;
400        }
401        Ok(())
402    }
403}
404
405impl fmt::Display for SelectStatement {
406    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
407        write!(f, "SELECT")?;
408        if self.distinct {
409            write!(f, " DISTINCT")?;
410        }
411        for (i, col) in self.columns.iter().enumerate() {
412            if i > 0 {
413                write!(f, ",")?;
414            }
415            write!(f, " {col}")?;
416        }
417        if let Some(from) = &self.from {
418            write!(f, " FROM {from}")?;
419        }
420        if let Some(w) = &self.where_clause {
421            write!(f, " WHERE {w}")?;
422        }
423        if !self.group_by.is_empty() {
424            write!(f, " GROUP BY")?;
425            for (i, g) in self.group_by.iter().enumerate() {
426                if i > 0 {
427                    write!(f, ",")?;
428                }
429                write!(f, " {g}")?;
430            }
431        }
432        if let Some(h) = &self.having {
433            write!(f, " HAVING {h}")?;
434        }
435        if !self.order_by.is_empty() {
436            write!(f, " ORDER BY")?;
437            for (i, o) in self.order_by.iter().enumerate() {
438                if i > 0 {
439                    write!(f, ",")?;
440                }
441                write!(f, " {o}")?;
442            }
443        }
444        if let Some(l) = &self.limit {
445            write!(f, " LIMIT {l}")?;
446        }
447        if let Some(o) = &self.offset {
448            write!(f, " OFFSET {o}")?;
449        }
450        Ok(())
451    }
452}
453
454impl fmt::Display for InsertSource {
455    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
456        match self {
457            Self::Values(rows) => {
458                write!(f, "VALUES")?;
459                for (i, row) in rows.iter().enumerate() {
460                    if i > 0 {
461                        write!(f, ",")?;
462                    }
463                    write!(f, " (")?;
464                    for (j, val) in row.iter().enumerate() {
465                        if j > 0 {
466                            write!(f, ", ")?;
467                        }
468                        write!(f, "{val}")?;
469                    }
470                    write!(f, ")")?;
471                }
472                Ok(())
473            }
474            Self::Query(q) => write!(f, "{q}"),
475            Self::DefaultValues => write!(f, "DEFAULT VALUES"),
476        }
477    }
478}
479
480impl fmt::Display for OnConflict {
481    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
482        write!(f, "ON CONFLICT (")?;
483        for (i, col) in self.columns.iter().enumerate() {
484            if i > 0 {
485                write!(f, ", ")?;
486            }
487            write!(f, "{col}")?;
488        }
489        write!(f, ") {}", self.action)
490    }
491}
492
493impl fmt::Display for ConflictAction {
494    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
495        match self {
496            Self::DoNothing => write!(f, "DO NOTHING"),
497            Self::DoUpdate(assignments) => {
498                write!(f, "DO UPDATE SET")?;
499                for (i, a) in assignments.iter().enumerate() {
500                    if i > 0 {
501                        write!(f, ",")?;
502                    }
503                    write!(f, " {a}")?;
504                }
505                Ok(())
506            }
507        }
508    }
509}
510
511impl fmt::Display for InsertStatement {
512    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
513        write!(f, "INSERT INTO ")?;
514        if let Some(s) = &self.schema {
515            write!(f, "{s}.")?;
516        }
517        write!(f, "{}", self.table)?;
518        if !self.columns.is_empty() {
519            write!(f, " (")?;
520            for (i, col) in self.columns.iter().enumerate() {
521                if i > 0 {
522                    write!(f, ", ")?;
523                }
524                write!(f, "{col}")?;
525            }
526            write!(f, ")")?;
527        }
528        write!(f, " {}", self.values)?;
529        if let Some(oc) = &self.on_conflict {
530            write!(f, " {oc}")?;
531        }
532        Ok(())
533    }
534}
535
536impl fmt::Display for UpdateAssignment {
537    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
538        write!(f, "{} = {}", self.column, self.value)
539    }
540}
541
542impl fmt::Display for UpdateStatement {
543    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
544        write!(f, "UPDATE ")?;
545        if let Some(s) = &self.schema {
546            write!(f, "{s}.")?;
547        }
548        write!(f, "{}", self.table)?;
549        if let Some(a) = &self.alias {
550            write!(f, " AS {a}")?;
551        }
552        write!(f, " SET")?;
553        for (i, a) in self.assignments.iter().enumerate() {
554            if i > 0 {
555                write!(f, ",")?;
556            }
557            write!(f, " {a}")?;
558        }
559        if let Some(from) = &self.from {
560            write!(f, " FROM {from}")?;
561        }
562        if let Some(w) = &self.where_clause {
563            write!(f, " WHERE {w}")?;
564        }
565        Ok(())
566    }
567}
568
569impl fmt::Display for DeleteStatement {
570    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
571        write!(f, "DELETE FROM ")?;
572        if let Some(s) = &self.schema {
573            write!(f, "{s}.")?;
574        }
575        write!(f, "{}", self.table)?;
576        if let Some(a) = &self.alias {
577            write!(f, " AS {a}")?;
578        }
579        if let Some(w) = &self.where_clause {
580            write!(f, " WHERE {w}")?;
581        }
582        Ok(())
583    }
584}
585
586impl fmt::Display for Statement {
587    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
588        match self {
589            Self::Select(s) => write!(f, "{s}"),
590            Self::Insert(i) => write!(f, "{i}"),
591            Self::Update(u) => write!(f, "{u}"),
592            Self::Delete(d) => write!(f, "{d}"),
593        }
594    }
595}
596
597#[cfg(test)]
598mod tests {
599    use super::*;
600
601    #[test]
602    fn test_order_direction() {
603        assert_eq!(OrderDirection::Asc.as_str(), "ASC");
604        assert_eq!(OrderDirection::Desc.as_str(), "DESC");
605    }
606
607    #[test]
608    fn test_join_type() {
609        assert_eq!(JoinType::Inner.as_str(), "INNER JOIN");
610        assert_eq!(JoinType::Left.as_str(), "LEFT JOIN");
611    }
612
613    #[test]
614    fn test_table_ref_builder() {
615        let table = TableRef::table("users").alias("u");
616        assert!(
617            matches!(table, TableRef::Table { name, alias, .. } if name == "users" && alias == Some(String::from("u")))
618        );
619    }
620}