oxide_sql_core/parser/
core.rs

1//! SQL Parser implementation.
2
3use super::error::ParseError;
4use super::pratt::{
5    infix_binding_power, prefix_binding_power, token_to_binary_op, token_to_unary_op,
6};
7use crate::ast::{
8    DataType, DeleteStatement, Expr, FunctionCall, InsertSource, InsertStatement, JoinClause,
9    JoinType, Literal, OrderBy, OrderDirection, SelectColumn, SelectStatement, Statement, TableRef,
10    UpdateAssignment, UpdateStatement,
11};
12use crate::lexer::{Keyword, Lexer, Span, Token, TokenKind};
13
14/// SQL Parser.
15pub struct Parser<'a> {
16    lexer: Lexer<'a>,
17    current: Token,
18    previous: Token,
19    /// Parameter counter for ? placeholders.
20    param_counter: usize,
21}
22
23impl<'a> Parser<'a> {
24    /// Creates a new parser for the given input.
25    #[must_use]
26    pub fn new(input: &'a str) -> Self {
27        let mut lexer = Lexer::new(input);
28        let current = lexer.next_token();
29        Self {
30            lexer,
31            current,
32            previous: Token::new(TokenKind::Eof, Span::new(0, 0)),
33            param_counter: 0,
34        }
35    }
36
37    /// Parses a single SQL statement.
38    ///
39    /// # Errors
40    ///
41    /// Returns a `ParseError` if the input is not a valid SQL statement.
42    pub fn parse_statement(&mut self) -> Result<Statement, ParseError> {
43        match &self.current.kind {
44            TokenKind::Keyword(Keyword::Select) => {
45                Ok(Statement::Select(self.parse_select_statement()?))
46            }
47            TokenKind::Keyword(Keyword::Insert) => {
48                Ok(Statement::Insert(self.parse_insert_statement()?))
49            }
50            TokenKind::Keyword(Keyword::Update) => {
51                Ok(Statement::Update(self.parse_update_statement()?))
52            }
53            TokenKind::Keyword(Keyword::Delete) => {
54                Ok(Statement::Delete(self.parse_delete_statement()?))
55            }
56            _ => Err(ParseError::unexpected(
57                "SELECT, INSERT, UPDATE, or DELETE",
58                self.current.kind.clone(),
59                self.current.span,
60            )),
61        }
62    }
63
64    /// Parses a SELECT statement.
65    fn parse_select_statement(&mut self) -> Result<SelectStatement, ParseError> {
66        self.expect_keyword(Keyword::Select)?;
67
68        // DISTINCT or ALL
69        let distinct = if self.check_keyword(Keyword::Distinct) {
70            self.advance();
71            true
72        } else if self.check_keyword(Keyword::All) {
73            self.advance();
74            false
75        } else {
76            false
77        };
78
79        // SELECT columns
80        let columns = self.parse_select_columns()?;
81
82        // FROM clause (optional for expressions like SELECT 1+1)
83        let from = if self.check_keyword(Keyword::From) {
84            self.advance();
85            Some(self.parse_table_ref()?)
86        } else {
87            None
88        };
89
90        // WHERE clause
91        let where_clause = if self.check_keyword(Keyword::Where) {
92            self.advance();
93            Some(self.parse_expression(0)?)
94        } else {
95            None
96        };
97
98        // GROUP BY clause
99        let group_by = if self.check_keyword(Keyword::Group) {
100            self.advance();
101            self.expect_keyword(Keyword::By)?;
102            self.parse_expression_list()?
103        } else {
104            vec![]
105        };
106
107        // HAVING clause
108        let having = if self.check_keyword(Keyword::Having) {
109            self.advance();
110            Some(self.parse_expression(0)?)
111        } else {
112            None
113        };
114
115        // ORDER BY clause
116        let order_by = if self.check_keyword(Keyword::Order) {
117            self.advance();
118            self.expect_keyword(Keyword::By)?;
119            self.parse_order_by_list()?
120        } else {
121            vec![]
122        };
123
124        // LIMIT clause
125        let limit = if self.check_keyword(Keyword::Limit) {
126            self.advance();
127            Some(self.parse_expression(0)?)
128        } else {
129            None
130        };
131
132        // OFFSET clause
133        let offset = if self.check_keyword(Keyword::Offset) {
134            self.advance();
135            Some(self.parse_expression(0)?)
136        } else {
137            None
138        };
139
140        Ok(SelectStatement {
141            distinct,
142            columns,
143            from,
144            where_clause,
145            group_by,
146            having,
147            order_by,
148            limit,
149            offset,
150        })
151    }
152
153    /// Parses SELECT columns.
154    fn parse_select_columns(&mut self) -> Result<Vec<SelectColumn>, ParseError> {
155        let mut columns = vec![];
156
157        loop {
158            let expr = self.parse_expression(0)?;
159
160            // Check for alias (AS name or just name)
161            let alias = if self.check_keyword(Keyword::As) {
162                self.advance();
163                Some(self.expect_identifier()?)
164            } else if matches!(&self.current.kind, TokenKind::Identifier(_)) {
165                Some(self.expect_identifier()?)
166            } else {
167                None
168            };
169
170            columns.push(SelectColumn { expr, alias });
171
172            if !self.check(&TokenKind::Comma) {
173                break;
174            }
175            self.advance();
176        }
177
178        Ok(columns)
179    }
180
181    /// Parses a table reference.
182    fn parse_table_ref(&mut self) -> Result<TableRef, ParseError> {
183        let mut table_ref = if self.check(&TokenKind::LeftParen) {
184            // Subquery or grouped table ref
185            self.advance();
186            if self.check_keyword(Keyword::Select) {
187                let query = self.parse_select_statement()?;
188                self.expect(&TokenKind::RightParen)?;
189                let alias = self.parse_optional_alias()?;
190                TableRef::Subquery {
191                    query: Box::new(query),
192                    alias: alias.unwrap_or_else(|| String::from("subquery")),
193                }
194            } else {
195                let inner = self.parse_table_ref()?;
196                self.expect(&TokenKind::RightParen)?;
197                inner
198            }
199        } else {
200            // Simple table name
201            let first = self.expect_identifier()?;
202            let (schema, name) = if self.check(&TokenKind::Dot) {
203                self.advance();
204                let table_name = self.expect_identifier()?;
205                (Some(first), table_name)
206            } else {
207                (None, first)
208            };
209
210            let alias = self.parse_optional_alias()?;
211
212            TableRef::Table {
213                schema,
214                name,
215                alias,
216            }
217        };
218
219        // Parse joins
220        while self.is_join_keyword() {
221            let join_type = self.parse_join_type()?;
222            let right = self.parse_simple_table_ref()?;
223
224            let (on, using) = if join_type == JoinType::Cross {
225                (None, vec![])
226            } else if self.check_keyword(Keyword::On) {
227                self.advance();
228                (Some(self.parse_expression(0)?), vec![])
229            } else if self.check_keyword(Keyword::Using) {
230                self.advance();
231                self.expect(&TokenKind::LeftParen)?;
232                let cols = self.parse_identifier_list()?;
233                self.expect(&TokenKind::RightParen)?;
234                (None, cols)
235            } else {
236                return Err(ParseError::new(
237                    "Expected ON or USING clause",
238                    self.current.span,
239                ));
240            };
241
242            table_ref = TableRef::Join {
243                left: Box::new(table_ref),
244                join: Box::new(JoinClause {
245                    join_type,
246                    table: right,
247                    on,
248                    using,
249                }),
250            };
251        }
252
253        Ok(table_ref)
254    }
255
256    /// Parses a simple table reference (no joins).
257    fn parse_simple_table_ref(&mut self) -> Result<TableRef, ParseError> {
258        let first = self.expect_identifier()?;
259        let (schema, name) = if self.check(&TokenKind::Dot) {
260            self.advance();
261            let table_name = self.expect_identifier()?;
262            (Some(first), table_name)
263        } else {
264            (None, first)
265        };
266
267        let alias = self.parse_optional_alias()?;
268
269        Ok(TableRef::Table {
270            schema,
271            name,
272            alias,
273        })
274    }
275
276    /// Checks if current token is a join keyword.
277    fn is_join_keyword(&self) -> bool {
278        matches!(
279            &self.current.kind,
280            TokenKind::Keyword(
281                Keyword::Join
282                    | Keyword::Inner
283                    | Keyword::Left
284                    | Keyword::Right
285                    | Keyword::Full
286                    | Keyword::Cross
287            )
288        )
289    }
290
291    /// Parses a join type.
292    fn parse_join_type(&mut self) -> Result<JoinType, ParseError> {
293        let join_type = match &self.current.kind {
294            TokenKind::Keyword(Keyword::Join) => {
295                self.advance();
296                JoinType::Inner
297            }
298            TokenKind::Keyword(Keyword::Inner) => {
299                self.advance();
300                self.expect_keyword(Keyword::Join)?;
301                JoinType::Inner
302            }
303            TokenKind::Keyword(Keyword::Left) => {
304                self.advance();
305                if self.check_keyword(Keyword::Outer) {
306                    self.advance();
307                }
308                self.expect_keyword(Keyword::Join)?;
309                JoinType::Left
310            }
311            TokenKind::Keyword(Keyword::Right) => {
312                self.advance();
313                if self.check_keyword(Keyword::Outer) {
314                    self.advance();
315                }
316                self.expect_keyword(Keyword::Join)?;
317                JoinType::Right
318            }
319            TokenKind::Keyword(Keyword::Full) => {
320                self.advance();
321                if self.check_keyword(Keyword::Outer) {
322                    self.advance();
323                }
324                self.expect_keyword(Keyword::Join)?;
325                JoinType::Full
326            }
327            TokenKind::Keyword(Keyword::Cross) => {
328                self.advance();
329                self.expect_keyword(Keyword::Join)?;
330                JoinType::Cross
331            }
332            _ => {
333                return Err(ParseError::unexpected(
334                    "JOIN keyword",
335                    self.current.kind.clone(),
336                    self.current.span,
337                ));
338            }
339        };
340        Ok(join_type)
341    }
342
343    /// Parses an optional table alias.
344    fn parse_optional_alias(&mut self) -> Result<Option<String>, ParseError> {
345        if self.check_keyword(Keyword::As) {
346            self.advance();
347            Ok(Some(self.expect_identifier()?))
348        } else if matches!(&self.current.kind, TokenKind::Identifier(_)) && !self.is_reserved_word()
349        {
350            Ok(Some(self.expect_identifier()?))
351        } else {
352            Ok(None)
353        }
354    }
355
356    /// Checks if current identifier is a reserved word.
357    fn is_reserved_word(&self) -> bool {
358        matches!(
359            &self.current.kind,
360            TokenKind::Keyword(
361                Keyword::Where
362                    | Keyword::Order
363                    | Keyword::Group
364                    | Keyword::Having
365                    | Keyword::Limit
366                    | Keyword::Offset
367                    | Keyword::Join
368                    | Keyword::Inner
369                    | Keyword::Left
370                    | Keyword::Right
371                    | Keyword::Full
372                    | Keyword::Cross
373                    | Keyword::On
374                    | Keyword::Using
375                    | Keyword::Union
376                    | Keyword::Intersect
377                    | Keyword::Except
378            )
379        )
380    }
381
382    /// Parses an INSERT statement.
383    fn parse_insert_statement(&mut self) -> Result<InsertStatement, ParseError> {
384        self.expect_keyword(Keyword::Insert)?;
385        self.expect_keyword(Keyword::Into)?;
386
387        let first = self.expect_identifier()?;
388        let (schema, table) = if self.check(&TokenKind::Dot) {
389            self.advance();
390            let table_name = self.expect_identifier()?;
391            (Some(first), table_name)
392        } else {
393            (None, first)
394        };
395
396        // Column list (optional)
397        let columns = if self.check(&TokenKind::LeftParen) {
398            self.advance();
399            let cols = self.parse_identifier_list()?;
400            self.expect(&TokenKind::RightParen)?;
401            cols
402        } else {
403            vec![]
404        };
405
406        // VALUES, SELECT, or DEFAULT VALUES
407        let values = if self.check_keyword(Keyword::Values) {
408            self.advance();
409            let mut rows = vec![];
410            loop {
411                self.expect(&TokenKind::LeftParen)?;
412                let row = self.parse_expression_list()?;
413                self.expect(&TokenKind::RightParen)?;
414                rows.push(row);
415                if !self.check(&TokenKind::Comma) {
416                    break;
417                }
418                self.advance();
419            }
420            InsertSource::Values(rows)
421        } else if self.check_keyword(Keyword::Select) {
422            InsertSource::Query(Box::new(self.parse_select_statement()?))
423        } else if self.check_keyword(Keyword::Default) {
424            self.advance();
425            self.expect_keyword(Keyword::Values)?;
426            InsertSource::DefaultValues
427        } else {
428            return Err(ParseError::unexpected(
429                "VALUES, SELECT, or DEFAULT VALUES",
430                self.current.kind.clone(),
431                self.current.span,
432            ));
433        };
434
435        Ok(InsertStatement {
436            schema,
437            table,
438            columns,
439            values,
440            on_conflict: None,
441        })
442    }
443
444    /// Parses an UPDATE statement.
445    fn parse_update_statement(&mut self) -> Result<UpdateStatement, ParseError> {
446        self.expect_keyword(Keyword::Update)?;
447
448        let first = self.expect_identifier()?;
449        let (schema, table) = if self.check(&TokenKind::Dot) {
450            self.advance();
451            let table_name = self.expect_identifier()?;
452            (Some(first), table_name)
453        } else {
454            (None, first)
455        };
456
457        let alias = self.parse_optional_alias()?;
458
459        self.expect_keyword(Keyword::Set)?;
460
461        // Parse SET assignments
462        let mut assignments = vec![];
463        loop {
464            let column = self.expect_identifier()?;
465            self.expect(&TokenKind::Eq)?;
466            let value = self.parse_expression(0)?;
467            assignments.push(UpdateAssignment { column, value });
468
469            if !self.check(&TokenKind::Comma) {
470                break;
471            }
472            self.advance();
473        }
474
475        // FROM clause (optional, for joins)
476        let from = if self.check_keyword(Keyword::From) {
477            self.advance();
478            Some(self.parse_table_ref()?)
479        } else {
480            None
481        };
482
483        // WHERE clause
484        let where_clause = if self.check_keyword(Keyword::Where) {
485            self.advance();
486            Some(self.parse_expression(0)?)
487        } else {
488            None
489        };
490
491        Ok(UpdateStatement {
492            schema,
493            table,
494            alias,
495            assignments,
496            from,
497            where_clause,
498        })
499    }
500
501    /// Parses a DELETE statement.
502    fn parse_delete_statement(&mut self) -> Result<DeleteStatement, ParseError> {
503        self.expect_keyword(Keyword::Delete)?;
504        self.expect_keyword(Keyword::From)?;
505
506        let first = self.expect_identifier()?;
507        let (schema, table) = if self.check(&TokenKind::Dot) {
508            self.advance();
509            let table_name = self.expect_identifier()?;
510            (Some(first), table_name)
511        } else {
512            (None, first)
513        };
514
515        let alias = self.parse_optional_alias()?;
516
517        // WHERE clause
518        let where_clause = if self.check_keyword(Keyword::Where) {
519            self.advance();
520            Some(self.parse_expression(0)?)
521        } else {
522            None
523        };
524
525        Ok(DeleteStatement {
526            schema,
527            table,
528            alias,
529            where_clause,
530        })
531    }
532
533    /// Parses an ORDER BY list.
534    fn parse_order_by_list(&mut self) -> Result<Vec<OrderBy>, ParseError> {
535        let mut items = vec![];
536        loop {
537            let expr = self.parse_expression(0)?;
538            let direction = if self.check_keyword(Keyword::Desc) {
539                self.advance();
540                OrderDirection::Desc
541            } else if self.check_keyword(Keyword::Asc) {
542                self.advance();
543                OrderDirection::Asc
544            } else {
545                OrderDirection::Asc
546            };
547
548            items.push(OrderBy {
549                expr,
550                direction,
551                nulls: None,
552            });
553
554            if !self.check(&TokenKind::Comma) {
555                break;
556            }
557            self.advance();
558        }
559        Ok(items)
560    }
561
562    /// Parses an expression using Pratt parsing.
563    #[allow(clippy::while_let_loop)]
564    fn parse_expression(&mut self, min_bp: u8) -> Result<Expr, ParseError> {
565        // Parse prefix (primary expression or unary operator)
566        let mut lhs = self.parse_prefix()?;
567
568        // Parse infix operators
569        loop {
570            // Check if current token is an infix operator
571            let (l_bp, r_bp) = match infix_binding_power(&self.current.kind) {
572                Some(bp) => bp,
573                None => break,
574            };
575
576            if l_bp < min_bp {
577                break;
578            }
579
580            // Handle special infix operators
581            match &self.current.kind {
582                TokenKind::Keyword(Keyword::Is) => {
583                    self.advance();
584                    let negated = if self.check_keyword(Keyword::Not) {
585                        self.advance();
586                        true
587                    } else {
588                        false
589                    };
590                    self.expect_keyword(Keyword::Null)?;
591                    lhs = Expr::IsNull {
592                        expr: Box::new(lhs),
593                        negated,
594                    };
595                }
596                TokenKind::Keyword(Keyword::In) => {
597                    self.advance();
598                    self.expect(&TokenKind::LeftParen)?;
599                    let list = self.parse_expression_list()?;
600                    self.expect(&TokenKind::RightParen)?;
601                    lhs = Expr::In {
602                        expr: Box::new(lhs),
603                        list,
604                        negated: false,
605                    };
606                }
607                TokenKind::Keyword(Keyword::Between) => {
608                    self.advance();
609                    let low = self.parse_expression(r_bp)?;
610                    self.expect_keyword(Keyword::And)?;
611                    let high = self.parse_expression(r_bp)?;
612                    lhs = Expr::Between {
613                        expr: Box::new(lhs),
614                        low: Box::new(low),
615                        high: Box::new(high),
616                        negated: false,
617                    };
618                }
619                _ => {
620                    // Standard binary operator
621                    if let Some(op) = token_to_binary_op(&self.current.kind) {
622                        self.advance();
623                        let rhs = self.parse_expression(r_bp)?;
624                        lhs = Expr::Binary {
625                            left: Box::new(lhs),
626                            op,
627                            right: Box::new(rhs),
628                        };
629                    } else {
630                        break;
631                    }
632                }
633            }
634        }
635
636        Ok(lhs)
637    }
638
639    /// Parses a prefix expression.
640    fn parse_prefix(&mut self) -> Result<Expr, ParseError> {
641        // Check for unary operators
642        if let Some(op) = token_to_unary_op(&self.current.kind) {
643            let bp = prefix_binding_power(&self.current.kind).unwrap_or(15);
644            self.advance();
645            let operand = self.parse_expression(bp)?;
646            return Ok(Expr::Unary {
647                op,
648                operand: Box::new(operand),
649            });
650        }
651
652        self.parse_primary()
653    }
654
655    /// Parses a primary expression.
656    fn parse_primary(&mut self) -> Result<Expr, ParseError> {
657        let token = self.current.clone();
658
659        match &token.kind {
660            // Literals
661            TokenKind::Integer(n) => {
662                self.advance();
663                Ok(Expr::Literal(Literal::Integer(*n)))
664            }
665            TokenKind::Float(f) => {
666                self.advance();
667                Ok(Expr::Literal(Literal::Float(*f)))
668            }
669            TokenKind::String(s) => {
670                let value = s.clone();
671                self.advance();
672                Ok(Expr::Literal(Literal::String(value)))
673            }
674            TokenKind::Blob(b) => {
675                let value = b.clone();
676                self.advance();
677                Ok(Expr::Literal(Literal::Blob(value)))
678            }
679            TokenKind::Keyword(Keyword::True) => {
680                self.advance();
681                Ok(Expr::Literal(Literal::Boolean(true)))
682            }
683            TokenKind::Keyword(Keyword::False) => {
684                self.advance();
685                Ok(Expr::Literal(Literal::Boolean(false)))
686            }
687            TokenKind::Keyword(Keyword::Null) => {
688                self.advance();
689                Ok(Expr::Literal(Literal::Null))
690            }
691
692            // Parameter placeholders
693            TokenKind::Question => {
694                self.param_counter += 1;
695                let position = self.param_counter;
696                self.advance();
697                Ok(Expr::Parameter {
698                    name: None,
699                    position,
700                })
701            }
702            TokenKind::Colon => {
703                self.advance();
704                let name = self.expect_identifier()?;
705                Ok(Expr::Parameter {
706                    name: Some(name),
707                    position: 0,
708                })
709            }
710
711            // Wildcard
712            TokenKind::Star => {
713                self.advance();
714                Ok(Expr::Wildcard { table: None })
715            }
716
717            // Parenthesized expression or subquery
718            TokenKind::LeftParen => {
719                self.advance();
720                if self.check_keyword(Keyword::Select) {
721                    let subquery = self.parse_select_statement()?;
722                    self.expect(&TokenKind::RightParen)?;
723                    Ok(Expr::Subquery(Box::new(subquery)))
724                } else {
725                    let expr = self.parse_expression(0)?;
726                    self.expect(&TokenKind::RightParen)?;
727                    Ok(Expr::Paren(Box::new(expr)))
728                }
729            }
730
731            // Aggregate functions
732            TokenKind::Keyword(
733                kw @ (Keyword::Count | Keyword::Sum | Keyword::Avg | Keyword::Min | Keyword::Max),
734            ) => {
735                let name = kw.as_str().to_string();
736                self.advance();
737                self.parse_function_call(name)
738            }
739
740            // Other functions
741            TokenKind::Keyword(kw @ (Keyword::Coalesce | Keyword::Nullif | Keyword::Cast)) => {
742                let name = kw.as_str().to_string();
743                self.advance();
744                if matches!(kw, Keyword::Cast) {
745                    self.parse_cast_expression()
746                } else {
747                    self.parse_function_call(name)
748                }
749            }
750
751            // CASE expression
752            TokenKind::Keyword(Keyword::Case) => self.parse_case_expression(),
753
754            // EXISTS
755            TokenKind::Keyword(Keyword::Exists) => {
756                self.advance();
757                self.expect(&TokenKind::LeftParen)?;
758                let subquery = self.parse_select_statement()?;
759                self.expect(&TokenKind::RightParen)?;
760                Ok(Expr::Function(FunctionCall {
761                    name: String::from("EXISTS"),
762                    args: vec![Expr::Subquery(Box::new(subquery))],
763                    distinct: false,
764                }))
765            }
766
767            // Identifier (column reference or function call)
768            TokenKind::Identifier(name) => {
769                let name = name.clone();
770                let span = token.span;
771                self.advance();
772
773                // Check for function call
774                if self.check(&TokenKind::LeftParen) {
775                    return self.parse_function_call(name);
776                }
777
778                // Check for qualified name (table.column or table.*)
779                if self.check(&TokenKind::Dot) {
780                    self.advance();
781                    if self.check(&TokenKind::Star) {
782                        self.advance();
783                        return Ok(Expr::Wildcard { table: Some(name) });
784                    }
785                    let column = self.expect_identifier()?;
786                    return Ok(Expr::Column {
787                        table: Some(name),
788                        name: column,
789                        span,
790                    });
791                }
792
793                Ok(Expr::Column {
794                    table: None,
795                    name,
796                    span,
797                })
798            }
799
800            _ => Err(ParseError::unexpected(
801                "expression",
802                self.current.kind.clone(),
803                self.current.span,
804            )),
805        }
806    }
807
808    /// Parses a function call.
809    fn parse_function_call(&mut self, name: String) -> Result<Expr, ParseError> {
810        self.expect(&TokenKind::LeftParen)?;
811
812        let distinct = if self.check_keyword(Keyword::Distinct) {
813            self.advance();
814            true
815        } else {
816            false
817        };
818
819        let args = if self.check(&TokenKind::RightParen) {
820            vec![]
821        } else if self.check(&TokenKind::Star) {
822            self.advance();
823            vec![Expr::Wildcard { table: None }]
824        } else {
825            self.parse_expression_list()?
826        };
827
828        self.expect(&TokenKind::RightParen)?;
829
830        Ok(Expr::Function(FunctionCall {
831            name,
832            args,
833            distinct,
834        }))
835    }
836
837    /// Parses a CAST expression.
838    fn parse_cast_expression(&mut self) -> Result<Expr, ParseError> {
839        self.expect(&TokenKind::LeftParen)?;
840        let expr = self.parse_expression(0)?;
841        self.expect_keyword(Keyword::As)?;
842        let data_type = self.parse_data_type()?;
843        self.expect(&TokenKind::RightParen)?;
844
845        Ok(Expr::Cast {
846            expr: Box::new(expr),
847            data_type,
848        })
849    }
850
851    /// Parses a CASE expression.
852    fn parse_case_expression(&mut self) -> Result<Expr, ParseError> {
853        self.expect_keyword(Keyword::Case)?;
854
855        // Check for simple CASE (CASE expr WHEN ...)
856        let operand = if !self.check_keyword(Keyword::When) {
857            Some(Box::new(self.parse_expression(0)?))
858        } else {
859            None
860        };
861
862        // Parse WHEN/THEN clauses
863        let mut when_clauses = vec![];
864        while self.check_keyword(Keyword::When) {
865            self.advance();
866            let when_expr = self.parse_expression(0)?;
867            self.expect_keyword(Keyword::Then)?;
868            let then_expr = self.parse_expression(0)?;
869            when_clauses.push((when_expr, then_expr));
870        }
871
872        // Parse ELSE clause
873        let else_clause = if self.check_keyword(Keyword::Else) {
874            self.advance();
875            Some(Box::new(self.parse_expression(0)?))
876        } else {
877            None
878        };
879
880        self.expect_keyword(Keyword::End)?;
881
882        Ok(Expr::Case {
883            operand,
884            when_clauses,
885            else_clause,
886        })
887    }
888
889    /// Parses a data type.
890    fn parse_data_type(&mut self) -> Result<DataType, ParseError> {
891        let data_type = match &self.current.kind {
892            TokenKind::Keyword(Keyword::Int | Keyword::Integer) => {
893                self.advance();
894                DataType::Integer
895            }
896            TokenKind::Keyword(Keyword::Smallint) => {
897                self.advance();
898                DataType::Smallint
899            }
900            TokenKind::Keyword(Keyword::Bigint) => {
901                self.advance();
902                DataType::Bigint
903            }
904            TokenKind::Keyword(Keyword::Real) => {
905                self.advance();
906                DataType::Real
907            }
908            TokenKind::Keyword(Keyword::Double) => {
909                self.advance();
910                DataType::Double
911            }
912            TokenKind::Keyword(Keyword::Float) => {
913                self.advance();
914                DataType::Double
915            }
916            TokenKind::Keyword(Keyword::Decimal) => {
917                self.advance();
918                let (precision, scale) = self.parse_optional_precision_scale()?;
919                DataType::Decimal { precision, scale }
920            }
921            TokenKind::Keyword(Keyword::Numeric) => {
922                self.advance();
923                let (precision, scale) = self.parse_optional_precision_scale()?;
924                DataType::Numeric { precision, scale }
925            }
926            TokenKind::Keyword(Keyword::Char) => {
927                self.advance();
928                let len = self.parse_optional_length()?;
929                DataType::Char(len)
930            }
931            TokenKind::Keyword(Keyword::Varchar) => {
932                self.advance();
933                let len = self.parse_optional_length()?;
934                DataType::Varchar(len)
935            }
936            TokenKind::Keyword(Keyword::Text) => {
937                self.advance();
938                DataType::Text
939            }
940            TokenKind::Keyword(Keyword::Blob) => {
941                self.advance();
942                DataType::Blob
943            }
944            TokenKind::Keyword(Keyword::Boolean) => {
945                self.advance();
946                DataType::Boolean
947            }
948            TokenKind::Keyword(Keyword::Date) => {
949                self.advance();
950                DataType::Date
951            }
952            TokenKind::Keyword(Keyword::Time) => {
953                self.advance();
954                DataType::Time
955            }
956            TokenKind::Keyword(Keyword::Timestamp) => {
957                self.advance();
958                DataType::Timestamp
959            }
960            TokenKind::Keyword(Keyword::Datetime) => {
961                self.advance();
962                DataType::Datetime
963            }
964            TokenKind::Identifier(name) => {
965                let name = name.clone();
966                self.advance();
967                DataType::Custom(name)
968            }
969            _ => {
970                return Err(ParseError::unexpected(
971                    "data type",
972                    self.current.kind.clone(),
973                    self.current.span,
974                ));
975            }
976        };
977
978        Ok(data_type)
979    }
980
981    /// Parses optional precision and scale (for DECIMAL/NUMERIC).
982    fn parse_optional_precision_scale(&mut self) -> Result<(Option<u16>, Option<u16>), ParseError> {
983        if !self.check(&TokenKind::LeftParen) {
984            return Ok((None, None));
985        }
986        self.advance();
987
988        let precision = match &self.current.kind {
989            TokenKind::Integer(n) => {
990                let p = u16::try_from(*n)
991                    .map_err(|_| ParseError::new("Precision too large", self.current.span))?;
992                self.advance();
993                Some(p)
994            }
995            _ => {
996                return Err(ParseError::unexpected(
997                    "integer",
998                    self.current.kind.clone(),
999                    self.current.span,
1000                ));
1001            }
1002        };
1003
1004        let scale = if self.check(&TokenKind::Comma) {
1005            self.advance();
1006            match &self.current.kind {
1007                TokenKind::Integer(n) => {
1008                    let s = u16::try_from(*n)
1009                        .map_err(|_| ParseError::new("Scale too large", self.current.span))?;
1010                    self.advance();
1011                    Some(s)
1012                }
1013                _ => {
1014                    return Err(ParseError::unexpected(
1015                        "integer",
1016                        self.current.kind.clone(),
1017                        self.current.span,
1018                    ));
1019                }
1020            }
1021        } else {
1022            None
1023        };
1024
1025        self.expect(&TokenKind::RightParen)?;
1026        Ok((precision, scale))
1027    }
1028
1029    /// Parses optional length (for CHAR/VARCHAR).
1030    fn parse_optional_length(&mut self) -> Result<Option<u32>, ParseError> {
1031        if !self.check(&TokenKind::LeftParen) {
1032            return Ok(None);
1033        }
1034        self.advance();
1035
1036        let length = match &self.current.kind {
1037            TokenKind::Integer(n) => {
1038                let len = u32::try_from(*n)
1039                    .map_err(|_| ParseError::new("Length too large", self.current.span))?;
1040                self.advance();
1041                len
1042            }
1043            _ => {
1044                return Err(ParseError::unexpected(
1045                    "integer",
1046                    self.current.kind.clone(),
1047                    self.current.span,
1048                ));
1049            }
1050        };
1051
1052        self.expect(&TokenKind::RightParen)?;
1053        Ok(Some(length))
1054    }
1055
1056    /// Parses a comma-separated list of expressions.
1057    fn parse_expression_list(&mut self) -> Result<Vec<Expr>, ParseError> {
1058        let mut exprs = vec![];
1059        loop {
1060            exprs.push(self.parse_expression(0)?);
1061            if !self.check(&TokenKind::Comma) {
1062                break;
1063            }
1064            self.advance();
1065        }
1066        Ok(exprs)
1067    }
1068
1069    /// Parses a comma-separated list of identifiers.
1070    fn parse_identifier_list(&mut self) -> Result<Vec<String>, ParseError> {
1071        let mut idents = vec![];
1072        loop {
1073            idents.push(self.expect_identifier()?);
1074            if !self.check(&TokenKind::Comma) {
1075                break;
1076            }
1077            self.advance();
1078        }
1079        Ok(idents)
1080    }
1081
1082    // --- Helper methods ---
1083
1084    /// Advances to the next token.
1085    fn advance(&mut self) {
1086        self.previous = core::mem::replace(&mut self.current, self.lexer.next_token());
1087    }
1088
1089    /// Checks if the current token matches the given kind.
1090    fn check(&self, kind: &TokenKind) -> bool {
1091        core::mem::discriminant(&self.current.kind) == core::mem::discriminant(kind)
1092    }
1093
1094    /// Checks if the current token is the given keyword.
1095    fn check_keyword(&self, keyword: Keyword) -> bool {
1096        matches!(&self.current.kind, TokenKind::Keyword(kw) if *kw == keyword)
1097    }
1098
1099    /// Expects the current token to be the given kind.
1100    fn expect(&mut self, kind: &TokenKind) -> Result<(), ParseError> {
1101        if self.check(kind) {
1102            self.advance();
1103            Ok(())
1104        } else {
1105            Err(ParseError::unexpected(
1106                format!("{kind:?}"),
1107                self.current.kind.clone(),
1108                self.current.span,
1109            ))
1110        }
1111    }
1112
1113    /// Expects the current token to be the given keyword.
1114    fn expect_keyword(&mut self, keyword: Keyword) -> Result<(), ParseError> {
1115        if self.check_keyword(keyword) {
1116            self.advance();
1117            Ok(())
1118        } else {
1119            Err(ParseError::unexpected(
1120                keyword.as_str(),
1121                self.current.kind.clone(),
1122                self.current.span,
1123            ))
1124        }
1125    }
1126
1127    /// Expects and returns an identifier.
1128    fn expect_identifier(&mut self) -> Result<String, ParseError> {
1129        match &self.current.kind {
1130            TokenKind::Identifier(name) => {
1131                let name = name.clone();
1132                self.advance();
1133                Ok(name)
1134            }
1135            _ => Err(ParseError::unexpected(
1136                "identifier",
1137                self.current.kind.clone(),
1138                self.current.span,
1139            )),
1140        }
1141    }
1142}
1143
1144#[cfg(test)]
1145mod tests {
1146    use super::*;
1147    use crate::ast::BinaryOp;
1148
1149    fn parse(sql: &str) -> Result<Statement, ParseError> {
1150        Parser::new(sql).parse_statement()
1151    }
1152
1153    #[test]
1154    fn test_simple_select() {
1155        let stmt = parse("SELECT id, name FROM users").unwrap();
1156        assert!(matches!(stmt, Statement::Select(_)));
1157    }
1158
1159    #[test]
1160    fn test_select_with_where() {
1161        let stmt = parse("SELECT * FROM users WHERE id = 1").unwrap();
1162        if let Statement::Select(select) = stmt {
1163            assert!(select.where_clause.is_some());
1164        } else {
1165            panic!("Expected SELECT statement");
1166        }
1167    }
1168
1169    #[test]
1170    fn test_select_with_join() {
1171        let stmt =
1172            parse("SELECT u.id, o.amount FROM users u JOIN orders o ON u.id = o.user_id").unwrap();
1173        assert!(matches!(stmt, Statement::Select(_)));
1174    }
1175
1176    #[test]
1177    fn test_expression_precedence() {
1178        // 1 + 2 * 3 should be parsed as 1 + (2 * 3)
1179        let stmt = parse("SELECT 1 + 2 * 3").unwrap();
1180        if let Statement::Select(select) = stmt {
1181            if let Expr::Binary { op, right, .. } = &select.columns[0].expr {
1182                assert_eq!(*op, BinaryOp::Add);
1183                assert!(matches!(
1184                    right.as_ref(),
1185                    Expr::Binary {
1186                        op: BinaryOp::Mul,
1187                        ..
1188                    }
1189                ));
1190            } else {
1191                panic!("Expected binary expression");
1192            }
1193        } else {
1194            panic!("Expected SELECT statement");
1195        }
1196    }
1197
1198    #[test]
1199    fn test_insert_values() {
1200        let stmt =
1201            parse("INSERT INTO users (name, email) VALUES ('Alice', 'alice@example.com')").unwrap();
1202        if let Statement::Insert(insert) = stmt {
1203            assert_eq!(insert.table, "users");
1204            assert_eq!(insert.columns.len(), 2);
1205            assert!(matches!(insert.values, InsertSource::Values(_)));
1206        } else {
1207            panic!("Expected INSERT statement");
1208        }
1209    }
1210
1211    #[test]
1212    fn test_update() {
1213        let stmt = parse("UPDATE users SET name = 'Bob' WHERE id = 1").unwrap();
1214        if let Statement::Update(update) = stmt {
1215            assert_eq!(update.table, "users");
1216            assert_eq!(update.assignments.len(), 1);
1217            assert!(update.where_clause.is_some());
1218        } else {
1219            panic!("Expected UPDATE statement");
1220        }
1221    }
1222
1223    #[test]
1224    fn test_delete() {
1225        let stmt = parse("DELETE FROM users WHERE id = 1").unwrap();
1226        if let Statement::Delete(delete) = stmt {
1227            assert_eq!(delete.table, "users");
1228            assert!(delete.where_clause.is_some());
1229        } else {
1230            panic!("Expected DELETE statement");
1231        }
1232    }
1233
1234    #[test]
1235    fn test_parameter_placeholders() {
1236        let stmt = parse("SELECT * FROM users WHERE id = ? AND name = :name").unwrap();
1237        let Statement::Select(select) = stmt else {
1238            panic!("Expected SELECT statement");
1239        };
1240        let Some(Expr::Binary { left, right, .. }) = &select.where_clause else {
1241            panic!("Expected Binary expression in WHERE clause");
1242        };
1243        // First condition: id = ?
1244        if let Expr::Binary { right: param1, .. } = left.as_ref() {
1245            assert!(matches!(
1246                param1.as_ref(),
1247                Expr::Parameter {
1248                    name: None,
1249                    position: 1
1250                }
1251            ));
1252        }
1253        // Second condition: name = :name
1254        if let Expr::Binary { right: param2, .. } = right.as_ref() {
1255            assert!(matches!(
1256                param2.as_ref(),
1257                Expr::Parameter { name: Some(n), .. } if n == "name"
1258            ));
1259        }
1260    }
1261
1262    #[test]
1263    fn test_case_expression() {
1264        let stmt =
1265            parse("SELECT CASE WHEN status = 1 THEN 'active' ELSE 'inactive' END FROM users")
1266                .unwrap();
1267        if let Statement::Select(select) = stmt {
1268            assert!(matches!(select.columns[0].expr, Expr::Case { .. }));
1269        }
1270    }
1271
1272    #[test]
1273    fn test_aggregate_functions() {
1274        let stmt = parse("SELECT COUNT(*), SUM(amount), AVG(price) FROM orders").unwrap();
1275        if let Statement::Select(select) = stmt {
1276            assert_eq!(select.columns.len(), 3);
1277            assert!(matches!(select.columns[0].expr, Expr::Function(_)));
1278        }
1279    }
1280}